<img align="left" src = https://project.lsst.org/sites/default/files/Rubin-O-Logo_0.png width=250 style="padding: 10px"> 
<b>Testing dipole measurement and recovery of a star with proper-motion</b> <br>
Contact author: Michael Wood-Vasey <br>
Last verified to run: 2023-05-15 <br>
LSST Science Pipelines version: Weekly 2023_07 + source_injection tickets/DM-34253 <br>
Container Size: large <br>
Targeted learning level: intermediate <br>

What do subtractions of a start that is moving look like?

Note: This Notebook is written below the PipelineTask level.  Rather is uses individual Tasks directly and reads/writes output products to the butler.  This is pedagogically useful to understand how that works, and pratically helpful in working with the evolving `source_injection` package.  However, this structure is not scalable to larger runs (100+ images).  Such large-scale runs should be done as part of an integrated Task that can be connected and run through the large-scale cluster jobs submission.

Inspired by conversation with Eric Bellm and Lynne Jones.

1. [x] Identify a set of images of a region spaced over 10 years.
2. [~] Simulate a star with a given proper motion.  Try things from 0.1 -- 10" over ten years.  
    a. Currently doing just one 0.1"/yr example.
3. [x] First test single-frame recovery and astrometry
4. [x] Then run subtractions for each Year N - Year 1 pair.
5. [x] Analyze the dipoles and measurements of these DIA Sources.  Compare to single-frame measurements.  I think the interesting question of concern/interest for DIA are the ability and robustness of dipole fitting and measurement as one approaches no separation.
6. [x] Start with constant brightness source.
7. [ ] Repeat for variable source and check recovery.
8. [ ] Compare individual fitting with scene-modeling.

Ideas for next steps from 2023-05-16 DESC DIA TT/DIA Science Unit Meeting:

1. [ ] Explore other dipole parameters such as dipole_length, angle, SNR.
2. [ ] Explore a range of magnitudes.
3. [ ] Think about classification cuts.
4. [ ] Analytic answer to 0.35-0.65.
5. [ ] Signed total_flux.  Actual sum of pos_flux and tot_flux.
6. [ ] Forced total flux.  Forced total PSF.
7. [ ] Match to catalog?  Maybe that's already there in the variance plane?
8. [ ] If we convince ourselves that we need a new measurement, we can make the argument.

This Notebook is designed to work with DC2[1] (5 years).  Should be expandable to HSC (3 years), DECam (5 years) images with appropriate changes to repo_config and RA, Dec.  The time lag of the data set isn't really key as one can scale up the proper motion to compensate[2].

Footnotes:

[1] DC2 stars do not have proper motion.  
[2] I'm ignoring parallax.  One could simulate parallax if one really wanted to model end-to-end, but I don't think it's central to the basic question.

In [None]:
import os

import astropy
from astropy.wcs import WCS
import astropy.units as u
from astropy.coordinates import SkyCoord

import gc

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import lsst.afw.display as afwDisplay
from lsst.afw.image import MultibandExposure
from lsst.afw.math import Warper, WarperConfig
import lsst.geom as geom
from lsst.daf.butler import Butler, DimensionUniverse, DatasetType, CollectionType
from lsst.daf.butler.registry import MissingCollectionError
from lsst.ip.diffim import AlardLuptonSubtractConfig, AlardLuptonSubtractTask
from lsst.ip.diffim import DetectAndMeasureConfig, DetectAndMeasureTask
from lsst.pipe.tasks.makeWarp import MakeWarpConfig, MakeWarpTask
from lsst.source.injection.inject_engine import generate_galsim_objects, inject_galsim_objects_into_exposure
import lsst.sphgeom

In [None]:
afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')
%matplotlib inline

In [None]:
user = os.getenv("USER")

collection = "2.2i/runs/DP0.2"
repo_config = "dp02"
output_collection = f"u/{user}/proper_motion"

In [None]:
INJECTED_IMAGES_EXIST = False
SUBTRACTIONS_OF_INJECTED_IMAGE_EXIST = False

In [None]:
butler = Butler(repo_config, collections=collection)

In [None]:
print(butler.registry.getDatasetType("calexp"))

In [None]:
# Do a spatial query for calexps using HTM levels following example in 04b_Intermediate_Butler_Queries.ipynb
ra, dec = 55, -30  # degrees

level = 10  # the resolution of the HTM grid
pixelization = lsst.sphgeom.HtmPixelization(level)

htm_id = pixelization.index(
    lsst.sphgeom.UnitVector3d(
        lsst.sphgeom.LonLat.fromDegrees(ra, dec)
    )
)

In [None]:
htm_id

Get the neighboring HTM pixels

In [None]:
parent_level = htm_id // 10
htm_ids = [parent_level * 10 + i for i in [0, 1, 2, 3]]

In [None]:
htm_ids

In [None]:
hi = htm_ids[0]

# dataset_refs is an iterator, but each query is only a few hundred results,
#   so convert to a list for future convenience
dataset_refs = list(butler.registry.queryDatasets("calexp", htm20=hi, dataId={"band": "i"}))
dataset_refs = set(dataset_refs)
for hi in htm_ids:
    dr = list(butler.registry.queryDatasets("calexp", htm20=hi, dataId={"band": "i"}))
    dataset_refs = dataset_refs.intersection(dr)

In [None]:
dataset_refs = list(dataset_refs)
# Sort by visitId to get a loose time order
ids_visit = [dr.dataId["visit"] for dr in dataset_refs]
dataset_refs = [dataset_refs[idx] for idx in np.argsort(ids_visit)]

print(dataset_refs)

In [None]:
print(f"Found {len(list(dataset_refs))} calexps")

In [None]:
visit_table = butler.get("visitTable") #, dataset_refs[0].dataId)

We should find ~17 calexps matching all.

## Inject object

In [None]:
# Simulate star
# Start with a constant magnitude
BASE_MAG_STAR = 17  # u.mag
BASE_STAR_RA, BASE_STAR_DEC = 54.982, -29.81
BASE_MJD = 60_000

# Wrap with a function to allow accessing consistent with a future expansion to time.
def mag_star(phase=0 * u.d):
    return BASE_MAG_STAR * u.mag

def ra_star(phase=0 * u.d, pm_ra=0.1 * u.arcsec / u.yr):
    return BASE_STAR_RA * u.degree + phase * pm_ra

def dec_star(phase=0 * u.d, pm_dec=0.1 * u.arcsec / u.yr):
    return BASE_STAR_DEC * u.degree + phase * pm_dec

Create a catalog for each visit

In [None]:
dataset_refs

In [None]:
injected_cat = {}
for dr in dataset_refs:
    visit = dr.dataId["visit"]
    mjd = visit_table.loc[visit]["expMidptMJD"]
    phase = (mjd - BASE_MJD) * u.d
    injected_cat[visit] = [dict(
                                ra=ra_star(phase).to_value(u.degree),
                                dec=dec_star(phase).to_value(u.degree),
                                mag=mag_star().to_value(u.mag),
                                source_type="DeltaFunction",
                                index=["test"],
    )]


### Register a Collection to write injection catalog and injected images to

See `source_injection` si_demo_dc2_visit.ipynb

In [None]:
writeable_butler = Butler(repo_config,
                          run=output_collection,
                          collections=output_collection,
                          writeable=True)

if not INJECTED_IMAGES_EXIST:
    try:
        writeable_butler.removeRuns([output_collection])
    except MissingCollectionError:
        print("Writing into a new RUN collection")
        pass
    else:
        print("Prior RUN collection located and successfully removed")

_ = writeable_butler.registry.registerCollection(output_collection, type=CollectionType.RUN)

In [None]:
print(butler)

In [None]:
print(writeable_butler)

## Make an injection catalog generator, save a copy as a list, reprovide as generator

In [None]:
def injected_and_save_image(data_id,
                            injected_cat,
                            read_butler,
                            write_butler,
                            dataset_type="calexp",
                            mask_plane_name: str = "INJECTED",
                            calib_flux_radius: float = 12.0,
                            draw_size_scale: float = 1.0,
                            draw_size_max: int = 1000,
                            verbose=True):
    "Load an image, injected a catalog, and save to a new collection"
    calexp = read_butler.get(dataset_type, dataId=data_id)

    injected_object_generator = generate_galsim_objects(
        injection_catalog=injected_cat[calexp.visitInfo.id],
        wcs=calexp.getWcs(),
        photo_calib=calexp.getPhotoCalib(),
        fits_alignment = "wcs",
    )

    # ( draw_sizes, common_bounds, fft_size_errors, psf_compute_errors, ) = 

    if verbose:
        print("Inserting objects into {calexp}")
        
    _ = inject_galsim_objects_into_exposure(
        calexp,
        injected_object_generator,
        mask_plane_name=mask_plane_name,
        calib_flux_radius=calib_flux_radius,
        draw_size_scale=draw_size_scale,
        draw_size_max=draw_size_max,
    )
    
    if verbose:
        print("Saving newly injected image: ", calexp)
        print(write_butler)

    write_butler.put(calexp, dataset_type, dataId=data_id)

In [None]:
print(writeable_butler.run)

In [None]:
if not INJECTED_IMAGES_EXIST:
    for dr in dataset_refs:
        injected_and_save_image(dr.dataId, injected_cat, butler, writeable_butler)

# Run subtraction between calexps 2-N and calexp 1.

In [None]:
config = AlardLuptonSubtractConfig()
task = AlardLuptonSubtractTask(config=config)

In [None]:
template = writeable_butler.get("calexp", dataset_refs[0].dataId)

In [None]:
def subtract(science, template, source_catalog):
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    science_wcs = science.getWcs()
    science_bbox = science.getBBox()

    science.visitInfo.id
    
    # Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated
    warped_template = warper.warpExposure(science_wcs, template, destBBox=science_bbox)
    warped_template.setPsf(template.getPsf())
    
    # Now let's do the subtraction
    subtraction = task.run(warped_template, science, source_catalog)
    
    return subtraction


def detect(science, subtraction):
    # Run detection on subtraction
    detect_and_measure_config = DetectAndMeasureConfig()
    detect_and_measure_task = DetectAndMeasureTask(config=detect_and_measure_config)

    detect_and_measure = detect_and_measure_task.run(science, subtraction.matchedTemplate, subtraction.difference)

    return detect_and_measure

In [None]:
def subtract_and_detect(data_id: dict,
                        template: lsst.afw.image.exposure.ExposureF,
                        previous_butler: Butler,
                        writeable_butler: Butler):
    """
    Subtract template image from image referred to by data_id and run detection.
    
    Butler needs to be writeable to store output of subtraction and detection.
    """
    science = writeable_butler.get("calexp", dr.dataId)
    source_catalog = previous_butler.get("src", dataId=dr.dataId)

    subtraction = subtract(science, template, source_catalog)
    writeable_butler.put(subtraction.difference, "goodSeeingDiff_differenceExp", dataId=dr.dataId)

    detection_catalog = detect(science, subtraction)
    writeable_butler.put(detection_catalog.diaSources, "goodSeeingDiff_diaSrc", dataId=dr.dataId)

In [None]:
# The template is the first image, so start at 1:
# This dataset_ref list is sorted by visit,
#   which should be equivalent to sorting by MJD
if not SUBTRACTIONS_OF_INJECTED_IMAGE_EXIST:
    for dr in dataset_refs[1:]:
        subtract_and_detect(dr.dataId, template, butler, writeable_butler)

## Show cut outs for each subtraction

## Some helper utilities for plotting

In [None]:
def show_image_on_wcs(calexp, figsize=(8, 8), ax=None, x=None, y=None,
                      pixel_extent=None, stamp_size=None,
                      marker="o", color="red", size=20):
    """
    Show an image with an RA, Dec grid overlaid.  Optionally add markers.
    
    Notes
    -----
    Specifying both pixel_extent and size is undefined.
    """
    if ax is None:
        fig = plt.figure(figsize=figsize)
        plt.subplot(projection=WCS(calexp.getWcs().getFitsMetadata()))
        ax = plt.gca()

    if stamp_size is not None and x is not None and y is not None:    
        half_stamp = stamp_size / 2
        # If x and y are of different types, then user should clarify what they wanted
        if np.isscalar(x):
            first_x = x
            first_y = y
        else:
            first_x = x[0]
            first_y = y[0]
            
        pixel_extent = (int(first_x - half_stamp), int(first_x + half_stamp),
                        int(first_y - half_stamp), int(first_y + half_stamp))
    if pixel_extent is None:
        pixel_extent = (int(calexp.getBBox().beginX), int(calexp.getBBox().endX),
                        int(calexp.getBBox().beginY), int(calexp.getBBox().endY))
    # Image array is y, x.  
    # So we select from the image array in [Y_Begin:Y_End, X_Begin:X_End]
    # But then `extent` is (X_Begin, X_End, Y_Begin, Y_End)

    im = ax.imshow(calexp.image.array[pixel_extent[2]:pixel_extent[3], pixel_extent[0]:pixel_extent[1]],
                   cmap="gray", vmin=-200.0, vmax=400,
                   extent=pixel_extent, origin="lower")
    ax.grid(color="white", ls="solid")
    ax.set_xlabel("Right Ascension")
    ax.set_ylabel("Declination")
    if x is not None and y is not None:
        ax.scatter(x, y, s=size, marker=marker, edgecolor=color, facecolor="none")

### Identify pixel regions to focus on

In [None]:
def getSiXyFromCalexp(visit_id, calexp, injected_cat=injected_cat):
    """
    visit_id: Visit index into catalog that has RA, Dec
    calexp: Determines the WCS frame that converts RA, Dec -> x, y
    """
    this_injected_cat = injected_cat[visit_id]
    xy_coords = calexp.getWcs().skyToPixel(geom.SpherePoint(this_injected_cat[0]["ra"], this_injected_cat[0]["dec"], geom.degrees))
    return xy_coords    

In [None]:
template_xy_coords = getSiXyFromCalexp(template.visitInfo.id, template)

In [None]:
stamp_size = 400

Here's the template

In [None]:
xy_coords = getSiXyFromCalexp(template.visitInfo.id, template)
show_image_on_wcs(template, x=xy_coords.x, y=xy_coords.y, stamp_size=stamp_size)

Now let's step through the calexps and subtractions

We can see the difference between the location of the star at the template epoch ("green") from the location of the star in the science epoch ("red").

In [None]:
def plot_calexp_cutout(data_id, template, butler=writeable_butler, ax=None, visit_table=visit_table, injected_cat=injected_cat, verbose=True):
    calexp = butler.get("calexp", dataId=dr.dataId)
    # Warp the template to get the orientation
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    template_wcs = template.getWcs()
    template_bbox = template.getBBox()
    
    # Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated
    warped_calexp = warper.warpExposure(template_wcs, calexp, destBBox=template_bbox)
    warped_calexp.setPsf(calexp.getPsf())

    template_xy_coords = getSiXyFromCalexp(template.visitInfo.id, warped_calexp, injected_cat=injected_cat)
    xy_coords = getSiXyFromCalexp(warped_calexp.visitInfo.id, warped_calexp, injected_cat=injected_cat)
    colors = ["green", "red"]
    
    visit_id = calexp.visitInfo.id
    ra, dec = injected_cat[visit_id][0]["ra"], injected_cat[visit_id][0]["dec"]
    template_ra, template_dec = injected_cat[template.visitInfo.id][0]["ra"], injected_cat[template.visitInfo.id][0]["dec"]
    separation = SkyCoord(ra, dec, unit=u.degree).separation(SkyCoord(template_ra, template_dec, unit=u.degree))

    if verbose:
        print(f"Visit: {visit_id}, MJD: {visit_table.loc[visit_id]['expMidptMJD']:0.6f}, " + \
              f"RA: {ra:0.7f}, Dec: {dec:0.7f}, Displacement: {separation.arcsec:0.6f} arcsec")

    show_image_on_wcs(warped_calexp,
                      ax=ax,
                      x=[template_xy_coords.x, xy_coords.x],
                      y=[template_xy_coords.y, xy_coords.y],
                      color=colors, stamp_size=stamp_size,
                      figsize=(3, 3))
        
    del calexp

In [None]:
def get_matched_object(dia_src: astropy.table.Table, ra, dec):
    # Match in a simple way
    threshold_dist = 2 / 3600  # arcseconds
    threshold_dist_sq = threshold_dist ** 2

    dist_sq = ((np.rad2deg(dia_src["coord_ra"]) - ra) * np.cos(dia_src["coord_dec"]))**2 + \
              (np.rad2deg(dia_src["coord_dec"]) - dec)**2

    idx, = np.where(dist_sq < threshold_dist_sq)
    try:
        matching_injected = dia_src[idx]
    except:
        matching_injected = None
    
    return matching_injected

In [None]:
def plot_subtraction_cutout(data_id, template, butler=writeable_butler, ax=None, verbose=True, figsize=None):
    calexp = butler.get("goodSeeingDiff_differenceExp", dataId=data_id)
    dia_src = butler.get("goodSeeingDiff_diaSrc", dataId=data_id)
    # I find Astropy Tables easier to think about than the custom lsst.afw.table SourceCatalog
    dia_src = dia_src.asAstropy()
    # Warp the template to get the orientation
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    template_wcs = template.getWcs()
    template_bbox = template.getBBox()
    
    # Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated
    warped_calexp = warper.warpExposure(template_wcs, calexp, destBBox=template_bbox)
    warped_calexp.setPsf(calexp.getPsf())

    template_xy_coords = getSiXyFromCalexp(template.visitInfo.id, warped_calexp)
    xy_coords = getSiXyFromCalexp(warped_calexp.visitInfo.id, warped_calexp)
    x_array = [template_xy_coords.x, xy_coords.x]
    y_array = [template_xy_coords.y, xy_coords.y]
    colors = ["green", "red"]
    size = [20, 20]
   
    visit_id = calexp.visitInfo.id
    ra, dec = injected_cat[visit_id][0]["ra"], injected_cat[visit_id][0]["dec"]
    template_ra, template_dec = injected_cat[template.visitInfo.id][0]["ra"], injected_cat[template.visitInfo.id][0]["dec"]
    separation = SkyCoord(ra, dec, unit=u.degree).separation(SkyCoord(template_ra, template_dec, unit=u.degree))
    
    if verbose:
        print(f"Visit: {visit_id}, MJD: {visit_table.loc[visit_id]['expMidptMJD']:0.6f}, " + \
              f"RA: {ra:0.7f}, Dec: {dec:0.7f}, Displacement: {separation.arcsec:0.6f} arcsec")
    
    matching_injected = get_matched_object(dia_src, ra, dec)
    
    columns_of_interest = ["ip_diffim_PsfDipoleFlux_pos_instFlux",
                           "ip_diffim_PsfDipoleFlux_neg_instFlux",
                           "ip_diffim_ClassificationDipole_value"]

    if matching_injected is None or len(matching_injected) < 1:
        print("No matching dia source found for injected object.")
    else:
        print(f"Dipole: ")
        matching_injected[columns_of_interest].pprint(max_width=-1)
        
        dia_src_xy_coords = warped_calexp.getWcs().skyToPixel(geom.SpherePoint(matching_injected["coord_ra"],
                                                                               matching_injected["coord_dec"],
                                                                               geom.radians))
        x_array.append(dia_src_xy_coords.x)
        y_array.append(dia_src_xy_coords.y)
        colors.append("blue")
        size.append(40)

    show_image_on_wcs(warped_calexp, ax=ax, x=x_array, y=y_array, color=colors, size=size, stamp_size=stamp_size, figsize=figsize);
    
    del calexp
    del dia_src

In [None]:
stamp_size = 100
figsize = (8, 5)
# Note that each image will be shown in its own orientation.
for dr in dataset_refs[1:]:
    plt.figure(figsize=figsize)
    plt.subplot(1, 2, 1, projection=WCS(template.getWcs().getFitsMetadata()))
    plot_calexp_cutout(dr.dataId, template, writeable_butler, ax=plt.gca(), verbose=False);
    plt.subplot(1, 2, 2, projection=WCS(template.getWcs().getFitsMetadata()))
    plot_subtraction_cutout(dr.dataId, template, writeable_butler, ax=plt.gca());
    plt.tight_layout()
    plt.show()

Green: Location of star in template  
Red: Location of star in science  
Yellow: Location of matching dia_src

The injected star shows up clearly in the subtraction against the original template.

If we subtract the injected template from the injected science we see a dipole

### Appendix.  Snippets of code one might want later

In [None]:
def remove_figure(fig):
    """
    Remove a figure to reduce memory footprint.

    Parameters
    ----------
    fig: matplotlib.figure.Figure
        Figure to be removed.

    Returns
    -------
    None
    """
    # get the axes and clear their images
    for ax in fig.get_axes():
        for im in ax.get_images():
            im.remove()
    fig.clf()       # clear the figure
    plt.close(fig)  # close the figure
    gc.collect()    # call the garbage collector