# Testing dipole measurement and recovery of proper-motion star
Michael Wood-Vasey

Inspired by conversation with Eric Bellm and Lynne Jones.

1. Take a set of images of a region spaced over 10 years.
2. Simulate a star with a given proper motion.  Try things from 0.1 -- 10" over ten years.
3. First test single-frame recovery and astrometry
4. Then run subtractions for each Year N - Year 1 pair.
5. Analyze the dipoles and measurements of these DIA Sources.  Compare to single-frame measurements.  I think the interesting question of concern/interest for DIA are the ability and robustness of dipole fitting and measurement as one approaches no separation.
6. Start with constant brightness source.  Then repeat for variable source and check recovery.
7. Compare individual fitting with scene-modeling.

Could start with just a very simplified gaussian PSF to get a sense, but I think the interesting questions are about the Science Pipeline measurements so shifting to (starting with) more realistic data where the Science Pipelines are really computing PSFs and convolution kernels would be helpful.

I would be tempted to start with DC2 (5 years), and then try it with HSC, DECam images (3 years, maybe 5 years with a little work?).  But the time lag isn't really key as one can scale up the proper motion to compensate.

Notes:
* DC2 stars do not have proper motion.
* I'm ignoring parallax.  One could simulate parallax if one really wanted to model end-to-end, but I don't think it's central to the basic question.

In [None]:
from astropy.wcs import WCS
import astropy.units as u

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import gc


import lsst.afw.display as afwDisplay
from lsst.afw.image import MultibandExposure
from lsst.afw.math import Warper, WarperConfig
import lsst.geom as geom
from lsst.daf.butler import Butler, DimensionUniverse, DatasetType, CollectionType
from lsst.daf.butler.registry import MissingCollectionError
from lsst.ip.diffim import AlardLuptonSubtractConfig, AlardLuptonSubtractTask
from lsst.ip.diffim import DetectAndMeasureConfig, DetectAndMeasureTask
from lsst.pipe.tasks.makeWarp import MakeWarpConfig, MakeWarpTask
from lsst.source.injection.inject_engine import generate_galsim_objects, inject_galsim_objects_into_exposure
import lsst.sphgeom

In [None]:
afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')
%matplotlib inline

In [None]:
# Data set, DC2  Pick an RA, Dec region.  get 10 images
tract = 4639
patch = 0
data_id = {"tract": tract, "patch": patch}

In [None]:
collection = "2.2i/runs/DP0.2"
repo_config = "dp02"
butler = Butler(repo_config, collections=collection)

In [None]:
print(butler.registry.getDatasetType("calexp"))

In [None]:
# Get images that overlap this patch in i-band
data_id["band"] = "i"

In [None]:
skyMap = butler.get("skyMap", skymap="DC2")

In [None]:
# Do a spatial query for calexps using HTM levels following example in 04b_Intermediate_Butler_Queries.ipynb
ra, dec = 55, -30  # degrees

level = 10  # the resolution of the HTM grid
pixelization = lsst.sphgeom.HtmPixelization(level)

htm_id = pixelization.index(
    lsst.sphgeom.UnitVector3d(
        lsst.sphgeom.LonLat.fromDegrees(ra, dec)
    )
)

In [None]:
htm_id

In [None]:
parent_level = htm_id // 10
htm_ids = [parent_level * 10 + i for i in [0, 1, 2, 3]]

In [None]:
htm_ids

In [None]:
hi = htm_ids[0]

# dataset_refs is an iterator, but each query is only a few hundred results, so convert to a list for future convenience
dataset_refs = list(butler.registry.queryDatasets("calexp", htm20=hi, dataId={"band": "i"}))
dataset_refs = set(dataset_refs)
for hi in htm_ids:
    dr = list(butler.registry.queryDatasets("calexp", htm20=hi, dataId={"band": "i"}))
    dataset_refs = dataset_refs.intersection(dr)

In [None]:
dataset_refs = list(dataset_refs)
# Sort by visitId to get a loose time order
ids_visit = [dr.dataId["visit"] for dr in dataset_refs]
dataset_refs = [dataset_refs[idx] for idx in np.argsort(ids_visit)]

print(dataset_refs)

In [None]:
print(f"Found {len(list(dataset_refs))} calexps")

In [None]:
visit_table = butler.get("visitTable") #, dataset_refs[0].dataId)

We should find ~17 calexps matching all.

## Some helper utilities for plotting

In [None]:
def remove_figure(fig):
    """
    Remove a figure to reduce memory footprint.

    Parameters
    ----------
    fig: matplotlib.figure.Figure
        Figure to be removed.

    Returns
    -------
    None
    """
    # get the axes and clear their images
    for ax in fig.get_axes():
        for im in ax.get_images():
            im.remove()
    fig.clf()       # clear the figure
    plt.close(fig)  # close the figure
    gc.collect()    # call the garbage collector

In [None]:
def show_image_on_wcs(calexp, figsize=(8, 8), x=None, y=None, pixel_extent=None,
                      marker="o", color="red", size=20):
    fig = plt.figure(figsize=figsize)
    plt.subplot(projection=WCS(calexp.getWcs().getFitsMetadata()))
    if pixel_extent is None:
        pixel_extent = (int(calexp.getBBox().beginX), int(calexp.getBBox().endX),
                        int(calexp.getBBox().beginY), int(calexp.getBBox().endY))
    # Image array is y, x.  
    # So we select from the image array in [Y_Begin:Y_End, X_Begin:X_End]
    # But then `extent` is (X_Begin, X_End, Y_Begin, Y_End)
    im = plt.imshow(calexp.image.array[pixel_extent[2]:pixel_extent[3], pixel_extent[0]:pixel_extent[1]],
                    cmap="gray", vmin=-200.0, vmax=400,
                    extent=pixel_extent, origin="lower")
    plt.grid(color="white", ls="solid")
    plt.xlabel("Right Ascension")
    plt.ylabel("Declination")
    if x is not None and y is not None:
        plt.scatter(x, y, s=size, marker=marker, edgecolor=color, facecolor="none")
    plt.show()
#    remove_figure(fig)

## Inject object

In [None]:
# Simulate star
# Start with a constant magnitude
BASE_MAG_STAR = 17  # u.mag
BASE_STAR_RA, BASE_STAR_DEC = 54.982, -29.81
BASE_MJD = 60_000

# Wrap with a function to allow accessing consistent with a future expansion to time.
def mag_star(phase=0 * u.d):
    return BASE_MAG_STAR * u.mag

def ra_star(phase=0 * u.d, pm_ra=0.1 * u.arcsec / u.yr):
    return BASE_STAR_RA * u.degree + phase * pm_ra

def dec_star(phase=0 * u.d, pm_dec=0.1 * u.arcsec / u.yr):
    return BASE_STAR_DEC * u.degree + phase * pm_dec

Create a catalog for each visit

In [None]:
dataset_refs

In [None]:
si_cat = {}
for dr in dataset_refs:
    visit = dr.dataId["visit"]
    mjd = visit_table.loc[visit]["expMidptMJD"]
    phase = (mjd - BASE_MJD) * u.d
    si_cat[visit] = [dict(
        ra=ra_star(phase).to_value(u.degree),
        dec=dec_star(phase).to_value(u.degree),
        mag=mag_star().to_value(u.mag),
        source_type="DeltaFunction",
        index=["test"],
    )]


### Register a Collection to write injection catalog and injected images to

See `source_injection` si_demo_dc2_visit.ipynb

In [None]:
writeable_butler =  Butler(repo_config, writeable=True)

si_input_collection = "u/wmwv/proper_motion"

try:
    writeable_butler.removeRuns([si_input_collection])
except MissingCollectionError:
    print("Writing into a new RUN collection")
    pass
else:
    print("Prior RUN collection located and successfully removed")

_ = writeable_butler.registry.registerCollection(si_input_collection, type=CollectionType.RUN)

### Register a DatasetType for these catalogs in the repo

This isn't necessary once it's been done once, which it has for `2.2i/runs/DP0.2` at the IDF RSP.  Thus `registerDatasetType` returns `False`, which means the DatasetType is already there.

In [None]:
si_dataset_type =  DatasetType(
    "si_cat",
    dimensions=["skymap", "tract"],
    storageClass="DataFrame",
    universe=DimensionUniverse(),
)

writeable_butler.registry.registerDatasetType(si_dataset_type)

### Put the source injection catalog into the repo

In [None]:
for visit, cat in si_cat.items():
    si_dataId = dict(tract=tract, visit=visit, skymap="DC2")

    writeable_butler.put(pd.DataFrame(cat), si_dataset_type, si_dataId, run=si_input_collection)

# Run subtraction between two calexps

In [None]:
config = AlardLuptonSubtractConfig()
task = AlardLuptonSubtractTask(config=config)

In [None]:
print(dataset_refs[0].dataId)
print(dataset_refs[1].dataId)

In [None]:
template = butler.get("calexp", dataset_refs[0].dataId)
science = butler.get("calexp", dataset_refs[-1].dataId)
source_catalog = butler.get("src", dataId=dataset_refs[-1].dataId)

In [None]:
warper_config = WarperConfig()
warper =  Warper.fromConfig(warper_config)

In [None]:
science_wcs = science.getWcs()
science_bbox = science.getBBox()

In [None]:
science.visitInfo.id

## Make an injection catalog generator, save a copy as a list, reprovide as generator

In [None]:
template_si_object_generator = generate_galsim_objects(
    injection_catalog=si_cat[template.visitInfo.id],
    wcs=template.getWcs(),
    photo_calib=template.getPhotoCalib(),
    fits_alignment = "wcs",
)

template_si_objects = list(template_si_object_generator)
template_si_object_generator = (o for o in template_si_objects)

In [None]:
science_si_object_generator = generate_galsim_objects(
    injection_catalog=si_cat[template.visitInfo.id],
    wcs=science.getWcs(),
    photo_calib=science.getPhotoCalib(),
    fits_alignment = "wcs",
)

science_si_objects = list(science_si_object_generator)
science_si_object_generator = (o for o in science_si_objects)

In [None]:
print(template_si_objects)

In [None]:
print(science_si_objects)

In [None]:
mask_plane_name: str = "INJECTED"
calib_flux_radius: float = 12.0
draw_size_scale: float = 1.0
draw_size_max: int = 1000

In [None]:
injected_template = template.clone()
injected_science = science.clone()

In [None]:
# ( draw_sizes, common_bounds, fft_size_errors, psf_compute_errors, ) = 

_ = inject_galsim_objects_into_exposure(
    injected_template,
    template_si_object_generator,
    mask_plane_name=mask_plane_name,
    calib_flux_radius=calib_flux_radius,
    draw_size_scale=draw_size_scale,
    draw_size_max=draw_size_max,
)

In [None]:
_ = inject_galsim_objects_into_exposure(
    injected_science,
    science_si_object_generator,
    mask_plane_name=mask_plane_name,
    calib_flux_radius=calib_flux_radius,
    draw_size_scale=draw_size_scale,
    draw_size_max=draw_size_max,
)

In [None]:
size = 400

half_size = size / 2
science_xy_coords = science_si_objects[0][1]
science_xy_extent = (int(science_xy_coords.x - half_size), int(science_xy_coords.x + half_size),
                     int(science_xy_coords.y - half_size), int(science_xy_coords.y + half_size))

In [None]:
science_xy_coords_template = science.getWcs().skyToPixel(geom.SpherePoint(template_si_cat[0]["ra"], template_si_cat[0]["dec"], geom.degrees))

In [None]:
template_xy_coords = template_si_objects[0][1]
template_xy_extent = (int(template_xy_coords.x - half_size), int(template_xy_coords.x + half_size),
                      int(template_xy_coords.y - half_size), int(template_xy_coords.y + half_size))

In [None]:
print(science_xy_extent)
print(template_xy_extent)

In [None]:
# https://github.com/lsst/ip_diffim/blob/7a89bf037d00e8b7659df4fab0fd3c36f68be89f/python/lsst/ip/diffim/getTemplate.py#L531
warped_template = warper.warpExposure(science_wcs, template, destBBox=science_bbox)
warped_injected_template = warper.warpExposure(science_wcs, injected_template, destBBox=science_bbox)

Add PSF.  I think doing this directly without warping is wrong.  At least the x,y mapping should be updated

In [None]:
warped_template.setPsf(template.getPsf())
warped_injected_template.setPsf(injected_template.getPsf())

Here's what the template looks like when warped to the science image wcs and bounding box.

In [None]:
show_image_on_wcs(injected_science, x=science_xy_coords.x, y=science_xy_coords.y, pixel_extent=science_xy_extent)

In [None]:
show_image_on_wcs(warped_injected_template, x=science_xy_coords.x, y=science_xy_coords.y, pixel_extent=science_xy_extent)

### Simple subtraction of arrays with no matching of flux level or PSF confirms that we have a displacement

In [None]:
diff = science.clone()
diff.image.array = injected_science.image.array - warped_injected_template.image.array

In [None]:
show_image_on_wcs(diff, x=science_xy_coords.x, y=science_xy_coords.y, pixel_extent=science_xy_extent)

### Now let's do the subtraction right with PSF matching and flux normalization

In [None]:
subtraction_injected = task.run(warped_template, injected_science, source_catalog)

In [None]:
subtraction_injected_injected = task.run(warped_injected_template, injected_science, source_catalog)

In [None]:
subtraction_injected

In [None]:
show_image_on_wcs(subtraction_injected.matchedScience, x=science_xy_coords.x, y=science_xy_coords.y, pixel_extent=science_xy_extent)

In [None]:
show_image_on_wcs(subtraction_injected.matchedTemplate, x=science_xy_coords.x, y=science_xy_coords.y, pixel_extent=science_xy_extent)

The injected star shows up clearly in the subtraction against the original template.

In [None]:
show_image_on_wcs(subtraction_injected.difference, x=science_xy_coords.x, y=science_xy_coords.y, pixel_extent=science_xy_extent)

In [None]:
x = [science_xy_coords.x, science_xy_coords_template.x]
y = [science_xy_coords.y, science_xy_coords_template.y]
colors = ["red", "green"]

Here's the template we injected a source on.  We can see the difference between the location of the star at the template epoch ("green") from the location of the star in the science epoch ("red").

In [None]:
show_image_on_wcs(subtraction_injected_injected.matchedTemplate, x=x, y=y, color=colors, pixel_extent=science_xy_extent)

If we subtract the injected template from the injected science we see a dipole

In [None]:
show_image_on_wcs(subtraction_injected_injected.difference, x=x, y=y, color=colors, pixel_extent=science_xy_extent)

In [None]:
def show_image_with_mask_plane(calexp, figsize=(8, 8)):
    fig, ax = plt.subplots(figsize=figsize)
    display = afwDisplay.Display(frame=fig)
    display.scale('asinh', 'zscale')
    display.setMaskTransparency(80)
    display.setMaskPlaneColor('DETECTED', 'blue')
    display.mtv(calexp)
    plt.show()
    remove_figure(fig)
    
    return display

In [None]:
show_image_with_mask_plane(warped_template)

In [None]:
show_image_with_mask_plane(injected_science)

In [None]:
display = show_image_with_mask_plane(subtraction_injected.difference)

In [None]:
display = show_image_with_mask_plane(subtraction_injected_injected.difference)

In [None]:
print("Mask plane bit definitions:\n", display.getMaskPlaneColor())
print("\nMask plane methods:\n")
help(display.setMaskPlaneColor)

Run detection on subtraction

In [None]:
detect_and_measure_config = DetectAndMeasureConfig()
detect_and_measure_task = DetectAndMeasureTask(config=detect_and_measure_config)

In [None]:
detect_and_measure_injected = detect_and_measure_task.run(injected_science, subtraction_injected.matchedTemplate, subtraction_injected.difference)

In [None]:
detect_and_measure_injected_injected = detect_and_measure_task.run(injected_science, subtraction_injected_injected.matchedTemplate, subtraction_injected_injected.difference)

Did we recover injected object?

In [None]:
dia_src_injected = detect_and_measure_injected.diaSources.asAstropy()

In [None]:
plt.scatter(dia_src_injected["slot_Centroid_x"], dia_src_injected["slot_Centroid_y"])
plt.scatter([science_xy_coords.x], [science_xy_coords.y], color="red", s=20)

In [None]:
# Match in a simple way
threshold_dist = 2  # pixels
threshold_dist_sq = threshold_dist ** 2

dist_sq = (dia_src_injected["slot_Shape_x"] - science_xy_coords.x)**2 + \
          (dia_src_injected["slot_Shape_y"] - science_xy_coords.y)**2

idx, = np.where(dist_sq < threshold_dist_sq)
matching_injected = dia_src_injected[idx]

In [None]:
matching_injected

There are many diffim measurements and flags.  We're here most direclty interested in the dipole flag.

In [None]:
import re
# dipole_cols = re.compile("diffim.*ipole")
dipole_cols = re.compile("_value")

[c for c in list(matching_injected.columns) if dipole_cols.search(c)]

In [None]:
dipole_min_sn = detect_and_measure_config.measurement.plugins['ip_diffim_DipoleFit'].minSn
dipole_threshold_ratio = detect_and_measure_config.measurement.plugins['ip_diffim_DipoleFit'].maxFluxRatio
print(f"Dipole flag will be set for objects with S/N > {dipole_min_sn:0.2f}")
print(f"and {1-threshold_ratio:0.2f} < neg_flux/tot_flux = 1 - pos_flux/tot_flux < {threshold_ratio:0.2f}")

In [None]:
neg_flux = np.abs(dia_src_injected["ip_diffim_PsfDipoleFlux_neg_instFlux"])
pos_flux = np.abs(dia_src_injected["ip_diffim_PsfDipoleFlux_pos_instFlux"])
tot_flux = neg_flux + pos_flux
neg_ratio = neg_flux / tot_flux
pos_ratio = pos_flux / tot_flux

plt.scatter(neg_ratio, pos_ratio)
plt.xlabel("neg_flux / (pos_flux + neg_flux)")
plt.ylabel("pos_flux / (pos_flux + neg_flux)")
span_kwargs = {"color": "orange", "alpha": 0.1}
plt.axvspan(threshold_ratio, 1, **span_kwargs)
plt.axvspan(0, 1 - threshold_ratio, **span_kwargs)
plt.axhspan(threshold_ratio, 1, **span_kwargs)
plt.axhspan(0, 1 - threshold_ratio, **span_kwargs)

ax = plt.gca()
plt.xlim(0, 1)
plt.ylim(0, 1)
ax.set_aspect("equal")

The dipoles are those in the extremes (upper-left and lower-right corners).

This is a perfect y = 1 - x line by construction because we're plotting y/(x+y)  vs. x/(x+y) = 1 - y/(x+y).

So we can equivalently look at this in one dimension:

In [None]:
bins = np.linspace(-0.05, 1.05, 51)
plt.hist(pos_ratio, bins=bins)
plt.xlabel("pos_flux / (pos_flux + neg_flux)")
plt.axvspan(threshold_ratio, 1, **span_kwargs)
plt.axvspan(0, 1 - threshold_ratio, **span_kwargs)
plt.xlim(0, 1)
# plt.hist(neg_ratio, bins=bins, histtype="step");

Our DIA Source from the injected star is not a dipole because it fails the Dipole Flux measurement.  This seems completely possible and reasonable because it should have any negative part.

In [None]:
columns_of_interest = ["ip_diffim_PsfDipoleFlux_pos_instFlux", "ip_diffim_PsfDipoleFlux_neg_instFlux", "ip_diffim_ClassificationDipole_value"]
matching_injected[columns_of_interest]

Let's look at the subtraction where there was an injected start in the template and in the source iamge that were slightly displaced.

In [None]:
dia_src_injected_injected = detect_and_measure_injected_injected.diaSources.asAstropy()

In [None]:
x_col = re.compile("_x")
[c for c in dia_src_injected_injected.columns if x_col.search(c)]

In [None]:
plt.scatter(dia_src_injected_injected["slot_Centroid_x"], dia_src_injected_injected["slot_Centroid_y"])
plt.scatter([science_xy_coords.x], [science_xy_coords.y], color="red", s=20)

In [None]:
# Match in a simple way
threshold_dist = 2  # pixels
threshold_dist_sq = threshold_dist ** 2

dist_sq = (dia_src_table_injected_injected["slot_Centroid_x"] - science_xy_coords.x)**2 + \
          (dia_src_table_injected_injected["slot_Centroid_y"] - science_xy_coords.y)**2

idx, = np.where(dist_sq < threshold_dist_sq)
matching_injected_injected = dia_src_table_injected_injected[idx]

In [None]:
neg_flux = np.abs(matching_injected_injected["ip_diffim_PsfDipoleFlux_neg_instFlux"])
pos_flux = np.abs(matching_injected_injected["ip_diffim_PsfDipoleFlux_pos_instFlux"])
tot_flux = neg_flux + pos_flux
neg_ratio = neg_flux / tot_flux
pos_ratio = pos_flux / tot_flux

print(pos_ratio)

In [None]:
matching_injected_injected[columns_of_interest]

### Appendix.  Snippets of code I might want later

In [None]:
class deferred_wrapper():
    def __init__(self, cat):
        self.cat = cat
        
    def get(self):
        return self.cat

# si_cat_deferred = deferred_wrapper(si_cat)

In [None]:
# If you want a test calexp, here's one:
NEED_TEST_CALEXP = False
if NEED_TEST_CALEXP:
    test_dataId = {"visit": 421727, "detector": 157, "band": "i"}  # Known existing exposure
    test_calexp = butler.get("calexp", dataId=dataId)