# Region Cataloger

This notebook takes a DS9 region file and calculates the photometry in the regions
for the specified image types.

The gist of the notebook is as follows.

## Pre-Processing Directives

<code style="background:yellow;color:black">If running on Google Colab</code>, uncomment (remove the `# `) the line below and run the line.

In [None]:
# %pip install astropy astroquery regions

## Imports

In [None]:
# Python Imports
from pathlib import Path
from functools import lru_cache
from typing import Union, Iterable, NoReturn
from tqdm.notebook import tqdm
from warnings import catch_warnings, simplefilter

In [None]:
# Numerical Imports
import numpy as np

In [None]:
# Astropy Imports
from astropy import units as u
from astropy.io import fits
from astropy.wcs import WCS
from astropy.table import Table, QTable, vstack
from astropy.coordinates import SkyCoord

# Astroquery Imports
from astroquery.mast import Observations

# Regions
from regions import Regions

# PhotUtils
from photutils.aperture import ApertureStats, SkyCircularAperture as SkyCA
from photutils.background import Background2D

## Typing

In [None]:
PathLike = Union[str, Path]

## Classes

In [None]:
# Image Type Class
class ImageSpecification:
    """Class to keep the image type associted with a space telescope."""

    # Class Constructor
    def __init__(
            self, telescope: str, instrument: str, detector: str, filter: str,
            base_path: PathLike = Path('.')
        ):
        """Initializes the ImageSpecification with telescope, instrument, detector, and filter."""
        self.telescope = telescope
        self.instrument = instrument
        self.detector = detector
        self.filter = filter
        self.base_path = Path(base_path)

    # Default Representation
    def __repr__(self):
        """Returns a string representation of the ImageSpecification."""
        return f"{self.telescope}-{self.instrument}-{self.detector}-{self.filter}"

    # String Representation
    def __str__(self):
        """Returns a string representation of the ImageSpecification."""
        return self.__repr__()

    # Hash Function
    def __hash__(self):
        """Returns a hash of the ImageSpecification."""
        return hash((self.telescope, self.instrument, self.detector, self.filter))

    # Make Directory
    def make_directory(self) -> NoReturn:
        """Creates a directory for the image type."""

        # Create the directory if it does not exist
        self.cache_path.mkdir(parents=True, exist_ok=True)

    # Get the MAST Instrument
    @property
    def mast_instrument(self) -> str:
        """Returns the MAST instrument name."""
        return f"{self.instrument}/{self.detector}"

    # Get the Cache Directory
    @property
    @lru_cache(maxsize=1)
    def cache_path(self) -> Path:
        """Returns the cache directory for the image type."""
        cache_dir = self.base_path / self.telescope / self.instrument / self.detector / self.filter
        return cache_dir

    # From FITS Header
    @classmethod
    def from_header(cls, header : fits.Header) -> 'ImageSpecification':
        """Creates an ImageSpecification from a FITS header."""

        # Get the Filter Set
        filtSet = set([
            header.get('FILTER'),
            header.get('FILTER1'),
            header.get('FILTER2')
        ])
        filters = [filt for filt in filtSet if filt is not None and not 'clear' in filt.lower()]

        # Return the ImageSpecification
        return cls(
            telescope=header.get('TELESCOP'),
            instrument=header.get('INSTRUME'),
            detector=header.get('DETECTOR'),
            filter='-'.join(filters)
        )

    # From a FITS File
    @classmethod
    def from_fits(cls, fits_file: PathLike) -> 'ImageSpecification':
        """Creates an ImageSpecification from a FITS file."""
        return cls.from_header(fits.getheader(fits_file, ext=0))

    # From String
    @classmethod
    def from_string(cls, img_type_str: str) -> 'ImageSpecification':
        """Creates an ImageSpecification from a string (presumably in the format
        'Telescope-Instrument-Detector-Filter').
        """

        # Split the string and validate the parts
        parts = img_type_str.split('-')
        if len(parts) != 4:
            raise ValueError("Image type string must be in the format 'Telescope-Instrument-Detector-Filter'")
        return cls(*parts)

In [None]:
# Extending the default Regions class to handle SkyCoord more easily
class ExtendedRegions(Regions):
    """Extends the Regions class for easier SkyCoord handling."""

    # Have to Call the Parent Read Method due to the RegionsRegistry
    @classmethod
    def read(
        cls, filename: PathLike, format: str=None, cache: bool=False, **kwargs
    ) -> 'ExtendedRegions':
        """Reads regions from a file and returns an ExtendedRegions object."""

        # Read the regions using the parent class method
        regions = Regions.read(filename, format=format, cache=cache, **kwargs)

        # Return an instance of ExtendedRegions
        return cls(regions)

    # Make a Getter for the Coordinates
    @property
    @lru_cache
    def coordinates(self) -> SkyCoord:
        """Returns the SkyCoord object of the regions."""
        return SkyCoord(
            [reg.center for reg in self]
        )

    # Calculate the Center and Extent of the Regions
    def get_center_and_extent(
            self, extentAdd : u.Quantity=10*u.arcsec
            ) -> tuple[SkyCoord, u.Quantity]:
        """Calculates the center and extent of the regions."""

        # Get the Center
        center = SkyCoord(
            ra=self.coordinates.ra.mean(),
            dec=self.coordinates.dec.mean(),
            frame=self.coordinates.frame
        )

        # Get the Extent (Max Sep)
        extent = center.separation(self.coordinates).max().to(u.arcmin)

        # Return the Center and Extent
        return center, extent + extentAdd

In [None]:
# Image Getter Class
ImageSpecTypes = Union[ImageSpecification, Iterable[ImageSpecification]]
class ImageGetter:

    # Constructor
    def __init__(self, img_types: ImageSpecTypes):
        """Initializes the ImageGetter with an image type and base path."""
        self.img_types = list(img_types)

    # Download the Images
    def download_images(
            self, regions: ExtendedRegions, cache: bool = True
        ) -> dict[ImageSpecification, Union[Table, None]]:
        """Downloads images for the specified regions and image types."""

        # Get the Center and Extent of the Regions
        center, extent = regions.get_center_and_extent()

        # Query MAST for Observations
        obs_table = Observations.query_region(
            center, radius=extent
        )

        # Loop through each Image Type
        responses = {}
        for img_type in tqdm(self.img_types, desc="Downloading Images for Each Filter"):

            # Make the Directory
            img_type.make_directory()

            # Filter the Observations
            # https://masttest.stsci.edu/api/v0/_productsfields.html
            filt_table = Observations.filter_products(
                obs_table,
                intentType='science',
                obs_collection=img_type.telescope,
                instrument_name=img_type.mast_instrument,
                filters=img_type.filter,
                calib_level=[3],
                project='HAP',
                provenance_name=['HAP-SVM'],
                dataproduct_type='image'
            )

            # Download the Images
            # https://mast.stsci.edu/api/v0/_c_a_o_mfields.html
            if len(filt_table):
                responses[img_type] = Observations.download_products(
                    filt_table['obsid'], download_dir=img_type.cache_path,
                    productSubGroupDescription=['DRC'],
                    project='HAP-SVM',
                    calib_level=[3],
                    flat=True,
                    cache=cache
                )
            else:
                responses[img_type] = None

        # Return the Responses
        return responses

In [None]:
# Make the Region Cataloger Class
class RegionCataloger:
    """Class to handle the cataloging of regions and downloading images."""

    # Constructor
    def __init__(self, regions: ExtendedRegions, img_types: ImageSpecTypes):
        """Initializes the RegionCataloger with regions and image types."""

        # Store the Inputs
        self.regions = regions
        self.img_types = list(img_types)

        # Make an ImageGetter
        self.image_getter = ImageGetter(self.img_types)

        # Download the Images
        self.responses = self._download_images(cache=True)

        # Calculate the Photometry
        self.photometry = {}
        for img_type, response in tqdm(self.responses.items(), desc="Calculating Photometry for Each Filter"):
            # Calculate the Photometry for each Image Spec
            self.photometry[img_type] = self._calculate_photometry(img_type, response)

    # Download Images for Regions
    def _download_images(
            self, cache: bool = True
        ) -> dict[ImageSpecification, Union[Table, None]]:
        """Calls the ImageGetter to download images for the specified regions."""
        return self.image_getter.download_images(self.regions, cache=cache)

    # Calculate Photometry for a given Image Type
    def _calculate_photometry(
            self, img_type: ImageSpecification, response: Union[Table, None]
        ) -> Union[QTable, None]:
        """Calculates the photometry for a given image type."""

        # If no response, return None
        if response is None:
            return None
        else:
            # Calculate the Photometry by Calling Photutils
            return self._aperture_photometry(img_type, response)

    # Call Photutils to Calculate the Photometry
    def _aperture_photometry(
            self, img_type: ImageSpecification, response: Table
        ) -> QTable:
        """Calculates the aperture photometry for a given image type."""

        # Get Regions Centroid
        center = self.regions.get_center_and_extent()[0]

        # Loop through the Possible Images
        photTables = []
        for fileName in tqdm(response['Local Path'], desc=f"Processing {img_type} Images", leave=False):

            # Open the Image
            with fits.open(fileName) as hdul:
                image_data = hdul['SCI'].data
                header = hdul['SCI'].header
                wcs = WCS(header)

            # If the WCS does not contain the Regions, continue
            if not wcs.footprint_contains(center):
                continue

            # Get the Background
            with catch_warnings():
                simplefilter("ignore")
                bkg = Background2D(image_data, (64, 64), filter_size=(3, 3))

            # Loop over Regions
            stats = []
            for reg in tqdm(self.regions, desc=f"Processing Regions for {Path(fileName).name}", leave=False):

                # Get the Aperture
                aper = SkyCA(reg.center, r=reg.radius)

                # Get the Stats for the Aperture
                stats.append(ApertureStats(
                    image_data - bkg.background, aper, wcs=wcs
                ))

            # Convert the Stats to a QTable
            statsTable = QTable(vstack(
                [stat.to_table() for stat in stats],
                metadata_conflicts='silent'
            ))
            photTables.append(statsTable)

        # If there is only one table, get it. Otherwise, get the means
        if len(photTables) == 1:
            photTable = photTables[0]
        elif len(photTables) == 0:
            photTable = QTable()
        else:
            # Get first table as a base
            photTable = photTables[0].copy()

            # Get the Mean of Each Column
            # This is overkill as some columns will have the same data in each
            # but its easier to just do this
            with catch_warnings():
                simplefilter("ignore")  # Ignore where columns are NaN everywhere (like error cols)
                for col in photTable.colnames:
                    photTable[col] = np.nanmean(
                        [table[col] for table in photTables],
                        axis=0
                    )

        # Calculate the Aperture Photometry by Calling Photutils
        return photTable

## Make the Catalog

In [None]:
# Default Image Specifications
DEF_IMG_SPECS = [
    ImageSpecification('HST', 'WFC3', 'UVIS', 'F336W'),
    ImageSpecification('HST', 'ACS', 'WFC', 'F475W'),
    ImageSpecification('HST', 'ACS', 'WFC', 'F814W')
]

# Region File Name
REGION_FILE_NAME = 'ExampleNDs.reg'

In [None]:
regs = ExtendedRegions.read(REGION_FILE_NAME)

In [None]:
catalogs = RegionCataloger(regs, DEF_IMG_SPECS)

In [None]:
catalogs.photometry[DEF_IMG_SPECS[1]]  # Access photometry for the first image type