In [1]:
from pathlib import Path
from typing import Literal, Callable
from warnings import warn
from functools import wraps
import time

import scipy.stats as stats
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from scipy.ndimage import label

from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.wcs import WCS
import astropy.visualization as vis

from lsst.geom import SpherePoint, Point2I
import lsst.geom as geom
from lsst.daf.butler import Butler
from lsst.rsp import get_tap_service
from lsst.afw.math import Warper, WarperConfig
from lsst.source.injection import generate_injection_catalog
from lsst.source.injection import VisitInjectConfig, VisitInjectTask
from lsst.ip.diffim.subtractImages import AlardLuptonSubtractTask, AlardLuptonSubtractConfig
from lsst.ip.diffim.detectAndMeasure import DetectAndMeasureTask, DetectAndMeasureConfig


def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} took {total_time:.4f} seconds')
        return result
    return timeit_wrapper

In [2]:
service = get_tap_service("tap")

# Source detection over multiple simulated afterglows

We put together the code for anomaly detection we developed previously and execute it over our simulated afterglow events.

## Detect, table, annotate pipeline

In [3]:
enrich_short_labels = {
    "base_PsfFlux_instFlux": "flux_base_psf",
    "ip_diffim_forced_PsfFlux_instFlux": "flux_diffim_psf",
    "base_CircularApertureFlux_12_0_instFlux": "flux_base_12",
}
enrich_short_errors = {f"{k}Err": f"{v}Err" for k, v in enrich_short_labels.items()}
enrich_target_labels = [*enrich_short_labels.keys()]
enrich_target_errors = [f"{label}Err" for label in enrich_target_labels]


def entwine(a: list, b: list) -> list:
    return [item for pair in zip(a, b) for item in pair]


def enrich_snr(df: pd.DataFrame, labels_flux: list[str], labels_err: list[str]) -> pd.DataFrame:
    if any(missing_labels := [l for l in entwine(enrich_target_labels, enrich_target_errors) if l not in df.columns]):
        raise ValueError(f"Some of the labels are missing: {' '.join(missing_labels)}.")
    _d = {
        f"snr_{label_f}": df[label_f] / df[label_err] 
        for label_f, label_err in zip(labels_flux, labels_err)
    }
    return pd.concat((df, pd.DataFrame(_d)), axis=1)

def enrich_mag(df: pd.DataFrame, labels_flux: list[str], labels_err: list[str], f2mag: Callable) -> pd.DataFrame:
    if any(missing_labels := [l for l in entwine(enrich_target_labels, enrich_target_errors) if l not in df.columns]):
        raise ValueError(f"Some of the labels are missing: {' '.join(missing_labels)}.")
    _d = {}
    for label_f, label_err in zip(labels_flux, labels_err):
        _d[f"mag_{label_f}"] = [f2mag(f) for f in df[label_f]]
        _d[f"magErrBot_{label_f}"] = [-2.5 * np.log10(1 + delta_f / f) for f, delta_f in zip(df[label_f], df[label_err],)]
        _d[f"magErrTop_{label_f}"] = [-2.5 * np.log10(1 - delta_f / f) for f, delta_f in zip(df[label_f], df[label_err],)]
    return pd.concat((df, pd.DataFrame(_d)), axis=1)

@timeit
def detect(science, template, difference, threshold: float | None = None):
    config = DetectAndMeasureConfig()
    if threshold is not None:
        config.detection.thresholdValue = threshold
    anomalies = DetectAndMeasureTask(config=config).run(science, template, difference)
    _df = anomalies.getDict()["diaSources"].asAstropy().to_pandas()
    _df = enrich_snr(_df, enrich_target_labels, enrich_target_errors)
    _df = enrich_mag(_df, enrich_target_labels, enrich_target_errors, science.getPhotoCalib().instFluxToMagnitude)
    _d = (
        {
            "is_negative": _df["is_negative"],
            "coord_ra": np.rad2deg(_df["coord_ra"]),
            "coord_dec": np.rad2deg(_df["coord_dec"]),
            "area_psf": _df["base_PsfFlux_area"],
        } | {v: _df[k] for k, v in enrich_short_labels.items()}
          | {v: _df[k] for k, v in enrich_short_errors.items()}
          | {f"snr_{v}": _df[f"snr_{k}"] for k, v in enrich_short_labels.items()} 
          | {f"mag_{v}": _df[f"mag_{k}"] for k, v in enrich_short_labels.items()} 
          | {f"magErrBot_{v}": _df[f"magErrBot_{k}"] for k, v in enrich_short_labels.items()} 
          | {f"magErrTop_{v}": _df[f"magErrTop_{k}"] for k, v in enrich_short_labels.items()}
    )
    return pd.DataFrame(_d).sort_values("snr_flux_base_psf", ascending=False).reset_index(drop=True)


ANNOTATION_PREFIX = "ann"

@timeit
def annotate_donuts(table_df: pd.DataFrame) -> pd.DataFrame:
    mask = table_df["flux_base_psf"] < 0
    return pd.concat([table_df, pd.DataFrame({f"{ANNOTATION_PREFIX}_donut": mask})], axis=1)


@timeit
def annotate_mask(table_df: pd.DataFrame, difference) -> pd.DataFrame:
    keywords: list[list[str]] = [get_mask_labels_coord(difference, ra, dec) for ra, dec in table_df[["coord_ra", "coord_dec"]].values]
    return pd.concat(
        (table_df,
        pd.DataFrame({
                f"{ANNOTATION_PREFIX}_mask_{k.lower()}": [k in ks for ks in keywords ]
                for k in set([k for ks in keywords for k in ks])
        })),
        axis=1,
    ).reset_index(drop=True)


@timeit
def annotate_star_closeness(
    table_df: pd.DataFrame,
    star_catalog: pd.DataFrame,
    tolerance_arcsec: float = 0.2, # pixel size
) -> pd.DataFrame:
    source_coords = table_df[["coord_ra", "coord_dec"]].values
    ref_coords = star_catalog[["coord_ra", "coord_dec"]].values
    tolerance_rad = np.deg2rad(tolerance_arcsec / 3600.0)
    s_ = np.deg2rad(source_coords[:, np.newaxis, :])
    r_ = np.deg2rad(ref_coords[np.newaxis, :, :])
    distances = np.acos(np.sin(s_[:, :, 1]) * np.sin(r_[:, :, 1]) + np.cos(s_[:, :, 1]) * np.cos(r_[:, :, 1]) * np.cos(s_[:, :, 0] - r_[:, :, 0]))
    mask = np.any(distances  < tolerance_rad, axis=1)
    return pd.concat([table_df, pd.DataFrame({f"{ANNOTATION_PREFIX}_closestar": mask})], axis=1)


# TODO: improve this weak implementation
@timeit
def annotate_ext_overlap(
    table_df: pd.DataFrame,
    science,
    galaxies_catalog: pd.DataFrame,
)-> pd.DataFrame:
    mask_detected = get_mask(science, "DETECTED").astype(int)
    labelled, nlabels = label(mask_detected)
    out = np.zeros_like(labelled).astype(bool)
    # TODO: improve this implementation vectorizing it
    for i, source in galaxies_catalog.iterrows():
        ra, dec = source[["coord_ra", "coord_dec"]].values
        if not science.containsSkyCoords(ra * u.deg, dec * u.deg):
            continue
        pixel = coord_to_pix(science, ra, dec)
        value = labelled[*pixel]
        if value == 0:
            continue
        out |= (labelled == value)
    mask = table_df[["coord_ra", "coord_dec"]].apply(
        lambda x: out[*coord_to_pix(science, x.coord_ra, x.coord_dec)], 
        axis=1
    )
    return pd.concat([table_df, pd.DataFrame({f"{ANNOTATION_PREFIX}_extoverlap": mask})], axis=1)
    

### Mask Tools

In [4]:
def get_mask(image, mask_names: str | list[str]) -> np.ndarray:
    mask = image.getMask()
    if isinstance(mask_names, str):
        mask_names = [mask_names]

    mask_array = mask.getArray()
    out = np.zeros_like(mask_array)
    for mask_name in mask_names:
        target_bit = mask.getMaskPlane(mask_name)
        out |= (mask_array & (2 ** target_bit)) != 0
    return out

def coord_to_pix(image, ra, dec) -> tuple[int, int]:
    wcs = WCS(image.getWcs().getFitsMetadata())
    coord = SkyCoord(
        ra = ra * u.deg,
        dec = dec * u.deg
    )
    i, j = wcs.world_to_array_index(coord)
    # apparently can return pixel at boundary
    return min(i, image.getHeight() - 1), min(j, image.getWidth() - 1), 

def get_mask_value_coord(image, ra, dec) -> int:
    return image.getMask()[*coord_to_pix(image, ra, dec)[::-1]]

def get_mask_labels_coord(image, ra, dec) -> list[str]:
    return [
        keyword 
        for keyword, bit in image.getMask().getMaskPlaneDict().items() 
        if get_mask_value_coord(image, ra, dec) & (2 ** bit) != 0
    ]

### Catalog helpers

In [5]:
def sregion_to_vertices(sregion: str, closed=False):
    """Convert the s_region from the ObsCore table into two
    arrays containing the x and y vertices, in order to plot
    boxes using matplotlib.

    from notebook DPO02_02c.
    """
    temp = sregion.split(' ')
    xvertices = []
    yvertices = []
    ix = 2
    iy = 3
    for c in range(4):
        xvertices.append(float(temp[ix]))
        yvertices.append(float(temp[iy]))
        ix += 2
        iy += 2
    if closed:
        xvertices.append(xvertices[0])
        yvertices.append(yvertices[0])
    return list(zip(xvertices, yvertices))


def polygon_string(sregion: str):
    """
    Usage:
        * Plug in queries like:
        
            ```
            (.. more query ..)
            WHERE CONTAINS(POINT('ICRS', obj.coord_ra, obj.coord_dec), {polygon_string(s_region)}) = 1
            (.. more query ..)                                          ^^^^^^^^^^^^^^^^^^^^^^^^
            ```
    """
    return f"""POLYGON('ICRS', {
    ', '.join(map(
        lambda x: f'{x[0]:.6f}, {x[1]:.6f}', 
        sregion_to_vertices(sregion, False),
    ))})"""


def calexp_ccdvisitid(calexp) -> str:
    detector = calexp.detector.getId()
    id_ = calexp.visitInfo.id
    ccdvisitid = f"{id_}{detector:03d}"
    return ccdvisitid


@timeit
def sources_in(service, calexp, extendedness: bool | None = None) -> pd.DataFrame:
    band = calexp.getFilter().bandLabel
    ccdvisitid = calexp_ccdvisitid(calexp)

    query_sregion = (
        f"""
        SELECT s_region 
        FROM dp02_dc2_catalogs.ObsCore 
        WHERE lsst_ccdvisitid = {ccdvisitid} 
        AND dataproduct_subtype = 'lsst.calexp'
        """
    )
    s_region = (_t := service.search(query_sregion).to_table())[0][0]
    assert len(_t) == 1
    and_extendedness = f"AND obj.{band}_extendedness = {int(extendedness)}" if extendedness is not None else ""
    query_objects = (
        f"""
        SELECT obj.coord_ra AS coord_ra,
        obj.coord_dec AS coord_dec,
        obj.{band}_extendedness AS {band}_extendedness,
        scisql_nanojanskyToAbMag(obj.{band}_cModelFlux) AS cModelMag_{band}, 
        obj.{band}_cModelFlux AS {band}_cModelFlux,
        obj.footprintArea AS footprintArea 
        FROM dp02_dc2_catalogs.Object AS obj 
        WHERE CONTAINS(POINT('ICRS', obj.coord_ra, obj.coord_dec), {polygon_string(s_region)}) = 1
        AND obj.detect_isPrimary = 1 
        {and_extendedness}
        """
    )
    return service.search(query_objects).to_table().to_pandas()

In [6]:
def detection_pipeline(science, template, difference, threshold: float | None = None, footprint_area: int | None = None) -> pd.DataFrame:
    print("Starting source detection.")
    anomalies = detect(science, template, difference, threshold=threshold)
    sources = sources_in(service, science)
    pointlike_sources = sources[sources["r_extendedness"] == 0]
    results = annotate_star_closeness(
        annotate_mask(
            annotate_donuts(anomalies), 
            difference,
        ),
        pointlike_sources,
    )
    if footprint_area is None:
        return results
    extended_sources = sources[sources["footprintArea"] >= footprint_area]
    return annotate_ext_overlap(results, template, extended_sources)


# Testing

In [7]:
from doppelganger.doppelganger.rubin.inject import sfis_pipeline

In [8]:
ra = 56.90063
dec = -33.94851
mag = 17.
band = "r"
footprint_area = None
threshold = 4

images = sfis_pipeline(service, ra, dec, mag, band)

Retrieving visit table, choosing one visit at random.
Starting source injection.


Starting DIA.


In [9]:
annotated_anomalies = detection_pipeline(**images, footprint_area = footprint_area, threshold = threshold)

Starting source detection.


  warn("Using UFloat objects with std_dev==0 may give unexpected results.")


  _d[f"magErrBot_{label_f}"] = [-2.5 * np.log10(1 + delta_f / f) for f, delta_f in zip(df[label_f], df[label_err],)]
  _d[f"magErrTop_{label_f}"] = [-2.5 * np.log10(1 - delta_f / f) for f, delta_f in zip(df[label_f], df[label_err],)]


Function detect took 8.8913 seconds
Function sources_in took 10.3326 seconds
Function annotate_donuts took 0.0006 seconds
Function annotate_mask took 44.3252 seconds
Function annotate_star_closeness took 0.0443 seconds


In [10]:
annotated_anomalies

Unnamed: 0,is_negative,coord_ra,coord_dec,area_psf,flux_base_psf,flux_diffim_psf,flux_base_3,flux_base_psfErr,flux_diffim_psfErr,flux_base_3Err,...,magErrTop_flux_diffim_psf,magErrTop_flux_base_3,ann_donut,ann_mask_sat_template,ann_mask_detected,ann_mask_injected_core,ann_mask_injected,ann_mask_crosstalk,ann_mask_detected_negative,ann_closestar
0,False,56.900630,-33.948510,40.101875,1.179062e+06,1.180266e+06,1.174872e+06,1588.421135,1582.884205,1531.898861,...,0.001457,0.001417,False,False,True,True,True,False,False,False
1,False,56.934948,-33.831676,40.586796,1.869201e+05,2.445706e+06,1.004794e+05,2351.615958,2315.587965,2080.323484,...,0.001028,0.022715,False,True,True,False,False,False,False,False
2,False,56.760003,-33.751603,40.373829,1.737938e+05,2.412877e+06,1.085861e+05,2347.490718,2303.039682,2077.629573,...,0.001037,0.020975,False,True,True,False,False,False,False,False
3,False,56.900940,-33.942128,40.123875,1.135761e+05,2.167483e+06,6.824161e+04,2203.599289,2170.771344,1973.718735,...,0.001088,0.031865,False,True,True,False,False,False,False,False
4,False,56.887211,-33.860588,40.358974,6.968733e+04,2.071987e+06,3.858901e+04,2151.381497,2126.038838,1927.395971,...,0.001115,0.055630,False,True,True,False,False,False,True,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
305,True,56.796426,-33.765788,40.299557,-2.741390e+03,8.906580e+04,-5.149372e+02,472.116597,466.298062,886.513366,...,0.005699,-1.087059,True,False,False,False,False,False,True,False
306,True,56.907274,-33.955091,40.130928,-2.682977e+03,6.290474e+04,-3.469876e+03,426.260160,420.922995,855.026975,...,0.007290,-0.239156,True,False,False,False,False,False,True,False
307,False,56.978152,-33.821781,40.793056,-1.208045e+04,9.386341e+05,-1.891633e+02,1439.050135,1429.611710,1416.614027,...,0.001655,-2.322121,True,False,True,False,False,False,True,False
308,True,56.957800,-33.827187,40.688553,-8.138419e+03,3.507061e+05,-7.920030e+03,842.774208,835.978531,1143.792703,...,0.002591,-0.146461,True,False,False,False,False,False,True,False


In [11]:
annotated_anomalies.columns

Index(['is_negative', 'coord_ra', 'coord_dec', 'area_psf', 'flux_base_psf',
       'flux_diffim_psf', 'flux_base_3', 'flux_base_psfErr',
       'flux_diffim_psfErr', 'flux_base_3Err', 'snr_flux_base_psf',
       'snr_flux_diffim_psf', 'snr_flux_base_3', 'mag_flux_base_psf',
       'mag_flux_diffim_psf', 'mag_flux_base_3', 'magErrBot_flux_base_psf',
       'magErrBot_flux_diffim_psf', 'magErrBot_flux_base_3',
       'magErrTop_flux_base_psf', 'magErrTop_flux_diffim_psf',
       'magErrTop_flux_base_3', 'ann_donut', 'ann_mask_sat_template',
       'ann_mask_detected', 'ann_mask_injected_core', 'ann_mask_injected',
       'ann_mask_crosstalk', 'ann_mask_detected_negative', 'ann_closestar'],
      dtype='object')

# Operations

In [12]:
np.random.seed(0)

In [None]:
dirname = "data/08_anomalies_250722"

events_df = pd.read_csv("data/06_sourceCatalogSim/afterglow_host.csv")
for i, (ra, dec, mag) in events_df[["ra", "dec", "m"]].iterrows():
    print(f"START SIMULATION of source with index {i:02d}, ra = {ra:.4f}, dec = {dec:.4f}, m = {mag:.4f}")
    results = sfis_pipeline(service, ra, dec, mag, "r")
    annotated_anomalies = detection_pipeline(**results, threshold=4, footprint_area = 4000)
    ccdvisitid = calexp_ccdvisitid(results['science'])
    rastr, decstr = map(lambda x: f"{x:.3f}".replace(".", "d").replace("-", "m"), [ra, dec])
    magstr = f"{mag:.1f}".replace(".", "d")
    idstring = f"i{i:02d}_ccdvisitid{ccdvisitid}_ra{rastr}_dec{decstr}_mag{magstr}"
    
    for k in ["science", "template", "difference"]:
        results[k].writeFits(f"{dirname}/anomalies_{k}_{idstring}.fits")
    annotated_anomalies.to_csv(f"{dirname}/anomalies_{idstring}.csv")
    print("ENDS SIMULATION.\n\n")