# Example of Running `tweakwcs` to Align JWST images

***
## About this Notebook
**Authors:** Mihai Cara, Clare Shanahan (STScI)
<br>**Updated On:** 10/15/2018

This notebooks illustrates how to use tweakwcs to align JWST images.

***
## Imports

In [None]:
from astropy.table import Column, Table
from astropy.modeling.models import RotateNative2Celestial
from astropy.modeling import models
from astropy.coordinates import SkyCoord
from astropy import units as u
from copy import deepcopy
from jwst.datamodels import ImageModel, DataModel
from jwst.assign_wcs import AssignWcsStep
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from os import path
from photutils import detect_threshold, DAOStarFinder
from tweakwcs import tweak_image_wcs, tweak_wcs, TPMatch, JWSTgWCS

***
## Load Data:

In [None]:
# Load JWST image models:
data_dir = '/grp/jwst/wit/nircam/hilbert/simulated_ramps/image_reg_hack_day/simulated_data/'
im1 = ImageModel(data_dir + 'jw98765001001_01101_00001_nrca1_cal.fits')
im2 = ImageModel(data_dir + 'jw98765001001_01102_00001_nrca1_cal.fits')

#simulated data are already aligned, artifically shift image 2 to represent a pointing error
artificial_offset = models.Shift(-12) & models.Shift(16) | im2.meta.wcs.get_transform('detector','v2v3')
im2.meta.wcs.set_transform('detector', 'v2v3', artificial_offset)

***
## Example 1: Typical Workflow to Align Two or More Images

First, we will align the two images (the two ImageModels created above) using the most basic workflow. 

This process consists of the following steps.

    1. Detect sources in each image and create source catalogs. 
    2. Insert these catalogs in the ImageModels, where tweakwcs will look for them. 
    3. Inspect the output to verify the alignment
    4. Save the aligned image. 

Note that tweakwcs does not create source catalogs internally, it must be passed in (unlike tweakreg, for those familiar). This increases the flexibility of use - users may generate source catalogs any way they like (segmentation maps, psf fitting, or DAOStarFinder as used here, for example). Users may also apply selection criteria to source catalogs before passing them in. In this example, we simply select all sources detected with DAOStarFinder that are over a certain threshold. 

### 1. Create Source Catalogs

Images will be aligned based on sources common between the frames. Source catalogs are generated for each image using the DAOStarFinder algorithm in photutils. Once the sources have been detected, the resulting catalog will be added to the image model. 

The source finding parameters in this example have been optimized for this dataset - you may need to tweak these to find sources optimally in your images. 

In [None]:
for catno, im in enumerate([im1, im2]): #iterate over images to create and insert catalogs into imagemodel
    threshold = detect_threshold(im.data, snr=25.)[0, 0]
    daofind = DAOStarFinder(fwhm=5, threshold=threshold)
    cat = daofind(im.data)
    cat.rename_column('xcentroid', 'x')
    cat.rename_column('ycentroid', 'y')
    cat.meta['name'] = 'im{:d} sources'.format(catno); print(len(cat), 'sources detected')
    im.meta['catalog'] = cat 
cat1 = im1.meta['catalog']
cat2 = im2.meta['catalog']

####  Plot detected sources

In [None]:
f, (ax1, ax2) = plt.subplots(nrows = 1, ncols = 2,figsize = (12,6)); f.suptitle('All Detected sources')
ax1.imshow(im1.data, origin = 'lower', vmin = -0.5, vmax = 0.5, cmap = 'Greys_r')
ax1.scatter(cat1['x'], cat1['y'], facecolor = 'None', edgecolor = 'r')
ax1.set_xlim(0,2000); ax1.set_ylim(0,2000); ax1.set_title('Image 1')
ax2.imshow(im2.data, origin = 'lower', vmin = -0.5, vmax = 0.5, cmap = 'Greys_r')
ax2.scatter(cat2['x'],cat2['y'], facecolor = 'None', edgecolor = 'r')
ax2.set_xlim(0,2000); ax2.set_ylim(0,2000); ax2.set_title('Image 2'); pass;

It is useful to match common sources between images to inspect the quality of the alignment before and after image re-alignment.
If these images are well aligned, then when we project the same source from each image onto the sky with the GWCS transformation, they should agree and the residuals should be minimal.

In order to cross-match sources, we can estimate the shift between the two images, apply this rudimentary correction and find closest matching pairs between catalogs using a KD tree. By visual inspection, these images are dithered about 250 pixels in both x and y. We can apply this shift for a rudimentary alignment just to match some of the common sources. Because we are only estimating the shift, we should also set a lower bound for seperation to avoid mixing up sources.

Note: This step is only to diagnose the quality of the alignment and not necessary for the core alignment process.  You may cross match sources with many other methods but in this instance this method is fairly straightforward. 

In [None]:
#Artifically shift first catalog to very roughly match second, and project onto sky with existing GWCS
shifted_x1 = cat1['x'] + 250
shifted_y1 = cat1['y'] - 250

#Use KD tree to find nearest neighbor in cat2 to a point in cat1.
from scipy.spatial import KDTree
tree = KDTree(list(zip(shifted_x1, shifted_y1)))
pts = list(zip(cat2['x'], cat2['y']))
distance_pix_cat1, idx_cat1 = tree.query(pts) 

distance_mask = [(distance_pix_cat1 > 10) & (distance_pix_cat1 < 18)]
matched_tab2 = cat2[distance_mask]
matched_tab1 = cat1[idx_cat1[distance_mask]]

#### Plot Matched Sources

Verify that each source we determined was a match corresponds to the same source in each image by overplotting them.

In [None]:
plt.figure(figsize = (7,7))
plt.imshow(im2.data, origin = 'lower', vmin = -0.5, vmax = 0.5, cmap = 'Greys_r')
plt.scatter(matched_tab1['x'] + 250, matched_tab1['y'] - 250, facecolor = 'None', edgecolor = 'b', label = 'cat 1')
plt.scatter(matched_tab2['x'], matched_tab2['y'], facecolor = 'None', edgecolor = 'r', label = 'cat 2')
plt.xlim(0,2000); plt.ylim(0,2000); plt.title('Cross-matched sources'); plt.legend(loc = 'best'); pass;

With a set of common sources between each image, we can project these sources onto the sky and inspect the residuals between the projection from image 1 and image 2. 

In [None]:
im1_ra, im1_dec = im1.meta.wcs(matched_tab1['x'],matched_tab1['y'])
im2_ra, im2_dec = im2.meta.wcs(matched_tab2['x'],matched_tab2['y'])

plt.quiver(im1_ra, im1_dec,im1_ra-im2_ra,im1_dec-im2_dec)
plt.title('residuals, before alignment')
plt.xlabel('ra (reference image)')
plt.ylabel('dec (reference image)'); pass;
print('median absolute offset before alignment, RA and Dec (deg): ',np.median(np.abs(im2_ra - im1_ra)), np.median(np.abs(im2_dec - im1_dec)))

The systematic nature of these residuals, and the large magnitude which would span several pixels on many detectors, indicates a poor alignment. TweakWcs can improve this alignment.

### 2. Align Images Using Image Source Catalogs

With source catalogs created and added to each image model, they can now be aligned with tweak_image_wcs.

A reference catalog will be created from the first image passed in, by default. In this case, im2 will be aligned to im1, so only the WCS information for im2 will be adjusted.

In [None]:
tweak_image_wcs([im1, im2])

### 3. Inspect Corrected WCS. Save Aligned Image.

TweakWcs inserts a correction into the full transformation pipeline, labeled 'tangent-plane linear correction'. If we im2, we can see this has been inserted.

In [None]:
print(im2.meta.wcs)

Now let's look at the same residual plot that was made before the alignment with the re-aligned images.

In [None]:
im1_ra, im1_dec = im1.meta.wcs(matched_tab1['x'],matched_tab1['y'])
im2_ra, im2_dec = im2.meta.wcs(matched_tab2['x'],matched_tab2['y'])

plt.quiver(im1_ra, im1_dec,im1_ra-im2_ra,im1_dec-im2_dec)
plt.title('residuals, before alignment')
plt.xlabel('ra (reference image)')
plt.ylabel('dec (reference image)'); pass;
print('median absolute offset before alignment, RA and Dec (deg): ',np.median(np.abs(im2_ra - im1_ra)), np.median(np.abs(im2_dec - im1_dec)))

The randomness is an indicator of a good fit, and the residuals are on the order of a tiny fraction of a degree.

Because TweakWCS works in memory, we have to explicitly save the tweaked GWCS to file once we are satisfied with the alignment. This can be done with the ImageModel.write method.

In [None]:
im2.write('im2_aligned.fits')