# Testing dipole measurement and recovery of proper-motion star
Michael Wood-Vasey

Inspired by conversation with Eric Bellm and Lynne Jones.

1. Take a set of images of a region spaced over 10 years.
2. Simulate a star with a given proper motion.  Try things from 0.1 -- 10" over ten years.
3. First test single-frame recovery and astrometry
4. Then run subtractions for each Year N - Year 1 pair.
5. Analyze the dipoles and measurements of these DIA Sources.  Compare to single-frame measurements.  I think the interesting question of concern/interest for DIA are the ability and robustness of dipole fitting and measurement as one approaches no separation.
6. Start with constant brightness source.  Then repeat for variable source and check recovery.
7. Compare individual fitting with scene-modeling.

Could start with just a very simplified gaussian PSF to get a sense, but I think the interesting questions are about the Science Pipeline measurements so shifting to (starting with) more realistic data where the Science Pipelines are really computing PSFs and convolution kernels would be helpful.

I would be tempted to start with DC2 (5 years), and then try it with HSC, DECam images (3 years, maybe 5 years with a little work?).  But the time lag isn't really key as one can scale up the proper motion to compensate.

Notes:
* DC2 stars do not have proper motion.
* I'm ignoring parallax.  One could simulate parallax if one really wanted to model end-to-end, but I don't think it's central to the basic question.

In [None]:
import os

from astropy.wcs import WCS
import astropy.units as u

import gc

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import lsst.afw.display as afwDisplay
from lsst.afw.image import MultibandExposure
from lsst.afw.math import Warper, WarperConfig
import lsst.geom as geom
from lsst.daf.butler import Butler, DimensionUniverse, DatasetType, CollectionType
from lsst.daf.butler.registry import MissingCollectionError
from lsst.ip.diffim import AlardLuptonSubtractConfig, AlardLuptonSubtractTask
from lsst.ip.diffim import DetectAndMeasureConfig, DetectAndMeasureTask
from lsst.pipe.tasks.makeWarp import MakeWarpConfig, MakeWarpTask
from lsst.source.injection.inject_engine import generate_galsim_objects, inject_galsim_objects_into_exposure
import lsst.sphgeom

In [None]:
afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')
%matplotlib inline

In [None]:
user = os.getenv("USER")

collection = "2.2i/runs/DP0.2"
repo_config = "dp02"
output_collection = f"u/{user}/proper_motion"

In [None]:
INJECTED_IMAGES_EXIST = True
SUBTRACTIONS_OF_INJECTED_IMAGE_EXIST = True

In [None]:
butler = Butler(repo_config, collections=collection)

In [None]:
print(butler.registry.getDatasetType("calexp"))

In [None]:
skyMap = butler.get("skyMap", skymap="DC2")

In [None]:
# Do a spatial query for calexps using HTM levels following example in 04b_Intermediate_Butler_Queries.ipynb
ra, dec = 55, -30  # degrees

level = 10  # the resolution of the HTM grid
pixelization = lsst.sphgeom.HtmPixelization(level)

htm_id = pixelization.index(
    lsst.sphgeom.UnitVector3d(
        lsst.sphgeom.LonLat.fromDegrees(ra, dec)
    )
)

In [None]:
htm_id

Get the neighboring HTM pixels

In [None]:
parent_level = htm_id // 10
htm_ids = [parent_level * 10 + i for i in [0, 1, 2, 3]]

In [None]:
htm_ids

In [None]:
hi = htm_ids[0]

# dataset_refs is an iterator, but each query is only a few hundred results,
#   so convert to a list for future convenience
dataset_refs = list(butler.registry.queryDatasets("calexp", htm20=hi, dataId={"band": "i"}))
dataset_refs = set(dataset_refs)
for hi in htm_ids:
    dr = list(butler.registry.queryDatasets("calexp", htm20=hi, dataId={"band": "i"}))
    dataset_refs = dataset_refs.intersection(dr)

In [None]:
dataset_refs = list(dataset_refs)
# Sort by visitId to get a loose time order
ids_visit = [dr.dataId["visit"] for dr in dataset_refs]
dataset_refs = [dataset_refs[idx] for idx in np.argsort(ids_visit)]

print(dataset_refs)

In [None]:
print(f"Found {len(list(dataset_refs))} calexps")

In [None]:
visit_table = butler.get("visitTable") #, dataset_refs[0].dataId)

We should find ~17 calexps matching all.

## Inject object

In [None]:
# Simulate star
# Start with a constant magnitude
BASE_MAG_STAR = 17  # u.mag
BASE_STAR_RA, BASE_STAR_DEC = 54.982, -29.81
BASE_MJD = 60_000

# Wrap with a function to allow accessing consistent with a future expansion to time.
def mag_star(phase=0 * u.d):
    return BASE_MAG_STAR * u.mag

def ra_star(phase=0 * u.d, pm_ra=0.1 * u.arcsec / u.yr):
    return BASE_STAR_RA * u.degree + phase * pm_ra

def dec_star(phase=0 * u.d, pm_dec=0.1 * u.arcsec / u.yr):
    return BASE_STAR_DEC * u.degree + phase * pm_dec

Create a catalog for each visit

In [None]:
dataset_refs

In [None]:
si_cat = {}
for dr in dataset_refs:
    visit = dr.dataId["visit"]
    mjd = visit_table.loc[visit]["expMidptMJD"]
    phase = (mjd - BASE_MJD) * u.d
    si_cat[visit] = [dict(
        ra=ra_star(phase).to_value(u.degree),
        dec=dec_star(phase).to_value(u.degree),
        mag=mag_star().to_value(u.mag),
        source_type="DeltaFunction",
        index=["test"],
    )]


### Register a Collection to write injection catalog and injected images to

See `source_injection` si_demo_dc2_visit.ipynb

In [None]:
writeable_butler = Butler(repo_config,
                          run=output_collection,
                          collections=output_collection,
                          writeable=True)

if not INJECTED_IMAGES_EXIST:
    try:
        writeable_butler.removeRuns([output_collection])
    except MissingCollectionError:
        print("Writing into a new RUN collection")
        pass
    else:
        print("Prior RUN collection located and successfully removed")

_ = writeable_butler.registry.registerCollection(output_collection, type=CollectionType.RUN)

In [None]:
print(butler)

In [None]:
print(writeable_butler)

## Make an injection catalog generator, save a copy as a list, reprovide as generator

In [None]:
def injected_and_save_image(data_id,
                            si_cat,
                            read_butler,
                            write_butler,
                            dataset_type="calexp",
                            mask_plane_name: str = "INJECTED",
                            calib_flux_radius: float = 12.0,
                            draw_size_scale: float = 1.0,
                            draw_size_max: int = 1000,
                            verbose=True):
    "Load an image, injected a catalog, and save to a new collection"
    calexp = read_butler.get(dataset_type, dataId=data_id)

    si_object_generator = generate_galsim_objects(
        injection_catalog=si_cat[calexp.visitInfo.id],
        wcs=calexp.getWcs(),
        photo_calib=calexp.getPhotoCalib(),
        fits_alignment = "wcs",
    )

    # ( draw_sizes, common_bounds, fft_size_errors, psf_compute_errors, ) = 

    if verbose:
        print("Inserting objects into {calexp}")
        
    _ = inject_galsim_objects_into_exposure(
        calexp,
        si_object_generator,
        mask_plane_name=mask_plane_name,
        calib_flux_radius=calib_flux_radius,
        draw_size_scale=draw_size_scale,
        draw_size_max=draw_size_max,
    )
    
    if verbose:
        print("Saving newly injected image: ", calexp)
        print(write_butler)

    write_butler.put(calexp, dataset_type, dataId=data_id)

In [None]:
print(writeable_butler.run)

In [None]:
if not INJECTED_IMAGES_EXIST:
    for dr in dataset_refs:
        injected_and_save_image(dr.dataId, si_cat, butler, writeable_butler)

# Run subtraction between calexps 2-N and calexp 1.

In [None]:
config = AlardLuptonSubtractConfig()
task = AlardLuptonSubtractTask(config=config)

In [None]:
template = writeable_butler.get("calexp", dataset_refs[0].dataId)

In [None]:
def subtract(science, template, source):
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    science_wcs = science.getWcs()
    science_bbox = science.getBBox()

    science.visitInfo.id
    
    # Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated
    warped_template = warper.warpExposure(science_wcs, template, destBBox=science_bbox)
    warped_template.setPsf(template.getPsf())
    
    # Now let's do the subtraction
    subtraction = task.run(warped_template, science, source_catalog)
    
    return subtraction


def detect(science, subtraction):
    # Run detection on subtraction
    detect_and_measure_config = DetectAndMeasureConfig()
    detect_and_measure_task = DetectAndMeasureTask(config=detect_and_measure_config)

    detect_and_measure = detect_and_measure_task.run(science, subtraction.matchedTemplate, subtraction.difference)

    return detect_and_measure

In [None]:
def subtract_and_detect(data_id: dict,
                        template: lsst.afw.image.exposure.ExposureF,
                        butler: Butler):
    """
    Subtract template image from image referred to by data_id and run detection.
    
    Butler needs to be writeable to store output of subtraction and detection.
    """
    science = butler.get("calexp", dr.dataId)
    source_catalog = butler.get("src", dataId=dr.dataId)

    subtraction = subtract(science, template, source_catalog)
    butler.put(subtraction.difference, "goodSeeingDiff_differenceExp", dataId=dr.dataId)

    detection_catalog = detect(science, subtraction)
    butler.put(detection_catalog.diaSources, "goodSeeingDiff_diaSrc", dataId=dr.dataId)

In [None]:
# The template is the first image, so start at 1:
# This dataset_ref list is sorted by visit,
#   which should be equivalent to sorting by MJD
if not SUBTRACTIONS_OF_INJECTED_IMAGE_EXIST:
    for dr in dataset_refs[1:]:
        subtract_and_detect(dr.dataId, template, writeable_butler)

## Show cut outs for each subtraction

## Some helper utilities for plotting

In [None]:
def show_image_on_wcs(calexp, figsize=(8, 8), x=None, y=None,
                      pixel_extent=None, stamp_size=None,
                      marker="o", color="red", size=20):
    """
    Specifying both pixel_extent and size is undefined.
    """
    fig = plt.figure(figsize=figsize)
    plt.subplot(projection=WCS(calexp.getWcs().getFitsMetadata()))
    if stamp_size is not None and x is not None and y is not None:    
        half_stamp = stamp_size / 2
        # If x and y are of different types, then user should clarify what they wanted
        if np.isscalar(x):
            first_x = x
            first_y = y
        else:
            first_x = x[0]
            first_y = y[0]
            
        pixel_extent = (int(first_x - half_stamp), int(first_x + half_stamp),
                        int(first_y - half_stamp), int(first_y + half_stamp))
    if pixel_extent is None:
        pixel_extent = (int(calexp.getBBox().beginX), int(calexp.getBBox().endX),
                        int(calexp.getBBox().beginY), int(calexp.getBBox().endY))
    # Image array is y, x.  
    # So we select from the image array in [Y_Begin:Y_End, X_Begin:X_End]
    # But then `extent` is (X_Begin, X_End, Y_Begin, Y_End)
    im = plt.imshow(calexp.image.array[pixel_extent[2]:pixel_extent[3], pixel_extent[0]:pixel_extent[1]],
                    cmap="gray", vmin=-200.0, vmax=400,
                    extent=pixel_extent, origin="lower")
    plt.grid(color="white", ls="solid")
    plt.xlabel("Right Ascension")
    plt.ylabel("Declination")
    if x is not None and y is not None:
        plt.scatter(x, y, s=size, marker=marker, edgecolor=color, facecolor="none")
    plt.show()

### Identify pixel regions to focus on

In [None]:
def getSiXyFromCalexp(visit_id, calexp, si_cat=si_cat):
    """
    visit_id: Visit index into catalog that has RA, Dec
    calexp: Determines the WCS frame that converts RA, Dec -> x, y
    """
    this_si_cat = si_cat[visit_id]
    xy_coords = calexp.getWcs().skyToPixel(geom.SpherePoint(this_si_cat[0]["ra"], this_si_cat[0]["dec"], geom.degrees))
    return xy_coords    

In [None]:
template_xy_coords = getSiXyFromCalexp(template.visitInfo.id, template)

In [None]:
stamp_size = 400

Here's the template

In [None]:
xy_coords = getSiXyFromCalexp(template.visitInfo.id, template)
show_image_on_wcs(template, x=xy_coords.x, y=xy_coords.y, stamp_size=stamp_size)

Now let's step through the calexps and subtractions

We can see the difference between the location of the star at the template epoch ("green") from the location of the star in the science epoch ("red").

In [None]:
def plot_calexp_cutout(data_id, template, butler=writeable_butler):
    calexp = butler.get("calexp", dataId=dr.dataId)
    # Warp the template to get the orientation
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    template_wcs = template.getWcs()
    template_bbox = template.getBBox()
    
    # Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated
    warped_calexp = warper.warpExposure(template_wcs, calexp, destBBox=template_bbox)
    warped_calexp.setPsf(calexp.getPsf())

    template_xy_coords = getSiXyFromCalexp(template.visitInfo.id, warped_calexp)
    xy_coords = getSiXyFromCalexp(warped_calexp.visitInfo.id, warped_calexp)
    colors = ["green", "red"]
    print(calexp.visitInfo.id)
    show_image_on_wcs(warped_calexp,
                      x=[template_xy_coords.x, xy_coords.x],
                      y=[template_xy_coords.y, xy_coords.y],
                      color=colors, stamp_size=stamp_size,
                     figsize=(3, 3))

In [None]:
stamp_size = 100
# Note that each image will be shown in its own orientation.
for dr in dataset_refs[1:]:
    plot_calexp_cutout(dr.dataId, template, writeable_butler)


The injected star shows up clearly in the subtraction against the original template.

If we subtract the injected template from the injected science we see a dipole

In [None]:
def plot_subtraction_cutout(data_id, tempalte, butler=writeable_butler):
    calexp = butler.get("goodSeeingDiff_differenceExp", dataId=data_id)
    src = butler.get("goodSeeingDiff_diaSrc", dataID=data_id)
    # Warp the template to get the orientation
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    template_wcs = template.getWcs()
    template_bbox = template.getBBox()
    
    # Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated
    warped_calexp = warper.warpExposure(template_wcs, calexp, destBBox=template_bbox)
    warped_calexp.setPsf(calexp.getPsf())

    template_xy_coords = getSiXyFromCalexp(template.visitInfo.id, warped_calexp)
    xy_coords = getSiXyFromCalexp(warped_calexp.visitInfo.id, warped_calexp)
    colors = ["green", "red"]
    print(calexp.visitInfo.id)
    show_image_on_wcs(warped_calexp,
                      x=[template_xy_coords.x, xy_coords.x],
                      y=[template_xy_coords.y, xy_coords.y],
                      color=colors, stamp_size=stamp_size,
                     figsize=(3, 3))

In [None]:
stamp_size = 100
# Note that each image will be shown in its own orientation.
for dr in dataset_refs[1:]:
    plot_subtraction_cutout(dr.dataId, template, writeable_butler)

Did we recover injected object?

In [None]:
dia_src_injected = detect_and_measure_injected.diaSources.asAstropy()

In [None]:
plt.scatter(dia_src_injected["slot_Centroid_x"], dia_src_injected["slot_Centroid_y"])
plt.scatter([science_xy_coords.x], [science_xy_coords.y], color="red", s=20)

In [None]:
# Match in a simple way
threshold_dist = 2  # pixels
threshold_dist_sq = threshold_dist ** 2

dist_sq = (dia_src_injected["slot_Centroid_x"] - science_xy_coords.x)**2 + \
          (dia_src_injected["slot_Centroid_y"] - science_xy_coords.y)**2

idx, = np.where(dist_sq < threshold_dist_sq)
matching_injected = dia_src_injected[idx]

In [None]:
matching_injected

There are many diffim measurements and flags.  We're here most direclty interested in the dipole flag.

In [None]:
import re
# dipole_cols = re.compile("diffim.*ipole")
dipole_cols = re.compile("_value")

[c for c in list(matching_injected.columns) if dipole_cols.search(c)]

In [None]:
dipole_min_sn = detect_and_measure_config.measurement.plugins['ip_diffim_DipoleFit'].minSn
dipole_threshold_ratio = detect_and_measure_config.measurement.plugins['ip_diffim_DipoleFit'].maxFluxRatio
print(f"Dipole flag will be set for objects with S/N > {dipole_min_sn:0.2f}")
print(f"and {1-dipole_threshold_ratio:0.2f} < neg_flux/tot_flux = 1 - pos_flux/tot_flux < {dipole_threshold_ratio:0.2f}")

In [None]:
neg_flux = np.abs(dia_src_injected["ip_diffim_PsfDipoleFlux_neg_instFlux"])
pos_flux = np.abs(dia_src_injected["ip_diffim_PsfDipoleFlux_pos_instFlux"])
tot_flux = neg_flux + pos_flux
neg_ratio = neg_flux / tot_flux
pos_ratio = pos_flux / tot_flux

plt.scatter(neg_ratio, pos_ratio)
plt.xlabel("neg_flux / (pos_flux + neg_flux)")
plt.ylabel("pos_flux / (pos_flux + neg_flux)")
span_kwargs = {"color": "orange", "alpha": 0.1}
plt.axvspan(dipole_threshold_ratio, 1, **span_kwargs)
plt.axvspan(0, 1 - dipole_threshold_ratio, **span_kwargs)
plt.axhspan(dipole_threshold_ratio, 1, **span_kwargs)
plt.axhspan(0, 1 - dipole_threshold_ratio, **span_kwargs)

ax = plt.gca()
plt.xlim(0, 1)
plt.ylim(0, 1)
ax.set_aspect("equal")

The dipoles are those in the extremes (upper-left and lower-right corners).

This is a perfect y = 1 - x line by construction because we're plotting y/(x+y)  vs. x/(x+y) = 1 - y/(x+y).

So we can equivalently look at this in one dimension:

In [None]:
bins = np.linspace(-0.05, 1.05, 51)
plt.hist(pos_ratio, bins=bins)
plt.xlabel("pos_flux / (pos_flux + neg_flux)")
plt.axvspan(dipole_threshold_ratio, 1, **span_kwargs)
plt.axvspan(0, 1 - dipole_threshold_ratio, **span_kwargs)
plt.xlim(0, 1)
# plt.hist(neg_ratio, bins=bins, histtype="step");

In [None]:
columns_of_interest = ["ip_diffim_PsfDipoleFlux_pos_instFlux", "ip_diffim_PsfDipoleFlux_neg_instFlux", "ip_diffim_ClassificationDipole_value"]
matching_injected[columns_of_interest]

### Appendix.  Snippets of code I might want later

In [None]:
# Data set, DC2  Pick an RA, Dec region.  get 10 images
tract = 4639
patch = 0
data_id = {"tract": tract, "patch": patch}

# Get images that overlap this patch in i-band
data_id["band"] = "i"

In [None]:
class deferred_wrapper():
    def __init__(self, cat):
        self.cat = cat
        
    def get(self):
        return self.cat

# si_cat_deferred = deferred_wrapper(si_cat)

In [None]:
# If you want a test calexp, here's one:
NEED_TEST_CALEXP = False
if NEED_TEST_CALEXP:
    test_dataId = {"visit": 421727, "detector": 157, "band": "i"}  # Known existing exposure
    test_calexp = butler.get("calexp", dataId=dataId)

In [None]:
def remove_figure(fig):
    """
    Remove a figure to reduce memory footprint.

    Parameters
    ----------
    fig: matplotlib.figure.Figure
        Figure to be removed.

    Returns
    -------
    None
    """
    # get the axes and clear their images
    for ax in fig.get_axes():
        for im in ax.get_images():
            im.remove()
    fig.clf()       # clear the figure
    plt.close(fig)  # close the figure
    gc.collect()    # call the garbage collector