# Reading FITS images from LB telescopes
Fully functional class. Requires access to images and a list of the folders where to look for them.


Basic packages: ``numpy``, ``pandas``, ``matplotlib``, ``astropy``.

In [4]:
import os, sys, glob
from pathlib import Path
import datetime as dt
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import matplotlib.patches as patches
%matplotlib qt

from astropy.io import fits
from astropy import units as u
from astropy.wcs import WCS
# from astropy.wcs.wcsapi import SlicedLowLevelWCS
from astropy.wcs.utils import skycoord_to_pixel
from astropy.visualization import SimpleNorm, simple_norm
from astropy.visualization import make_lupton_rgb
from astropy.visualization.wcsaxes import add_scalebar
from astropy.visualization.wcsaxes import SphericalCircle
from astropy.coordinates import Angle
from astropy.coordinates import SkyCoord
from astropy.nddata import Cutout2D
from math import ceil




In [145]:
class image_viewer:
    def __init__(self, directory: str = '',
               list_available: bool = False,
               folder_list = False,
               previous_df = False,
               print_error = True):
        """Class to quickly open FITS images. Searches in given directory.
        
        Attributes
        ---------
        directory : str
            Directory where images are stored. If none given look in current working directory.

        list_available : bool : False
            Wether to print the resulting dataframe of found images or not

        folder_list : optional, list of str
            Extra directories to inspect for images and save their folder path from working directory
        
        previous_df : optional, pd.DataFrame or str
            Previous df with files to be added to the new one of the files found in ``folder_list``

        print_error: bool, optional
            If False, no error warnings will be printed (good for reading large datasets)
        
        Methods
        --------
        return_index()
            Returns the image path and index in the datafile given one or the other.
        
        header_info()
            Method to view general header info.

        view()
            Method to view images.
        
        view_multiple()
            Method to view multiple images in subplots of a figure
        """
        self.folder_list = folder_list
        print('Current working directory: ' + os.getcwd())
        if directory=='':
            directory = os.getcwd()
        if directory != os.getcwd():
            self.dir_img = os.path.join(os.getcwd(),directory)
        else: self.dir_img = directory
        print('Image directory defined: ' + self.dir_img)

        # list of images in dir_img and where were they
        files = list(Path(self.dir_img).glob('*.fits'))
        folder_found = ['']*len(files)
        # list of images in the different folders of folder_list and the corresponding folder
        if folder_list!= False:
            for fl in folder_list:
                fi = list(Path(os.path.join(self.dir_img, fl)).glob('*.fits'))
                files=files+fi
                folder_found =folder_found+[fl]*len(fi)

        files_data = []
        # creation of data dictionary
        for k, f in enumerate(files):
            try:
                name = f.name
                path = str(f.resolve())
                try: telescope, camera, date_time, object, filter = name.split('_')
                except: 
                    if print_error: print('ERROR WITH FILENAME FORMAT CONVENTION EXPECTED')
                size_MB = f.stat().st_size / 1e6
                created = pd.to_datetime(f.stat().st_ctime, unit="s")
                files_data.append({"filename": name, "path": path, "telescope": telescope, 'camera': camera,
                                   "object": object, "filter": filter[:-5], "size_MB": size_MB,
                                   "date_time": pd.to_datetime(date_time, format='%Y-%m-%d-%H-%M-%S-%f'),
                                   "folder_found": folder_found[k]})
            except: 
                if print_error: print('Error with file: %s'%f)
                
        if len(files)==0:
            print('WARNING: NO IMAGE FILES FOUND')
            return
        # creation of dataframe
        df_files = pd.DataFrame(files_data).sort_values("filename").reset_index(drop=True)
        # Addition of previous dataframe
        if type(previous_df) != bool:
            if type(previous_df) != pd.DataFrame:
                if type(previous_df) == str:
                    if previous_df[-3:] == 'pkl': previous_df = pd.read_pickle(previous_df)
                    elif previous_df[-3:] == 'csv' : previous_df = pd.read_csv(previous_df)
                    else: 
                        print('ERROR: unrecognized DataFrame format. Use \'.pkl\' or \'.csv\'.')
                        return
            self.df_files = pd.concat([df_files, previous_df], ignore_index = True).drop_duplicates(subset = 'filename', keep= 'last')
        else: self.df_files = df_files
        # print available images if requested
        if list_available:
            print(self.df_files)
        print('Total number of images found: ', len(self.df_files))
    
    def return_index(self, image):
        """
        Returns the image path and index in the datafile given one or the other.

        Parameters
        ----------
        image: int / str
            int - image index in datafile \n
            str - image path
        """
        if type(image)==int:
            image_str = self.df_files.iloc[image].filename
            image_int = image
        else: 
            image_str = image
            try: image_int = self.df_files.index[self.df_files['filename']==image].to_list()[0]
            except:
                print('\n ERROR: FILENAME NOT FOUND')
                return
        if self.folder_list != False:
            folder_name = self.df_files.iloc[image_int].folder_found
            image_str = os.path.join(folder_name, image_str)
        return image_str, image_int


    def header_info(self, image,
                    interesting_keys = ['INSTRUME', 'OBJECT', 'FILTER', 'INTEGT', 'DATE-OBS',
                                        'RA', 'DEC', 'NAXIS1', 'NAXIS2', 'SCALE', 'FOVX', 'FOVY',
                                        'CCW', 'CRPIX1', 'CRPIX2', 'FWHM']
                                        ):
        """Method to view general header info.
        
        Parameters
        ----------
        image : int / str
            int - index of desired file in dataframe \n
            string - path to desired fits file
            
        interesting_keys: list / 'all'
            list - list of strings with header keyword \n
            'all' - will print the whole header
        """
        image_str, image_int = self.return_index(image)
        
        # Extracting data from header
        with fits.open(os.path.join(self.dir_img, image_str)) as hdul:
            heads = hdul[0].header
            hdul.close()
        # printing basic header info
        print('Image: %s'%image_str)
        print('\n   --- HEADER DATA ---')
        try:
            if type(interesting_keys) == str and interesting_keys=='all':
                print(repr(heads))
            else:
                for k in interesting_keys:
                    if heads.comments[k]!='':
                        print(k, ' = ', heads[k], '  ---  ', heads.comments[k])
                    else:
                        print(k, ' = ', heads[k])
        except:
            print('WARNING: WRONG interesting_keys PARAMETER.')
            print('         Try a list with the strings of header keys or just the string \'all\'')

    def view_image(self, image,
                    RGB = False,
                    nrows_ncols = None,
                    figsize = None,
                    manipulation_kw = {
                       'centered' : True,
                       'zoom' : False,
                       'stretch' : 'linear',
                       'percentile' : None,
                       'vminmax' : (None, None)
                       },
                    plotting_kw = {
                        'cmap' : 'gray',
                        'scalebar_arcsec' : 5,
                        'scalebar_frame' : False,
                        'add_circle' : None
                        }
                    ):
        """
        Method to view images. Takes dictionary keywords for ``data_manipulation`` and ``plotting``.
        """
        # Multiple images
        if type(image) == list and RGB == False:
            print('------\nViewing multiple images:')
            n_image = len(image)
            if nrows_ncols == None:
                if n_image <= 3: nrows_ncols = (1, n_image)
                else: nrows_ncols = (ceil(np.sqrt(n_image)), ceil(np.sqrt(n_image)))
            image_list = image

        # Simple image Non RGB
        if type(image) != list:
            print('------\nViewing image:')
            n_image = 1
            image_list = [image]
        # RGB image
        if RGB == True: 
            n_image = 1
            colors = ['R', 'G', 'B']
            cutout_RGB = []
            if manipulation_kw['stretch']!='linear':
                print('Overriding for a linear stretch')
                manipulation_kw['stretch']='linear'
            print('------\nRGB color composite image:')
            if n_image == 1: nrows_ncols = (1,1)
            image_list = image

        self.nr_nc = nrows_ncols
        n_data = len(image_list)

        # if manipulation and plotting are dicts, use the same setup for all images
        if type(manipulation_kw) == dict: manipulation_kw = [manipulation_kw]*n_data
        if type(plotting_kw) == dict: plotting_kw = [plotting_kw]*n_data

        fig, axes = plt.subplots(self.nr_nc[0], self.nr_nc[1],
                                 figsize = figsize)
        if n_image == 1: axes = [axes]
        axes = np.array(axes).reshape(-1)
        
        for i, (img, m_k, p_k) in enumerate(zip(image_list, manipulation_kw, plotting_kw)):
            self.img_str, self.img_int = self.return_index(img)
            cutout, norm = self.data_manipulation(self.img_str, **m_k)

            if RGB == False:
                print('    Object: ',self.df_files.object.iloc[self.img_int],
                  '  -  Filter: ',self.df_files['filter'].iloc[self.img_int])
                self.plotting(cutout, norm, fig, axes[i], i,
                              **p_k)
            else:
                if i==0: print('    Object: ', self.df_files.iloc[self.img_int].object)
                print('    - ',colors[i],': ', self.df_files['filter'].iloc[self.img_int])
                cutout_RGB.append(norm(cutout.data))
                if i == len(image_list)-1:
                    for k in range(3):
                        print(np.min(cutout_RGB[k]), np.max(cutout_RGB[k]))
                    rgb_default = make_lupton_rgb(cutout_RGB[0].data, cutout_RGB[1].data, cutout_RGB[2].data)
                    axes[0].remove()
                    ax = fig.add_subplot(1,1, 1, projection = cutout.wcs)
                    ax.imshow(rgb_default, origin='lower')
        # if RGB:
        #     print('    Object: ', self.df_files.iloc[self.img_int].object)
        #     for i, (img, m_k, p_k) in enumerate(zip(image_list, manipulation_kw, plotting_kw)):
        #         print('    - ',colors[i],': ', self.df_files['filter'].iloc[self.img_int])
        #         self.img_str, self.img_int = self.return_index(img)

        #     fig, ax = plt.subplots(subplot_kw={'projection': wcs})
        plt.tight_layout()
        plt.show()

    def data_manipulation(self, image_str,
                          centered = True, 
                          zoom = False,
                          stretch = 'linear',
                          percentile = None,
                          vminmax = (None, None)
                          ):
        """
        Method to prepare images for manipulation. It is internally called. Crops the image and sets visualization normalization and stretch.

        Parameters
        ---------
        image : int / string / list
            int - index of desired file in dataframe \n
            string - path to desired fits file \n

        centered : True or tuple, optional
            (x,y) - int for pix coordinates \n
            (RA, DEC) - wcs coordinates. Accepting both strings or angle values

        zoom : False or Value or Tuple, optional
            int / (int, int) - pixel size in x and y axis \n
            Angle / (Angle, Angle) - angular size in RA and DEC
        
        stretch : str, optional
            Image stretch to enhance detail visualization \n
            ``linear``, ``sqrt``, ``power``, ``log``, ``sinh``, ``asinh``
        
        percentile : int or tuple, optional
            ``int`` - Middle percentile of values to consider for normalization; 
            ``tuple`` - Lower and upper percentile of values to consider for normalization
        
        vminmax : tuple, optional
            Min and max pixel values for normalization. Overrides ``percentile``.
            If set as None, keeps the absolute min or max of image
        """
        

        # Extracting data from header
        with fits.open(os.path.join(self.dir_img, image_str)) as hdul:
            data = hdul[0].data.astype(np.float32)
            heads = hdul[0].header
            wcs = WCS(heads)
            hdul.close()
        
        # obtaining central px coordinates
        x_shape = data.shape[1]
        y_shape = data.shape[0]
        if centered == True:
            center_px = (x_shape//2, y_shape//2)
        if type(centered)==tuple:
            if type(centered[0]) == int: # input in px units
                center_px = tuple(centered)
            if type(centered[0]) == str: # input in str to be converted to deg
                center_angle = SkyCoord(centered[0], centered[1], frame = 'icrs')
                center_px = skycoord_to_pixel(center_angle, wcs, origin=0)
        
        # setting zoom
        if zoom == False:
            zoom = (x_shape, y_shape)
        if type(zoom) == str:
            zoom = Angle(zoom)
        if type(zoom)== tuple:
            if type(zoom[0]) == str:
                zoom = (Angle(zoom[0]), Angle(zoom[1]))
        if type(zoom)==tuple:
            zoom = zoom[::-1]
        
        # slicing image
        try:
            cutout = Cutout2D(data, position = center_px, size = zoom, wcs = wcs)
        except:
            print('\n --- \nERROR: the cutout region is outside of the image.')
            return

        # norm definition
        if type(percentile) == int or percentile == None:
            percentile_minmax = (None, None)
        if type(percentile) == tuple:
            percentile_minmax = percentile
            percentile = None
        if stretch not in {'linear', 'sqrt', 'power', 'log', 'sinh', 'asinh'}:
            print('ERROR: Stretch should be one of \'linear\', \'sqrt\', \'power\', \'log\', \'sinh\', \'asinh\'')
            plt.close()
            return
        norm = simple_norm(cutout.data, stretch = stretch, 
                           vmin = vminmax[0], vmax = vminmax[1],
                           percent = percentile,
                           min_percent = percentile_minmax[0],
                           max_percent = percentile_minmax[1])
        
        return cutout, norm
        
    def plotting(self,
                 cutout, norm, fig, ax, ax_i,
                cmap = 'gray',
                scalebar_arcsec = 5, scalebar_frame = False,
                add_circle = None
                ):
        """
        Method to plot images, obtains edited data from ``self.data_manipulation()``.

        Parameters
        ---------
        cutout : Cutout2D
            Selected cutout object from ``data_manipulation``

        norm : Norm
            Selected norm from ``data_manipulation``

        cmap : str, optional
            Select the desired colormap for the image

        scalebar_arcsec : int, optional
            Angular size of scalebar in arcsec units
        
        scalebar_frame : bool, optional
            Add frame or not

        add_circle : dict, list of dicts or None, optional
            Parameters to plot a circle overlay. If None, no circle is plotted. If multiple circles are desired, enter a list of dicts.\n
            Expected keys: \n
                'center' : tuple 
                    (RA, DEC) coordinates as astropy Angle or SkyCoord
                'size' : astropy.units.Quantity
                    Angular size (e.g., astropy Angle with units).
                'color' : str, optional
                    Circle edge color.
                'label' : str, optional
                    Label for the circle to use in legend.
            
        fig_kwrds : None or dict, optional
            Dict with all the keywords desired to insert in ``plt.subplots()``

        figure : None or dict ..... tuple or axis
            Dict used by view_multiple method. Expected keys: \n
                'is_simple' : bool
                'create_fig' : bool
                    True or False
                'figsize' : tuple
                    Looked at if ``create_fig = True``
                'nrows_ncols' : tuple
                    Looked at if ``create_fig = True``
                'fig' : plt.figure object
                    Looked at if ``create_fig = False``
                'ax' : plt.axis object
                    Looked at if ``create_fig = False``
                'im_i' : int
                    Subplot index (image index). Looked at if ``create_fig = False``

            None - creates normal figure, does not return nothing \n
            tuple (int, int) - creates figure with specified conditions. Returns (fig, ax) \n
            tuple (ax, int, int) - plots image in specified ax[int,int]
        """
        # # WCS projection (for simple figures)
        # if figure['is_simple'] == True:
        #     fig, ax = plt.subplots(subplot_kw=dict(projection=cutout.wcs), **fig_kwrds)
        # # Figure creation if multiple figures are in use
        # if figure['is_simple'] == False:
        #     create_fig = figure['create_fig']
        #     nrows_ncols = figure['nrows_ncols']
        #     im_i = figure['im_i']
        # if create_fig == True:
        #     figsize = figure['figsize']
        #     fig, axes = plt.subplots(nrows_ncols[0], nrows_ncols[1],
        #                             #  projection = None,
        #                                 figsize = figsize)
        #     axes = np.array(axes).reshape(-1)
        #     ax = axes[0]
        # else:
        #     fig = figure['fig']
        #     ax = figure['ax']
        with fits.open(os.path.join(self.dir_img, self.img_str)) as hdul:
            heads = hdul[0].header
            hdul.close()
        ax.remove()
        ax = fig.add_subplot(self.nr_nc[0], self.nr_nc[1], ax_i+1, projection = cutout.wcs)
        
        # colorbar
        cax = ax.imshow(cutout.data,
                        norm = norm, origin = 'lower',
                        cmap = cmap)
        cbar = plt.colorbar(cax)
        cbar.set_label('ADU', rotation=270, labelpad=15)
        cbar.ax.tick_params(labelsize=10)

        # Scale bar choosing color depending on luminance of cmap
        scalebar_angle = scalebar_arcsec/3600*u.deg
        rgba = plt.get_cmap(cmap)(0.0)
        luminance = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
        scalebar_color = 'white' if (luminance < 0.5 and scalebar_frame == False) else 'black'
        add_scalebar(ax, scalebar_angle, label="%s arcsec"%str(scalebar_arcsec), color=scalebar_color, frame=scalebar_frame)
        # Axis and title
        ax.set(xlabel='RA', ylabel='Dec')
        ax.coords.grid(color='gray', alpha=0.5, linestyle='solid')
        title_str = (r'$\bf{Object}$: %s - $\bf{Telescope}$: %s - $\bf{Seeing}$: %.1f$^{\prime\prime}$''\n'
                    r'$\bf{Camera}$: %s - $\bf{Filter}$: %s - $\bf{Integration}$: %s s''\n'
                    r'$\bf{SNR}$: %s - $\bf{Date time}$: %s'
                    %(self.df_files.iloc[self.img_int]['object'],
                    self.df_files.iloc[self.img_int]['telescope'],
                    (float(heads['FWHM'])*float(heads['SCALE'])),
                    self.df_files.iloc[self.img_int]['camera'],
                    self.df_files.iloc[self.img_int]['filter'],
                    heads['INTEGT'], heads['OBJECSNR'],
                    self.df_files.iloc[self.img_int]['date_time'].strftime("%Y-%m-%d %H:%M")))
        ax.set_title(title_str)
        ax.minorticks_on()

        # Optional plot of circles
        if add_circle is not None:
            if type(add_circle) != list:
                add_circle = [add_circle]
            for d_circle in add_circle:
                center = d_circle.get('center')
                size = d_circle.get('size')
                color = d_circle.get('color')
                label = d_circle.get('label')
                c = SphericalCircle((Angle(center[0]), Angle(center[1])),
                                    Angle(size),
                                    edgecolor = color,
                                    facecolor = 'none',
                                    transform = ax.get_transform('icrs'))
                ax.add_patch(c)
        # Only show figure if simple, return fig and axes if multiple
        # if figure == None:
        #     plt.tight_layout()
        #     plt.show()
        #     return
        # else:
        #     if figure['create_fig'] == True:
        #         return fig, ax
        

    # def view(self, image,
    #          centered = True, 
    #          zoom = False,
    #          stretch = 'linear',
    #          percentile = None,
    #          vminmax = (None, None),
    #          cmap = 'gray',
    #          scalebar_arcsec = 5, scalebar_frame = False,
    #          add_circle = None,
    #          fig_kwrds = {},
    #          figure = None,
    #          RGB = False
    #          ):
    #     """
    #     Method to view images.

    #     Parameters
    #     ---------
    #     image : int / string / list
    #         int - index of desired file in dataframe \n
    #         string - path to desired fits file \n

    #     centered : True or tuple, optional
    #         (x,y) - int for pix coordinates \n
    #         (RA, DEC) - wcs coordinates. Accepting both strings or angle values

    #     zoom : False or Value or Tuple, optional
    #         int / (int, int) - pixel size in x and y axis \n
    #         Angle / (Angle, Angle) - angular size in RA and DEC
        
    #     stretch : str, optional
    #         Image stretch to enhance detail visualization \n
    #         ``linear``, ``sqrt``, ``power``, ``log``, ``sinh``, ``asinh``
        
    #     percentile : int or tuple, optional
    #         ``int`` - Middle percentile of values to consider for normalization; 
    #         ``tuple`` - Lower and upper percentile of values to consider for normalization
        
    #     vminmax : tuple, optional
    #         Min and max pixel values for normalization. Overrides ``percentile``.
    #         If set as None, keeps the absolute min or max of image

    #     cmap : str, optional
    #         Select the desired colormap for the image

    #     scalebar_arcsec : int, optional
    #         Angular size of scalebar in arcsec units
        
    #     scalebar_frame : bool, optional
    #         Add frame or not

    #     add_circle : dict, list of dicts or None, optional
    #         Parameters to plot a circle overlay. If None, no circle is plotted. If multiple circles are desired, enter a list of dicts.\n
    #         Expected keys: \n
    #             'center' : tuple 
    #                 (RA, DEC) coordinates as astropy Angle or SkyCoord
    #             'size' : astropy.units.Quantity
    #                 Angular size (e.g., astropy Angle with units).
    #             'color' : str, optional
    #                 Circle edge color.
    #             'label' : str, optional
    #                 Label for the circle to use in legend.
        
    #     fig_kwrds : None or dict, optional
    #         Dict with all the keywords desired to insert in ``plt.subplots()``

    #     figure : None or dict ..... tuple or axis
    #         Dict used by view_multiple method. Expected keys: \n
    #             'create_fig' : bool
    #                 True or False
    #             'figsize' : tuple
    #                 Looked at if ``create_fig = True``
    #             'nrows_ncols' : tuple
    #                 Looked at if ``create_fig = True``
    #             'fig' : plt.figure object
    #                 Looked at if ``create_fig = False``
    #             'ax' : plt.axis object
    #                 Looked at if ``create_fig = False``
    #             'im_i' : int
    #                 Subplot index (image index). Looked at if ``create_fig = False``

    #         None - creates normal figure, does not return nothing \n
    #         tuple (int, int) - creates figure with specified conditions. Returns (fig, ax) \n
    #         tuple (ax, int, int) - plots image in specified ax[int,int]
    #     """
            
    #     if RGB == False:
    #         image_str, image_int = self.return_index(image)
    #         print('Viewing ', image_str)
        

    #     # Extracting data from header
    #     with fits.open(os.path.join(self.dir_img, image_str)) as hdul:
    #         data = hdul[0].data.astype(np.float32)
    #         heads = hdul[0].header
    #         wcs = WCS(heads)
    #         hdul.close()
        
    #     # obtaining central px coordinates
    #     x_shape = data.shape[1]
    #     y_shape = data.shape[0]
    #     if centered == True:
    #         center_px = (x_shape//2, y_shape//2)
    #     if type(centered)==tuple:
    #         if type(centered[0]) == int: # input in px units
    #             center_px = tuple(centered)
    #         if type(centered[0]) == str: # input in str to be converted to deg
    #             center_angle = SkyCoord(centered[0], centered[1], frame = 'icrs')
    #             center_px = skycoord_to_pixel(center_angle, wcs, origin=0)
        
    #     # setting zoom
    #     if zoom == False:
    #         zoom = (x_shape, y_shape)
    #     if type(zoom) == str:
    #         zoom = Angle(zoom)
    #     if type(zoom)== tuple:
    #         if type(zoom[0]) == str:
    #             zoom = (Angle(zoom[0]), Angle(zoom[1]))
    #     if type(zoom)==tuple:
    #         zoom = zoom[::-1]
        
    #     # slicing image
    #     try:
    #         cutout = Cutout2D(data, position = center_px, size = zoom, wcs = wcs)
    #     except:
    #         print('\n --- \nERROR: the cutout region is outside of the image.')
    #         return

    #     # norm definition
    #     if type(percentile) == int or percentile == None:
    #         percentile_minmax = (None, None)
    #     if type(percentile) == tuple:
    #         percentile_minmax = percentile
    #         percentile = None
    #     if stretch not in {'linear', 'sqrt', 'power', 'log', 'sinh', 'asinh'}:
    #         print('ERROR: Stretch should be one of \'linear\', \'sqrt\', \'power\', \'log\', \'sinh\', \'asinh\'')
    #         plt.close()
    #         return
    #     norm = simple_norm(cutout.data, stretch = stretch, 
    #                        vmin = vminmax[0], vmax = vminmax[1],
    #                        percent = percentile,
    #                        min_percent = percentile_minmax[0],
    #                        max_percent = percentile_minmax[1])
    #     # WCS projection (for simple figures)
    #     if figure == None:
    #         fig, ax = plt.subplots(subplot_kw=dict(projection=cutout.wcs), **fig_kwrds)

    #     # Figure creation if multiple figures are in use
    #     if type(figure) == dict:
    #         create_fig = figure['create_fig']
    #         nrows_ncols = figure['nrows_ncols']
    #         im_i = figure['im_i']
    #         if create_fig == True:
    #             figsize = figure['figsize']
    #             fig, axes = plt.subplots(nrows_ncols[0], nrows_ncols[1],
    #                                     #  projection = None,
    #                                      figsize = figsize)
    #             axes = np.array(axes).reshape(-1)
    #             ax = axes[0]
    #         else:
    #             fig = figure['fig']
    #             ax = figure['ax']
    #         ax.remove()
    #         ax = fig.add_subplot(nrows_ncols[0], nrows_ncols[1], im_i+1, projection = cutout.wcs)
            
    #     # colorbar
    #     cax = ax.imshow(cutout.data,
    #                     norm = norm, origin = 'lower',
    #                     cmap = cmap)
    #     cbar = plt.colorbar(cax)
    #     cbar.set_label('ADU', rotation=270, labelpad=15)
    #     cbar.ax.tick_params(labelsize=10)

    #     # Scale bar choosing color depending on luminance of cmap
    #     scalebar_angle = scalebar_arcsec/3600*u.deg
    #     rgba = plt.get_cmap(cmap)(0.0)
    #     luminance = 0.299*rgba[0] + 0.587*rgba[1] + 0.114*rgba[2]
    #     scalebar_color = 'white' if (luminance < 0.5 and scalebar_frame == False) else 'black'
    #     # if override_scalebar_color != False:
    #     #     scalebar_color = override_scalebar_color
    #     add_scalebar(ax, scalebar_angle, label="%s arcsec"%str(scalebar_arcsec), color=scalebar_color, frame=scalebar_frame)
    #     # Axis and title
    #     ax.set(xlabel='RA', ylabel='Dec')
    #     ax.coords.grid(color='gray', alpha=0.5, linestyle='solid')
    #     title_str = (r'$\bf{Object}$: %s - $\bf{Telescope}$: %s - $\bf{Seeing}$: %.1f$^{\prime\prime}$''\n'
    #                  r'$\bf{Camera}$: %s - $\bf{Filter}$: %s - $\bf{Integration}$: %s s''\n'
    #                  r'$\bf{SNR}$: %s - $\bf{Date time}$: %s'
    #                  %(self.df_files.iloc[image_int]['object'],
    #                    self.df_files.iloc[image_int]['telescope'],
    #                    (float(heads['FWHM'])*float(heads['SCALE'])),
    #                    self.df_files.iloc[image_int]['camera'],
    #                    self.df_files.iloc[image_int]['filter'],
    #                    heads['INTEGT'], heads['OBJECSNR'],
    #                    self.df_files.iloc[image_int]['date_time'].strftime("%Y-%m-%d %H-%M")))
    #     ax.set_title(title_str)
    #     ax.minorticks_on()

    #     # Optional plot of circles
    #     if add_circle is not None:
    #         if type(add_circle) != list:
    #             add_circle = [add_circle]
    #         for d_circle in add_circle:
    #             center = d_circle.get('center')
    #             size = d_circle.get('size')
    #             color = d_circle.get('color')
    #             label = d_circle.get('label')
    #             c = SphericalCircle((Angle(center[0]), Angle(center[1])),
    #                                 Angle(size),
    #                                 edgecolor = color,
    #                                 facecolor = 'none',
    #                                 transform = ax.get_transform('icrs'))
    #             ax.add_patch(c)
    #     # Only show figure if simple, return fig and axes if multiple
    #     if figure == None:
    #         plt.tight_layout()
    #         plt.show()
    #         return
    #     else:
    #         if figure['create_fig'] == True:
    #             return fig, axes


    # def view_multiple(self,
    #                   image_list : list,
    #                   view_kwrds = {},
    #                   nrows_ncols = None,
    #                   figsize = None
    #                   ):
    #     """
    #     Method to view multiple images, recurring to the self.view method.

    #     Paramters
    #     ------
    #     image_list : list
    #         Contains the list of paths or indexes to files in the dataframe

    #     view_kwrds : list
    #         List of dictionaries with for each argument in self.view

    #     nrows_ncols : None or tuple, optional
    #         If None, look for a squared relation

    #     figsize : None or tuple
    #         Figure size
    #     """
    #     n_image = len(image_list)
    #     if nrows_ncols == None:
    #         if n_image <= 3: nrows_ncols = (1, n_image)
    #         else: nrows_ncols = (ceil(np.sqrt(n_image)), ceil(np.sqrt(n_image)))
    #     if type(view_kwrds) == dict:
    #         view_kwrds = [view_kwrds]*n_image
    #     for i, (image, v_k) in enumerate(zip(image_list, view_kwrds)):
    #         if i==0:
    #             fig, axes = self.view(image, figure = {'create_fig' : True,
    #                                                    'figsize' : figsize,
    #                                                    'nrows_ncols' : nrows_ncols,
    #                                                    'im_i' : i}, 
    #                                  **v_k)
                
    #         else:
    #             self.view(image, 
    #                       figure = {'create_fig' : False,
    #                                 'nrows_ncols' : nrows_ncols,
    #                                 'im_i' : i,
    #                                 'fig' : fig,
    #                                 'ax' : axes[i]}, 
    #                       **v_k)
    #     plt.tight_layout()
    #     plt.show()

    # def view_color(self, images):
    #     """
    #     RGB color composite image with ``make_lupton_rgb`` function of the ``astropy`` package.
    #     All given images must belong to the same filter
    #     """
    #     print('------\nRGB color composite image:')
    #     colors = ['R', 'G', 'B']
    #     data_filters = []
    #     filters = []
    #     wcs_all = []

    #     for i, image in enumerate(images):
    #         image_str, image_int = self.return_index(image)
    #         if i==0:
    #             print('    Object: ', self.df_files.iloc[image_int].object)
    #         print('    - ',colors[i],': ', self.df_files['filter'].iloc[image_int])

    #         # Extracting data from header
    #         with fits.open(os.path.join(self.dir_img, image_str)) as hdul:
    #             data = hdul[0].data.astype(np.float32)
    #             heads = hdul[0].header
    #             wcs = WCS(heads)
    #             hdul.close()
    #         data_filters.append(data)
    #         wcs_all.append(wcs)

    #     rgb_default = make_lupton_rgb(data_filters[0], data_filters[1], data_filters[2])
    #     fig, ax = plt.subplots(subplot_kw={'projection': wcs})
    #     ax.imshow(rgb_default, origin='lower')


    def read_data(self, image):
        """Method to view images."""
        image_str, image_int = self.return_index(image)
        print('Viewing ', image_str)

        # Extracting data from header
        with fits.open(os.path.join(self.dir_img, image_str)) as hdul:
            data = hdul[0].data.astype(np.float32)
            hdul.close()
        return data



In [146]:
iv = image_viewer(directory = 'test_images', folder_list = ['2025-11-06'])
iv.view_image([2,1,0], RGB = True, 
              figsize = (12,8),
              manipulation_kw={'zoom' : False,
                               'stretch' : 'linear',
                               'percentile' : (50,99.99)}
              )

Current working directory: /Users/oscar/LB/grav_lens
Image directory defined: /Users/oscar/LB/grav_lens/test_images
Total number of images found:  4
------
RGB color composite image:
    Object:  ZTF25abnjznp
    -  R :  SDSSi
    -  G :  SDSSr
    -  B :  SDSSg
-0.1159927418829325 1.0257955817777074
-0.09655508328699171 1.0331418309061144
-0.018096460101566566 1.0420655726170345


In [141]:
iv = image_viewer(directory = 'test_images', folder_list = ['2025-11-06'])
# iv.df_files
iv.view_multiple([0,1,2], 
                 figsize=(12,8),
                 view_kwrds={
                             'zoom' : ('0 0 50 d', '0:0:50 d'),
                            'stretch' : 'linear',
                            'percentile' : (50,90)})
# iv.view_color([3,1,0])

Current working directory: /Users/oscar/LB/grav_lens
Image directory defined: /Users/oscar/LB/grav_lens/test_images
Total number of images found:  4
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-19-33-541318_ZTF25abnjznp_SDSSg.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-24-44-679453_ZTF25abnjznp_SDSSr.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-29-56-000503_ZTF25abnjznp_SDSSi.fits


In [None]:
iv = image_viewer(directory = 'test_images', folder_list = ['2025-11-06'])#, previous_df='df_files.pkl')
# iv.df_files#.groupby(['telescope', 'folder_found']).size()

obj_coords = ('07:16:34.5h','+38:21:08d')
indexes = [0,1,2,3]

iv.view_multiple(indexes,
                 figsize = (12,8),
                 view_kwrds={
                     'centered' : obj_coords,
                     'zoom' : ('0 0 20 d', '0:0:20 d'),
                     'stretch' : 'sinh',
                     'percentile' : (10,100),
                     'scalebar_frame' : True
                    #  'add_circle' : {'center' : obj_coords,
                    #                  'size' : '0:0:5d',
                    #                  'color' : 'white'}
                        })

Current working directory: /Users/oscar/LB/grav_lens
Image directory defined: /Users/oscar/LB/grav_lens/test_images
Total number of images found:  4
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-19-33-541318_ZTF25abnjznp_SDSSg.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-24-44-679453_ZTF25abnjznp_SDSSr.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-29-56-000503_ZTF25abnjznp_SDSSi.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-35-07-118053_ZTF25abnjznp_SDSSzs.fits


In [63]:
obj_coords = ('350.3458d','-03.5082d')
indexes = [2,3,4,5]
obj_coords = ('07:16:34.5h','+38:21:08d')
indexes = [0,1,2,3]

iv.view_multiple(indexes,
                 figsize = (12,8),
                 view_kwrds={
                     'centered' : obj_coords,
                     'zoom' : ('0 0 25 d', '0:0:25 d'),
                     'stretch' : 'sinh',
                     'percentile' : (10,100),
                     'scalebar_frame' : True,
                     'add_circle' : {'center' : obj_coords,
                                     'size' : '0:0:5d',
                                     'color' : 'white'}
                        })

Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-19-33-541318_ZTF25abnjznp_SDSSg.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-24-44-679453_ZTF25abnjznp_SDSSr.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-29-56-000503_ZTF25abnjznp_SDSSi.fits
Viewing  2025-11-06/TTT3_iKon936-1_2025-11-07-02-35-07-118053_ZTF25abnjznp_SDSSzs.fits


In [244]:
iv = image_viewer('test_images', list_available=False, folder_list=['2025-11-05'])
# iv.df_files
# iv.header_info(2, interesting_keys='all')
# iv.view(4) #span_x=1000, span_y=1000)
iv.view(2, 
        # centered = ('350.3458d','-03.5082d'),#True,#('22h 40m 30.3s', '+3° 21′ 31″'),
        zoom = ('0 0 40 d', '0:0:40 d'),
        stretch = 'linear',
        # percentile = (1,99),
        # vminmax = (2e4, None),
        cmap = 'gray',
        scalebar_arcsec= 5,
        add_circle= {'center' : ('350.3458d','-03.5082d'),
                     'size' : '0:0:3d',
                     'color' : 'red'})
plt.tight_layout()

Current working directory: /Users/oscar/LB/grav_lens
Image directory defined: /Users/oscar/LB/grav_lens/test_images
Total number of images found:  10
Viewing  2025-11-05/TTT3_iKon936-1_2025-11-05-21-45-13-833208_DESI-350.3458-03.5082_SDSSg.fits


In [54]:
iv = image_viewer('test_images/2025-11-02')

for im_i, filter in zip([1,2],['g','r']):
    hdulist = fits.open(os.path.join(iv.dir_img,iv.return_index(im_i)[0]))
    hdu = hdulist[0]
    globals()[filter] = hdu.data
    hdulist.close()
def normalize(data):
    data_min, data_max = np.percentile(data, (50, 99))
    data = np.clip(data, data_min, data_max)
    return (data - data_min) / (data_max - data_min)
g_n = normalize(g)
r_n = normalize(r)
i = np.zeros_like(g)
rgb_default = make_lupton_rgb(g_n, r_n, i, Q=10, stretch=0.5)#, filename="ngc6976-default.jpeg")
fig, ax = plt.subplots()
ax.imshow(rgb_default, origin='lower')

Current working directory: /Users/oscar/LB/grav_lens
Image directory defined: /Users/oscar/LB/grav_lens/test_images/2025-11-02
Total number of images found:  4


<matplotlib.image.AxesImage at 0x36c7891d0>

In [None]:
g.min(), g.max(), np.(g).mode(), np.median(g)

AttributeError: 'numpy.ndarray' object has no attribute 'mode'

In [44]:
from astropy.utils.data import get_pkg_data_filename

# Read in the three images downloaded from here:
g_name = get_pkg_data_filename('visualization/reprojected_sdss_g.fits.bz2')
r_name = get_pkg_data_filename('visualization/reprojected_sdss_r.fits.bz2')
i_name = get_pkg_data_filename('visualization/reprojected_sdss_i.fits.bz2')
g = fits.getdata(g_name)
r = fits.getdata(r_name)
i = fits.getdata(i_name)

rgb_default = make_lupton_rgb(i, r, g)
fig, ax = plt.subplots()
ax.imshow(rgb_default, origin='lower')

<matplotlib.image.AxesImage at 0x10e28f9d0>

# Work with historic dataframe imported from .pkl file

In [20]:
df = pd.read_pickle('df_files_all.pkl')
#df.set_index('date_time').groupby('object').resample('D').size().unstack(0)
df['day'] = df['date_time'].dt.date
df

Unnamed: 0,filename,path,telescope,camera,object,filter,size_MB,date_time,day
0,ATLAS-Teide-P_2023-05-01-22-01-07-273279_AL62....,/home/oscarsoler/work/red/2023-05-01/ATLAS-Tei...,err,err,err,,244.37664,2000-01-01 00:00:00.000000,2000-01-01
1,ATLAS-Teide-P_2023-05-02-04-51-46-695793_AM82....,/home/oscarsoler/work/red/2023-05-01/ATLAS-Tei...,err,err,err,,244.37088,2000-01-01 00:00:00.000000,2000-01-01
2,ATLAS-Teide-P_2023-05-04-00-05-04-730911_Focus...,/home/oscarsoler/work/red/2023-05-03/ATLAS-Tei...,err,err,err,,244.37664,2000-01-01 00:00:00.000000,2000-01-01
3,ATLAS-Teide-P_2023-05-04-00-23-31-426872_Focus...,/home/oscarsoler/work/red/2023-05-03/ATLAS-Tei...,err,err,err,,244.37088,2000-01-01 00:00:00.000000,2000-01-01
4,ATLAS-Teide-P_2023-05-04-00-32-10-828223_Focus...,/home/oscarsoler/work/red/2023-05-03/ATLAS-Tei...,err,err,err,,244.37088,2000-01-01 00:00:00.000000,2000-01-01
...,...,...,...,...,...,...,...,...,...
1357711,TTT3_iKon936-1_2025-10-31-06-43-58-638998_TOI-...,/home/oscarsoler/work/red/2025-10-30/TTT3_iKon...,TTT3,iKon936-1,TOI-2267,SDSSzs,16.79040,2025-10-31 06:43:58.638998,2025-10-31
1357712,TTT3_iKon936-1_2025-10-31-06-44-28-510311_TOI-...,/home/oscarsoler/work/red/2025-10-30/TTT3_iKon...,TTT3,iKon936-1,TOI-2267,SDSSzs,16.79040,2025-10-31 06:44:28.510311,2025-10-31
1357713,TTT3_iKon936-1_2025-10-31-06-44-58-381037_TOI-...,/home/oscarsoler/work/red/2025-10-30/TTT3_iKon...,TTT3,iKon936-1,TOI-2267,SDSSzs,16.79040,2025-10-31 06:44:58.381037,2025-10-31
1357714,TTT3_iKon936-1_2025-10-31-06-45-28-250995_TOI-...,/home/oscarsoler/work/red/2025-10-30/TTT3_iKon...,TTT3,iKon936-1,TOI-2267,SDSSzs,16.79040,2025-10-31 06:45:28.250995,2025-10-31


In [27]:
grav_lens = ['QSO0957+561', 'Q2237+030', 'MG1654+1346', 'SDSSJ1004+4112', 'LBQS1333+0113', 'SDSSJ0819+5356',
             'EinsteinCross']

# for g in grav_lens:
#     try: print(g, df.value_counts(['object']).loc[(g)])
#     except: print()


# Filter the DataFrame
df_filtered = df[df["object"].isin(grav_lens)].copy()

# Group by day and object, count observations
daily_counts = (
    df_filtered.groupby(["day", "object"])
    .size()
    .reset_index(name="count")
    )

# daily_counts.groupby(['filter']).size()


## Timeline overview of gravitational lens object observations

In [131]:

# Pivot for plotting
pivoted = daily_counts.pivot(index="day", columns="object", values="count").fillna(0)
pivot_for_plot = pivoted.where(pivoted != 0, np.nan)
# Remove days where all objects are NaN (no observations at all)
pivot_for_plot = pivot_for_plot.dropna(how="all")
# optional: sort index (dates) to ensure proper time order
pivot_for_plot = pivot_for_plot.sort_index()
# Plot
fig, ax = plt.subplots()

pivot_for_plot.plot(kind="line", marker="o", lw=0, alpha=0.5, figsize=(10, 5), ax=ax)


ax.set_title("Daily Observations by Object")
ax.minorticks_on()
ax.set_xlabel("Date")
ax.tick_params('x', rotation = 45)
ax.set_ylabel("Number of Observations")
ax.legend(title="Object")
ax.grid(alpha=0.5)
plt.tight_layout()
plt.show()

qt.qpa.backingstore: Back buffer dpr of 1 doesn't match <NSViewBackingLayer: 0x11a229690> contents scale of 2 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 1 doesn't match <NSViewBackingLayer: 0x32ffbe780> contents scale of 2 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 1 doesn't match <NSViewBackingLayer: 0x11a229690> contents scale of 2 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 1 doesn't match <NSViewBackingLayer: 0x32ffbe780> contents scale of 2 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 2 doesn't match <NSViewBackingLayer: 0x11a229690> contents scale of 1 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 1 doesn't match <NSViewBackingLayer: 0x32ffbe780> contents scale of 2 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 1 doesn't match <NSViewBackingLayer: 0x32ffbe780> contents scale of 2 - updating layer to match.
qt.qpa.backingstore: Back buffer dpr of 1

In [94]:
plt.savefig('grav_lens_overview_object_date.png', dpi=200)

(0.12156862745098039, 0.4666666666666667, 0.7058823529411765)

## Overview of timeline observations separated by filter for each object

In [130]:
# Assiging each filter a color
filters_list = df_filtered.groupby('filter').size().keys().to_list()
n_filters = len(filters_list)

colors_dict = {filters_list[i]: plt.cm.tab10.colors[i] for i in range(n_filters)}
colors_list = [plt.cm.tab10.colors[i] for i in range(n_filters)]

# Create figure
# CORRECT to adjust to total number of filters
fig, ax = plt.subplots(4, 2)

for i, obj in enumerate(grav_lens):
    # Group by date, object, filter and count observations
    summary = (
        df_filtered[df_filtered['object']==obj].groupby(['day', 'object', 'filter'])
        .size()
        .reset_index(name='n_observations')
    )
    # Pivot the grouped data: rows are dates, columns are (object, filter) multiindex, values are counts
    pivoted = summary.pivot_table(
                                index='day',
                                columns=['object', 'filter'],
                                values='n_observations',
                                fill_value=0
)
# Plot as stacked bar chart
    #piv_plot = pivoted[pivoted['object']==obj].copy()
    iax = i//2
    jax = i%2
    # extract filters used for the object
    filters_object = summary.groupby('filter').size().keys().to_list()
    colors_object = [colors_dict.get(f, '#333333') for f in filters_object]
    pivoted.plot(
        kind='bar',
        stacked=True,
        figsize=(15, 7),
        ax=ax[iax, jax],
        legend = False,
        color = colors_object
    )
    ax[iax, jax].locator_params(axis='x', nbins=5)
    ax[iax, jax].tick_params('x', rotation = 0)
    ax[iax, jax].set_xlabel('')
    ax[iax, jax].set_title(obj)

    if iax == ax.shape[0]-1:
        ax[iax, jax].set_xlabel('Observation Date')
    if jax==0:
       ax[iax, jax].set_ylabel('Number of\nObservations')
    else:
        ax[iax, jax].set_ylabel('')

fig.suptitle('Daily Number of Observations per Object and Filter')
# Custom legend handles
ax[-1,-1].set_axis_off()
handles = [patches.Patch(color = colors_list[i], label =filters_list[i]) for i in range(n_filters)]
ax[-1,-1].legend(title='Filter', #bbox_to_anchor=(1.05, 1), loc='lower center',
           handles = handles, ncols = n_filters)
plt.tight_layout()
plt.show()

In [92]:
plt.savefig('grav_lens_overview_object_date_filter.png', dpi=200)

In [96]:
filters_list

['SDSSg', 'SDSSi', 'SDSSr', 'SDSSu', 'SDSSy', 'SDSSzs']

In [125]:
df_filtered.groupby(['object', 'filter', 'telescope']).size()

object          filter  telescope
EinsteinCross   SDSSg   TTT3           28
                SDSSr   TTT3           28
LBQS1333+0113   SDSSg   TTT1           65
                        TTT2           90
                SDSSi   TTT1           65
                        TTT2           90
                SDSSr   TTT1           65
                        TTT2           90
MG1654+1346     SDSSg   TTT1          158
                        TTT2          469
                        TTT3          322
                SDSSi   TTT1          159
                        TTT2          468
                        TTT3          322
                SDSSr   TTT1          158
                        TTT2          468
                        TTT3          322
Q2237+030       SDSSg   TTT3           77
                SDSSi   TTT3           77
                SDSSr   TTT3           77
                SDSSzs  TTT3           77
QSO0957+561     SDSSg   TST          1388
                        TTT1         1744


# Future work

In [None]:
# ------- FUTURE WORK: IMAGE DATA HISTOGRAM ---------
# data = iv.read_data(1)
print(np.mean(data))
print(np.median(data))
print(np.std(data))
print(np.max(data),'-', np.min(data))
plt.hist(data.ravel(), bins=100)
plt.yscale('log')
plt.xscale('log')
plt.show()