# Robust Registration of Catalogs
**Fan Tian, 12/01/2019** -  ftian4@jhu.edu <br/>

## Description
In this notebook, we demonstrate using the robust registration algorithm [1] to cross-match small catalogs (particularly to those of the HST images) with rotation and shift. This is the latest version of the algorithm that: 
 - implements the "ring" algorithm, which subsets all pairs within an initial search radius $R$ into overlapping rings with a specified ring-width. <br/>
 - uses a simple annealing schedule for the astrometric uncertainty, the $\sigma$ value.

We also compare the robust estimation results with the results from the method of least-squares [2]. <br/>
The first part of this notebook consists implementation of the algorithm on the simulated HST/ACS/WFC catalogs. The second part demonstrates the cross-registration of an HST image (from the HLA catalog) to the Gaia DR2 catalog of the same field.

### Reference
[1] Tian, F. Budavári, T. Basu, A. Lubow, S.H. & White, R.L. (2019). Robust Registration of Astronomy Catalogs with Applications to the Hubble Space Telescope. _The Astronomical Journal_. 158(5) pp. 191.
<a href="https://iopscience.iop.org/article/10.3847/1538-3881/ab3f38/meta">doi:10.3847/1538-3881/ab3f38</a>

[2] Budavári, T. & Lubow, S.H. (2012).  Catalog Matching with Astrometric Correction and its Application to the Hubble Legacy Archive. _The Astrophysical Journal_. 761(2) pp.188. <a href="https://iopscience.iop.org/article/10.1088/0004-637X/761/2/188">doi:10.1088/0004-637X/761/2/188</a>

**Based on prototype implementations of 5/31/2018 - Tamás Budavári, and of 3/29/2019 - Rick White**

In [None]:
%matplotlib inline
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import astropy

# Set page width to fill browser for longer output lines
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
# set width for pprint
astropy.conf.max_width = 150

In [None]:
# import cross-registration modules
import xregistration.simulation as sim
import xregistration.estimation as est
import xregistration.est_catalog as rcat

In [None]:
# global variables
# convert units from arcseconds to radians
arc2rad = 3600 * 180 / np.pi

## Part 1. Simulation

## 1. Simulate mock universe and catalogs
- Simulate catalogs to the HST/ACS/WFC catalog with Field of View: 202"× 202"
- Approxately 1500 sources

In [None]:
# Initialize image size
size=202

# Initialize uncertainty parameter - sigma
sigma = 0.04

# Set seed
seed= 4444
np.random.seed(seed)

# Create mock universe
m = sim.mock(1500, size) # df with index that's the objid

# Create perturbed catalogs within selection interval - same size as m
cn = [sim.cat(m,sigma,l,h) for l,h in [(0.2, 1), (0, 0.9)]] # selection intervals

# Select catalogs - objid index retained
cs = [a[a.Selected] for a in cn] 

# Generate random omega0 and catalog0
omega0, c0 = sim.randomega(cs[0], scale=60)

# Generate catalog1, with omega1 = -omega0
omega1 = -1 * omega0
c1 = sim.trf(cs[1], omega1)

# transformed catalogs
co = [c0,c1] 

print("Average offset of two catalogs before transformation: {:2.3f} arcsec".format(sim.getsep(cs[0],cs[1],"mean")))
print("Average offset of two catalogs after transformation: {:2.3f} arcsec".format(sim.getsep(co[0],co[1]),"mean"))

## 2. Robust Ring Estimation:
- Cross-match pairs within rings
- Ring seletion: width $\approx 4\sigma$
- Apply $\sigma$ convergence at initial steps of iteration
- Stopping: $|\omega_{t+1} - \omega_{t}| < \epsilon$

### 2.1 Matched pairs within an initial search radius
#### Find all pairs that match within _radius_ (arcsec).

In [None]:
# Initial search radius, approximately 1.1 times of the maximum offset
radius = 1.1 * sim.getsep(co[0],co[1],"max")
print(f"search radius: {radius:.2f}")

print(f"{co[0].shape[0]} sources in input catalog and {co[1].shape[0]} sources in reference catalog")

### 2.2 Fast prototype of the robust iterative solver

Objective function: 
$$ 
    \tilde{\boldsymbol{\omega}}= 
    \arg\min_{\boldsymbol{\omega}}\sum_{q}\,
    \rho\left(\frac{ \left|\boldsymbol{\Delta}_{q}-\boldsymbol{\omega}\times\boldsymbol{r}_{q} \right|}{\sigma}\right)
$$

$$
    \rho(x) = -\ln 
    \left(
    \frac{\gamma_{*}}{2\pi\sigma^2}\ e^{-x^2/2} \,+\, \frac{1\!-\!\gamma_{*}}{\Omega}
    \right).
$$

- $\textbf{c}_q$: q-th calibrator direction
- $\textbf{r}_q$: q-th source direction of the image (to be corrected)
- $\boldsymbol{\Delta}_q = \textbf{c}_q - \textbf{r}_q$: seperation between q-th source-calibrator
- $\boldsymbol{\omega}$: 3-D rotation vector
- $\sigma$:  astrometric  uncertainty
- $\gamma$: probability of being a true association
- $\gamma_{*} = \frac{\min (N_1, N_2)}{N}$; N=total number of pairs, N1=number of sources in input catalog, N2=number of sources in reference catalog.
- $\Omega$: footprint area (steradians) 

Solve for $\tilde{\boldsymbol{\omega}}$ using $A\tilde{\boldsymbol{\omega}} = \textbf{b}$ with

\begin{equation}
\begin{array}{ccc}
    A =\displaystyle \sum_{q} \frac{w_{q}}{\sigma^{2}}
    \left(I-\boldsymbol{r}_{q}\!\otimes\boldsymbol{r}_{q}\right) & \textrm{and} &
    b = \displaystyle \sum_{q}
    \frac{w_{q}}{\sigma^{2}}
    \left(\boldsymbol{r}_{q}\!\times\boldsymbol{c}_{q}\right) 
\end{array}
\end{equation}

\begin{equation*}
\begin{aligned}
    \text{Weight function: }
    w_q=W(x) &= \frac{\rho'(x)}{x}
          = \frac{\alpha e^{-x^2/2}}{\alpha e^{-x^2/2}+1}\\
  \alpha &= \frac{\Omega}{2 \pi \sigma^2} \frac{\gamma_{*}}{(1-\gamma_{*})}
\end{aligned}
\end{equation*}

**Robust solver output:** 
- omega: estimated rotation vector
- pair: pairs in the optimal ring

**Input parameters** <br>
- area: side length of the image is 202 arcsec 
- radius: initial search radius
- sigma: actual astrometric uncertainty of the catalog, 0.04 arcsec
- gamma: fraction of true pairs (unknown, approximate)
- ringwidth: assign ring width to 0.2 arcsec (empirical value)
- sigma_init: assign an initial sigma to 0.4 arcsec 
- niter: minimum number of iterations for convergence = 10
- nextr: maximum additional number of iterations = 100
- mid: use midpoints of the two catalogs as reference

**Estimate Omega using Robust Ring Algorithm Version2**

Version-2 Algorithm divides all pairs into equal number of pairs in rings, and performs esimation in each ring. In this version, $\gamma$ is taken as the global value $$\gamma_* = \frac{\min(N_0,N_1)}{N_{pairs}}$$.

In [None]:
t0=time.time()

# image area in steradians
area=(202/arc2rad)**2

# estimate omega, and obtain pairs in the optimal ring
bestomega, bestpairs, bestwts = est.robust_ring(co[0], co[1], area, radius, sigma=0.04,
                                   sigma_init=4,
                                   niter=50, nextr=100, mid=True, printerror=False)

print("Total {:.3f} seconds to complete estimation".format(time.time()-t0))

### 2.3 Solve for $\boldsymbol{\omega}$ using the least-squares algorithm

Apply the least-squares method on pairs in the optimal ring.

<font color="red">Note this is really optimistic for the L2 method since it has no way to determine the
    best ring.</font>

In [None]:
L2omega = est.L2_est(co[0], co[1], bestpairs, sigma=0.04)

### 2.4 Plot Catalogs

In [None]:
cat_cor_rob = [sim.trf(co[0],bestomega), sim.trf(co[1],-bestomega)]
cat_cor_L2 = [sim.trf(co[0],L2omega), sim.trf(co[1],-L2omega)]

In [None]:
fig=plt.figure(figsize=(20,6))

fig.add_subplot(131)
plt.scatter(co[0].x*arc2rad, co[0].y*arc2rad, s=10, alpha=0.3)
plt.scatter(co[1].x*arc2rad, co[1].y*arc2rad, s=10, alpha=0.3)
plt.xlim(-50,250)
plt.ylim(-50,250)
plt.title("ORIGINAL")

fig.add_subplot(132)
plt.scatter(cat_cor_rob[0].x*arc2rad, cat_cor_rob[0].y*arc2rad, s=10, alpha=0.4)
plt.scatter(cat_cor_rob[1].x*arc2rad, cat_cor_rob[1].y*arc2rad, s=10, alpha=0.4)
plt.title("ROBUST METHOD")
print("Initial average offset: {:.3f} arcsec" .format(sim.getsep(co[0],co[1],"mean")))
print("Average offset after correction - ROBUST: {:.3f} arcsec" .format(sim.getsep(cat_cor_rob[0], cat_cor_rob[1],"mean")))

fig.add_subplot(133)
plt.scatter(cat_cor_L2[0].x*arc2rad, cat_cor_L2[0].y*arc2rad, s=10, alpha=0.4)
plt.scatter(cat_cor_L2[1].x*arc2rad, cat_cor_L2[1].y*arc2rad, s=10, alpha=0.4)
plt.title("L2 METHOD")
print("Average offset after correction - L2: {:.3f} arcsec" .format(sim.getsep(cat_cor_L2[0],cat_cor_L2[1],"mean")))
plt.show()

## Part 2. The HLA/Gaia Catalogs Cross-registration
**Adapted from 3/29/2019 - Rick White**

## 1. Read Data
### 1.1 Set parameters for a visit

- `imagename` = name of HLA dataset
- `radius` = maximum shift radius to search (arcsec)
- `requirePM` = True to require Gaia proper motions (must be False some cluster fields)
- `limitGaia` = True to restrict the number of Gaia sources to ~200
- `flagcut` = maximum flag value to include in HLA catalog (5=all, 1=unsat, 0=stellar unsat)

In [None]:
# list of test images
# imagename = 'hst_9984_nl_acs_wfc' # far north with rotation
# imagename = 'hst_9984_ni_acs_wfc' # far north with rotation
imagename = 'hst_11664_22_wfc3_uvis' # big 90" shift
# imagename = 'hst_10775_a7_acs_wfc' # challenging image with large catalogs

radius = 120.0
requirePM = True
limitGaia = False
flagcut = 5

### 1.2 Read the HLA catalog for a dataset

This also applies a magnitude cut to keep only sources brighter than magnitude 22 that might match Gaia sources.

In [None]:
# save some time if we already have the correct catalog
# use cache to store results so repeated queries are fast
if 'current_imagename' in locals() and current_imagename == imagename:
    print('Already read catalog for',imagename)
else:
    current_imagename = None
    imacat = rcat.getmultiwave(imagename)
    current_imagename = imagename
    print("Read {} sources for {}".format(len(imacat),imagename))

    # use only objects brighter than mag 22
    # select brightest of all mags
    magcols = []
    flagcols = []
    for col in imacat.colnames:
        if col.endswith('magauto'):
            magcols.append(col)
        elif col.endswith('_flags'):
            flagcols.append(col)
    if not magcols:
        raise ValueError("No magnitude columns found in catalog")
    if len(magcols) != len(flagcols):
        raise ValueError("Mismatch between magcols [{}] and flags [{}]?".format(
            len(magcols),len(flagcols)))
    print("Magnitudes {}".format(" ".join(magcols)))
    mags = imacat[magcols[0]]
    for col in magcols[1:]:
        mags = np.minimum(mags,imacat[col])
    #    mags = np.maximum(mags,cat[col])
    flags = imacat[flagcols[0]]
    for col in flagcols[1:]:
        flags = np.minimum(flags,imacat[col])

### 1.3 Read the Gaia catalog with padding to allow for large shifts

The Gaia search box is expanded by 2 arcmin on all sides to allow for the possibility of a shift that large.

In [None]:
mdec = imacat['dec'].mean()
mra = imacat['ra'].mean()
cdec = np.cos(rcat.d2r*mdec)

# always pad using search radius 120.0 so we can reuse the result
gradius = max(60.0,radius)

# pad by 1.1*search radius on each side
pad = 1.1*gradius/3600.0
rpad = pad/cdec
ramin = imacat['ra'].min() - rpad
ramax = imacat['ra'].max() + rpad
decmin = imacat['dec'].min() - pad
decmax = imacat['dec'].max() + pad
new_params = (ramin,ramax,decmin,decmax)

if 'gcat_params' in locals() and gcat_params == new_params:
    print('Already read Gaia catalog for {} ({} sources)'.format(imagename,len(gcat_all)))
else:
    gcat_params = None
    gcat_all = rcat.gaiaquery(ramin,decmin,ramax,decmax)
    gcat_params = new_params
    print("Read {} Gaia sources".format(len(gcat_all)))

gcat = gcat_all

# compute ratio of area covered by data to extended area
area_rat = (ramax-ramin-2*rpad)*(decmax-decmin-2*pad)/((ramax-ramin)*(decmax-decmin))

if requirePM:
    # keep only objects with proper motions
    gcat = gcat[~gcat['pmra'].mask]
    print("Keeping {} Gaia sources with measured PMs".format(len(gcat)))

    # apply proper motions
    epoch_yr = rcat.getepoch(imagename)
    # make reference epoch a scalar if possible
    ref_epoch = gcat['ref_epoch']
    if (ref_epoch == ref_epoch.mean()).all():
        ref_epoch = ref_epoch[0]
    dt = epoch_yr-ref_epoch
    print("Updating gcat for {:.1f} yrs of PM".format(-dt))
    # PM fields are in mas/yr
    gcat.ra = gcat['ra'] + gcat['pmra']*(dt/(3600.0e3*np.cos(rcat.d2r*gcat['dec'])))
    gcat.dec = gcat['dec'] + gcat['pmdec']*(dt/3600.0e3)
else:
    print("No Gaia PMs are used, all Gaia sources are retained")

# if number of Gaia sources is large, select just a subset of the fainter sources
# aim for about 200 sources within the field
if limitGaia:
    ngmax = int(round(200/area_rat))
    if len(gcat) > ngmax:
        print("Clipping to faintest",ngmax,"Gaia sources")
        gcat.sort('phot_g_mean_mag')
        gcat = gcat[-ngmax:]
gcat[:5]

### 1.4 Restrict the HLA catalog to sources close to the Gaia magnitude limit

For a typical Gaia field, the magnitude cut is about 22.  Some Gaia fields have a much brighter limit, which raises the magnitude cut. <br/>
This also applies a cut on flags if flagcut is set.

In [None]:
gmaglim = gcat['phot_g_mean_mag'].max()
magcut = min(gmaglim + 1.2, 22.0)
print('Gaia mag limit {:.3f} -> HLA magnitude cut {}'.format(gmaglim,magcut))

# forcing this cut to see how this affects wider radius searches
if (mags <= magcut).sum() > 1000:
    magcut = 17.0
    print('XXX Forcing HLA magnitude cut {} XXX'.format(magcut))

wcut = np.where((mags <= magcut) & (flags <= flagcut))
bcat = imacat[wcut]
bmags = mags[wcut]
bflags = flags[wcut]
print("{} sources left after cut at mag {}, flags <= {}".format(len(bcat),magcut,flagcut))
bcat[:5]

### 1.5 Plot positions on sky

In [None]:
plt.rcParams.update({'font.size':14})
plt.figure(figsize=(8,8))
plt.plot(bcat['ra'],bcat['dec'],'ro',alpha=0.3,markersize=4,label='HLA')
plt.plot(gcat['ra'],gcat['dec'],'bo',alpha=0.3,markersize=4,label='Gaia')
plt.xlabel('RA [deg]')
plt.ylabel('Dec [deg]')
plt.title(imagename)
plt.legend(loc=3);

## 2. Matched pairs within an initial search radius
#### Find all pairs that match within _radius_ arcsec.

In [None]:
## Convert positions to Cartesian xyz coordinates 
# a = catalog to shift (the HLA catalog) 
# b = reference catalog (Gaia)
a = rcat.cat2xyz(bcat)
b = rcat.cat2xyz(gcat)

print(f"{a.shape[0]:d} sources in HLA and {b.shape[0]:d} sources in Gaia")

## 3. Use robust solver to estimate rotation of the HLA catalog to the reference
**Input parameters** <br>
- area: side length of the image is 202 arcsec
- sigma: astrometric uncertainly is 0.02 arcsec
- sigma_init: assign an initial sigma to 1 arcsec 
- gamma: fraction of true pairs (unknown, approximate)
- niter: minimum number of iterations for convergence, 10 
- nextr: maximum number of additional iterations, 100 
- ringwidth: 0.3 arcsec ring width
- mid: False, reference is the Gaia catalog

In [None]:
t0=time.time()
area = (202/arc2rad)**2
radius = 120
print(f"Match to radius {radius} arcsec")

omega_HLA, pairs_HLA, wts_HLA = est.robust_ring(a, b, area, radius, sigma=0.01,
                                     sigma_init=1, niter=10, nextr=100, 
                                     mid=False, printerror=False)

print("Total {:.3f} seconds to complete estimation".format(time.time()-t0))

## 4. Plot to show catalog separation before and after correction

The top two panels show a zoomed-out view (over a region +- 5 arcsec) while the bottom two are zoomed in (over a region +- 0.12 arcsec).  The left plot shows the original distribution (note it is centered far from zero) while the right is after applying the correction from the robust match (centered on zero).  Note that the scale is identical in the left and right panels.

Points with weights $w_q > 0.5$ are shown in red.  Those are the "true" matches.

In [None]:
# Sort pairs for plot
sep = np.sqrt(((a[pairs_HLA[:,0]]-b[pairs_HLA[:,1]])**2).sum(axis=1))*arc2rad
ind = np.argsort(sep)
sep = sep[ind]
pairs_HLA = pairs_HLA[ind,:]
wts_HLA = wts_HLA[ind]

In [None]:
wmatch = np.where(wts_HLA>0.5)[0]
print('total pairs =', len(wts_HLA), 'sum wts =', wts_HLA.sum(), 'matched pairs =', len(wmatch))
print(a.shape, b.shape)

# first use only good pairs to get limits for plot
p0 = pairs_HLA[wmatch,0]
p1 = pairs_HLA[wmatch,1]
ra1 = bcat['ra'][p0]
dec1 = bcat['dec'][p0]
gra = gcat['ra'][p1]
gdec = gcat['dec'][p1]

rr = rcat.radec2xyz(ra1,dec1)
ra2, dec2 = rcat.xyz2radec(rr + np.cross(omega_HLA,rr))

dra1, ddec1 = rcat.getdeltas(ra1,dec1,gra,gdec)
dra2, ddec2 = rcat.getdeltas(ra2,dec2,gra,gdec)

# center shifted plot at zero and use the same range in arcsec for both plots
xcen1 = np.ma.median(dra1)
ycen1 = np.ma.median(ddec1)
xcen2 = 0.0
ycen2 = 0.0

# plot both good pairs and bad pairs near the match
p0 = pairs_HLA[:,0]
p1 = pairs_HLA[:,1]
ra1 = bcat['ra'][p0]
dec1 = bcat['dec'][p0]
gra = gcat['ra'][p1]
gdec = gcat['dec'][p1]
rr = rcat.radec2xyz(ra1,dec1)
ra2, dec2 = rcat.xyz2radec(rr + np.cross(omega_HLA,rr))
dra1, ddec1 = rcat.getdeltas(ra1,dec1,gra,gdec)
dra2, ddec2 = rcat.getdeltas(ra2,dec2,gra,gdec)

# transparency for box around legend
framealpha = 0.95

plt.figure(1,(12,12))

xsize = 5.0
xlims1 = (xcen1-xsize, xcen1+xsize)
ylims1 = (ycen1-xsize, ycen1+xsize)
xlims2 = (xcen2-xsize, xcen2+xsize)
ylims2 = (ycen2-xsize, ycen2+xsize)
# points to plot
wp = np.where(
    ((dra1>=xlims1[0]) & (dra1<=xlims1[1]) & (ddec1>=ylims1[0]) & (ddec1<=ylims1[1])) |
    ((dra2>=xlims2[0]) & (dra2<=xlims2[1]) & (ddec2>=ylims2[0]) & (ddec2<=ylims2[1]))
    )[0]
wgood = wp[wts_HLA[wp]>=0.5]
wbad = wp[wts_HLA[wp]<0.5]
print("{} good points {} bad points".format(len(wgood),len(wbad)))

plt.subplot(221)
plt.plot(dra1[wbad], ddec1[wbad], 'ko', markersize=2, label='original')
plt.plot(dra1[wgood], ddec1[wgood], 'ro', markersize=2, label=r'$w_q \geq 0.5$')
plt.ylabel('$\Delta$Dec [arcsec]')
plt.xlabel('$\Delta$RA [arcsec]')
plt.plot(xlims1,[0,0], 'g-', linewidth=0.5)
plt.plot([0,0], ylims1, 'g-', linewidth=0.5)
plt.xlim(xlims1)
plt.ylim(ylims1)
plt.legend(loc='upper left',framealpha=framealpha)

plt.subplot(222)
plt.plot(dra2[wbad], ddec2[wbad], 'ko', markersize=2, label='robust')
plt.plot(dra2[wgood], ddec2[wgood], 'ro', markersize=2, label=r'$w_q \geq 0.5$')
plt.xlabel('$\Delta$RA [arcsec]')
plt.plot(xlims2,[0,0], 'g-', linewidth=0.5)
plt.plot([0,0], ylims2, 'g-', linewidth=0.5)
plt.xlim(xlims2)
plt.ylim(ylims2)
plt.legend(loc='upper left',framealpha=framealpha)

xsize = 0.12
xlims1 = (xcen1-xsize, xcen1+xsize)
ylims1 = (ycen1-xsize, ycen1+xsize)
xlims2 = (xcen2-xsize, xcen2+xsize)
ylims2 = (ycen2-xsize, ycen2+xsize)
# points to plot
wp = np.where(
    ((dra1>=xlims1[0]) & (dra1<=xlims1[1]) & (ddec1>=ylims1[0]) & (ddec1<=ylims1[1])) |
    ((dra2>=xlims2[0]) & (dra2<=xlims2[1]) & (ddec2>=ylims2[0]) & (ddec2<=ylims2[1]))
    )[0]
wgood = wp[wts_HLA[wp]>=0.5]
wbad = wp[wts_HLA[wp]<0.5]
print("{} good points {} bad points".format(len(wgood),len(wbad)))

plt.subplot(223)
plt.plot(dra1[wgood], ddec1[wgood], 'ro', markersize=2, label='original')
plt.plot(dra1[wbad], ddec1[wbad], 'ko', markersize=2)
plt.ylabel('$\Delta$Dec [arcsec]')
plt.xlabel('$\Delta$RA [arcsec]')
plt.xlim(xlims1)
plt.ylim(ylims1)
plt.legend(loc='upper left',framealpha=framealpha)

plt.subplot(224)
plt.plot(dra2[wgood], ddec2[wgood], 'ro', markersize=2, label='robust')
plt.plot(dra2[wbad], ddec2[wbad], 'ko', markersize=2)
plt.xlabel('$\Delta$RA [arcsec]')
plt.plot(xlims2,[0,0], 'g-', linewidth=0.5)
plt.plot([0,0], ylims2, 'g-', linewidth=0.5)
plt.xlim(xlims2)
plt.ylim(ylims2)
plt.legend(loc='upper left',framealpha=framealpha,
             title='rms = {:.0f} mas'.format(
                 1000*np.sqrt((dra2[wgood]**2+ddec2[wgood]**2).mean())));

Note in the above example that not only was a large shift corrected, but there also was a small rotation corrected.  That is why the point distribution is much tighter after the correction has been applied.