# Example of Running `tweakwcs` to Align JWST images

***
## About this notebook
**Authors:** Mihai Cara (STScI), Clare Shanahan (STScI)
<br>**Updated On:** 12/05/2018

This notebooks illustrates a basic workflow of how to use `tweakwcs` to align JWST images.

***
## Imports

In [None]:
from copy import deepcopy

from jwst.datamodels import ImageModel
import matplotlib.pyplot as plt 
import numpy as np
from photutils import DAOStarFinder, detect_threshold
from tweakwcs import tweak_image_wcs

# 'matplotlib inline' for displaying plots nicely in this notebook.
%matplotlib inline 

***
## Download & Load Data

The data for this example are hosted on box and are available to anyone with the static link below. The cells below, when executed, will download these files to your machine at the paths stored as strings in variables `example_file_1` and `example_file_2`.

In [None]:
from astropy.utils.data import download_file

url1 = 'https://stsci.box.com/shared/static/3z78bjae14pj6nq3plvauukpmw1li1hc.fits'
url2 = 'https://stsci.box.com/shared/static/d301ydkzxqfnd559rl4r8qts5hw2x1is.fits'
example_file_1 = download_file(url1, cache=True)
example_file_2 = download_file(url2, cache=True)

Load JWST ImageModels from downloaded files and inspect exposures.

In [None]:
im1 = ImageModel(example_file_1)
im2 = ImageModel(example_file_2)

f, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(8,4))
ax1.imshow(im1.data, origin = 'lower', vmin=-0.5, vmax=0.5, cmap='Greys_r')
ax1.set_title('Image 1')
ax2.imshow(im2.data, origin='lower', vmin=-0.5, vmax=0.5, cmap='Greys_r')
ax2.set_title('Image 2')
pass;

We are working with simulated NIRCam Observations, two dithered exposures of the same field. Because these datasets are already accurately aligned, we will introduce an artifical WCS error to show how `tweakwcs` can account for misalignment. To do this, we will simply take the GWCS from `im1` and assign it to `im2`, mimicking a scenario where pointing information for `im2` is incorrect and so pixel positions will map incorrectly to the sky.

In [None]:
wcs_im1 = im1.meta.wcs
wcs_im2 = deepcopy(wcs_im1)

***
## Typical Workflow to Align Two or More Images

We will align the two images (ImageModels loaded above) using the most basic workflow. 

This process consists of the following steps.

    1. Detect sources in each image and create source catalogs. 
    2. Insert these catalogs in the ImageModels in 'meta', where tweakwcs will look for them. 
    3. Inspect the output to verify the quality of the alignment.
    4. When satisfied, save the aligned image. 


### Create Source Catalogs

`tweakwcs` does not create source catalogs used for matching internally, they must be passed in. This increases the flexibility of use - users may generate these catalogs any way they like (segmentation maps, psf fitting, or `DAOStarFinder` as used here, for example). Users may also apply selection criteria to source catalogs before passing them in, such as selecting only point sources, or those in certain magnitude range. In this example, we simply select all sources detected with `DAOStarFinder` that are over a certain threshold. 
Images will be aligned based on sources common between the frames. Source catalogs are generated for each image using the `DAOStarFinder` algorithm in `photutils`. Once the sources have been detected, the resulting catalog will be added to each ImageModel in `meta`. 

In [None]:
for catno, im in enumerate([im1, im2]): #iterate over images to create and insert catalogs into imagemodel
    threshold = detect_threshold(im.data, snr=50.)[0, 0]
    daofind = DAOStarFinder(fwhm=5, threshold=threshold)
    cat = daofind(im.data)
    cat.rename_column('xcentroid', 'x')
    cat.rename_column('ycentroid', 'y')
    cat.meta['name'] = 'im{:d} sources'.format(catno)
    print(len(cat), 'sources detected in im{}'.format(catno+1))
    im.meta['catalog'] = cat 
cat1 = im1.meta['catalog']
cat2 = im2.meta['catalog']

####  Plot detected sources

In [None]:
f, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize = (12, 6)); f.suptitle('All Detected sources')
ax1.imshow(im1.data, origin='lower', vmin=-0.5, vmax=0.5, cmap='Greys_r')
ax1.scatter(cat1['x'], cat1['y'], facecolor='None', edgecolor='r')
ax1.set_xlim(0, 2000); ax1.set_ylim(0, 2000); ax1.set_title('Image 1')
ax2.imshow(im2.data, origin='lower', vmin=-0.5, vmax=0.5, cmap='Greys_r')
ax2.scatter(cat2['x'],cat2['y'], facecolor='None', edgecolor='r')
ax2.set_xlim(0, 2000); ax2.set_ylim(0, 2000); ax2.set_title('Image 2'); pass;

### Inspecting alignment

It is useful to match common sources between images to inspect the quality of the alignment before and after image re-alignment.
If these images are well aligned, then when we project the same source from each image onto the sky with the GWCS transformation, they should match and the residuals should be minimal.

In order to cross-match sources, we can estimate the shift between the two images, apply this rudimentary correction and find closest matching pairs between catalogs using a KD tree. By visual inspection, these images are dithered about 240 pixels in both x and y. We can apply this shift for a very rough alignment just to match some of the common sources. Because we are only estimating the shift, we should also keep in mind lower bound for separation to avoid mixing up sources.

**Note:** This step is only to diagnose the quality of the alignment and not necessary for the core alignment process.  You may cross match sources with many other methods but in this instance this method is fairly straightforward. 

In [None]:
#Artifically shift first catalog to very roughly match second, and project onto sky with existing GWCS
shifted_x1 = cat1['x'] + 240
shifted_y1 = cat1['y'] - 240

#Use KD tree to find nearest point in cat2 to a point in cat1.
from scipy.spatial import KDTree
tree = KDTree(list(zip(shifted_x1, shifted_y1)))
pts = list(zip(cat2['x'], cat2['y']))
distance_pix_cat1, idx_cat1 = tree.query(pts) 

#exclude sources that are too close or too far
distance_mask = [(distance_pix_cat1 > 0) & (distance_pix_cat1 < 5)]
matched_tab2 = cat2[distance_mask]
matched_tab1 = cat1[idx_cat1[distance_mask]]

Verify that each source we determined was a match corresponds to the same source in each image by overplotting them.

In [None]:
plt.figure(figsize = (7, 7))
plt.imshow(im2.data, origin='lower', vmin=-0.5, vmax=0.5, cmap='Greys_r')
plt.scatter(matched_tab1['x']+250, matched_tab1['y']-250, facecolor='None', edgecolor='b', label='cat 1')
plt.scatter(matched_tab2['x'], matched_tab2['y'], facecolor='None', edgecolor='r', label='cat 2')
plt.xlim(0, 2000); plt.ylim(0, 2000); plt.title('Cross-matched sources'); plt.legend(loc = 'best'); pass;

With a set of common sources between each image, we can project these sources onto the sky and inspect the residuals between the projection from image 1 and image 2. 

Here we are just plotting the absoulute difference in RA and Dec for sources projected from image 1 and image 2 as this is sufficient in this case, but depending on your field you will need to calculate the great circle distance between points for this to be meaningful.

In [None]:
im1_ra, im1_dec = wcs_im1(matched_tab1['x'],matched_tab1['y'])
im2_ra, im2_dec = wcs_im2(matched_tab2['x'],matched_tab2['y'])

q = plt.quiver(im1_ra, im1_dec,-1.*(im1_ra-im2_ra),im1_dec-im2_dec)
plt.quiverkey(q, 0.1, 0.1, 10/3600., '10"', color = 'r')
plt.xlim(plt.xlim()[::-1])
plt.title('residuals, before alignment')
plt.xlabel('ra (reference image)')
plt.ylabel('dec (reference image)'); pass;
print('median absolute offset before alignment, RA and Dec (arcsec): ', 
      np.median(np.abs(im2_ra-im1_ra)*3600.), 
      np.median(np.abs(im2_dec-im1_dec)*3600.))

The systematic nature of these residuals, and the large offset of several arcseconds, indicates a poor alignment. `tweakwcs` can improve this alignment so that sources projected with GWCS from image 1 better match the sky position of the same sources projected with the GWCS from image 2.

## Align Images with tweakwcs

With source catalogs created and added to each image model, they can now be aligned with `tweak_image_wcs`. This function modifies the ImageModel in memory to adjust GWCS information. A reference catalog will be used from the first image passed in, by default. In this case, the reference catalog is in im1, so im2 will be aligned to im1 and only the WCS information for im2 will be adjusted.

Users have the option to supply an external reference catalog to align each image to as well - if that was done in this case, the WCS both im1 and im2 would be aligned to that catalog and so both files would be modified.  

In [None]:
tweak_image_wcs([im1, im2])

###  Inspect Corrected WCS

TweakWcs inserts a correction into the full transformation pipeline, labeled 'tangent-plane linear correction'. If we im2, we can see this has been inserted.

In [None]:
print(im2.meta.wcs)

Now let's look at the same residual plot that was made before the alignment with the re-aligned images.

In [None]:
im1_ra, im1_dec = im1.meta.wcs(matched_tab1['x'],matched_tab1['y'])
im2_ra, im2_dec = im2.meta.wcs(matched_tab2['x'],matched_tab2['y'])

q = plt.quiver(im1_ra, im1_dec,-3600.*(im1_ra-im2_ra),3600*(im1_dec-im2_dec))
plt.quiverkey(q,0.1,0.1,0.001,'1mas', color = 'r')
plt.title('residuals, after alignment')
plt.xlim(plt.xlim()[::-1])
plt.xlabel('ra (reference image)')
plt.ylabel('dec (reference image)'); pass;
print('median absolute offset after alignment, RA and Dec (arcsec): ',
      np.median(3600.*np.abs(im2_ra - im1_ra)), 
      np.median(3600.*np.abs(im2_dec - im1_dec)))

The randomness is an indicator of a good fit, and the residuals are on the order of a tiny fraction of a degree (arrow sizes are normalized - these residuals are in fact tiny).

### Save Aligned Image.

Because TweakWCS works in memory, we have to explicitly save the tweaked GWCS to file once we are satisfied with the alignment. This can be done with the ImageModel.write method.

In [None]:
im2.write('im2_aligned.fits')