# Aligning HST WFC3/IR Images Using `subpixal`

***
## About this Notebook
**Author:** Mihai Cara, STScI
<br>**Initial version on:** 12/27/2018
<br>**Updated on:** 01/02/2018

***
## Introduction

Often the World Coordinate System (WCS) of images may contain small errors. These alignment errors in the WCS of the images need to be removed before images can be further processed, e.g., before they can be combined into a mosaiced image. The images are said to be aligned (in a relative sense) _on the sky_ when image coordinates _of the same object_ (present in several images) can be converted aproximately the same sky coordinates (using appropriate image's WCS).

In this notebook we illustrate how to set-up a simple workflow for aligning images using `subpixal` package designed for sub-pixel cross-correlation image alignment.

<font color=red>**WARNING:** When working with real data, **BACKUP** ALL DATA before using `subpixal` as it modifies input data _in place_.</font>

***
## Imports

In [None]:
# import subpixal for image alignment
from subpixal import align_images, Drizzle, SExImageCatalog

# for image retrieval from archive:
import os
import glob
import shutil
from astroquery.mast import Observations
from astropy import table

# for plotting
%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np

***
# 1. Download Data

For this example, we have chosen HST WFC3/IR observation of GRB-110328A in the F160W filter. The data come from the GO/DD proposal 12447 _"The Nature of the Remarkable Transient GRB 110328A"_ (PI: Andrew S. Fruchter).

Data are downloaded using the `astroquery` API to access the [MAST](http://archive.stsci.edu) archive. The `astroquery.mast` [documentation](http://astroquery.readthedocs.io/en/latest/mast/mast.html) has more examples for how to find and download data from MAST.

In [None]:
# If mastDownload directory already exists, delete it
# and all subdirectories it contains:
if os.path.isdir('mastDownload'):
    shutil.rmtree('mastDownload')

# Retrieve the observation information.
obs_table1 = Observations.query_criteria(obs_id='ibof02*', filters='F160W', obstype='ALL')
products1 = Observations.get_product_list(obs_table1)
obs_table2 = Observations.query_criteria(obs_id='ibvwh2*', filters='F160W', obstype='ALL')
products2 = Observations.get_product_list(obs_table2)
products = table.vstack([products1, products2])

Observations.download_products(products, mrp_only=False,
                               productSubGroupDescription=['FLC', 'FLT'], 
                               extension='fits')

def copy_mast_to_cwd():
    """
    Move the files from the mastDownload directory to the current working
    directory and make a backup of the files. Return a list of image file
    names in the CWD.
    
    """
    downloaded_fits_files = (glob.glob('mastDownload/HST/ibof*/ibof*flt.fits') +
                             glob.glob('mastDownload/HST/ibvwh*/ibvwh*flt.fits'))
    fits_files = []
    for fil in downloaded_fits_files:
        base_name = os.path.basename(fil)
        fits_files.append(base_name)
        if os.path.isfile(base_name):
            os.remove(base_name)
        shutil.copy2(fil, '.')
        
    return fits_files

flt_files = copy_mast_to_cwd();

***
# 2. [Optionally] Perform Initial Alignment Using `tweakreg`

`subpixal` performs well when alignment errors are small. When mis-alignment is large, it is helpful to perform an initial alignmnet using some other method. In particular, when images contain several non-saturated stars, `tweakreg` can be used for initial alignment.

In [None]:
from drizzlepac import tweakreg

tweakreg.TweakReg(
    ','.join(flt_files), 
    reusename=True,
    conv_width=3.5, 
    #refimage='', 
    threshold=0.02, 
    minobj=5,
    searchrad=5,
    searchunits='pixel',
    configobj=None, 
    interactive=False,
    shiftfile=False, 
    updatehdr=True)

***
# 3. Create Drizzle object to run drizzle on a set of images

In [None]:
driz = Drizzle(input='ibof*_flt.fits,ibvw*_flt.fits',
               build=False,
               clean=False,
               driz_cr_corr=True,
               skystat='mean',
               skylower=0,
               skylsigma=2.0,
               skyusigma=2.0,
               stepsize=1,
               driz_sep_kernel='square',
               combine_type='imedian',
               combine_nlow=2,
               final_wcs=True, final_pixfrac=0.8,
               final_scale=0.06666666,
               final_refimage='')

***
# 4. Set-up Image Catalog Object

Because we want to use `SExtractor` to find sources in images, we create a `SExImageCatalog` object and set-up appropriate filters. In this case, we want to exclude stars, faint sources, and very diffuse sources.

In [None]:
cat = SExImageCatalog(None, 'sextractor.cfg', max_stellarity=0.7)
cat.append_filters([('flux', '>', 20), ('pos_snr', '>', 80), ('fwhm', '>', 1), ('fwhm', '<', 8)])

***
# 5. Align Images

In [None]:
fit_history = align_images(
    cat,
    driz,
    wcslin=None,
    fitgeom='general',
    nclip=3,
    sigma=2,
    nmax=50,
    eps_shift=1e-2,
    iterative=False,
    history='last'
)

***
# 6. Display Cutouts of Sources Used For Alignment

We can use the "history" returned by `align_images()` function to inspect what sources were used for fitting linear transformations and which sources were excluded (clipped out) from the fit. We will use `'jet'` colormap for the excluded sources and `'gray'` colormap for sources used in fitting.

For each source we plot (horizontally) the following plots:
- Cutout from the corresponding CR-cleaned image;
- Cutout from the drizzled image;
- Drizzled cutout blotted back into the CR-cleaned image's grid;
- Supersampled cross-correlation images.

In [None]:
ii = fit_history[-1]['finfo'][0]['image_info']
fit = fit_history[-1]['finfo'][0]['fit_info']

def find_minmax(data):
    return min(map(np.amin, data)), max(map(np.amax, data))

nrows = len(ii['image_cutouts'])
plt.figure(figsize=(5, nrows))
mindata1, maxdata1 = find_minmax([i.data for i in ii['image_cutouts']])
mindata2, maxdata2 = find_minmax(ii['ICC'])
idx = list(fit['img_indx'])

for k, (ict, dct, ic, blt) in enumerate(zip(ii['image_cutouts'], ii['driz_cutouts'], ii['ICC'],
                                                ii['blotted_cutouts'])):
    # plot excluded cutouts using 'jet' color map:
    cm = plt.cm.gray if k in idx else plt.cm.jet

    mindata1, maxdata1 = find_minmax([ict.data])
    mindata2, maxdata2 = find_minmax([ic])

    d1 = (ict.data - mindata1) / (maxdata1 - mindata1)
    d2 = (dct.data - mindata1) / (maxdata1 - mindata1)
    b00 = (blt.data - mindata1) / (maxdata1 - mindata1)
    d3 = (ic - mindata2) / (maxdata2 - mindata2)

    ax = plt.subplot(nrows, 4, 4 * k + 1); ax.axis('off')
    ax.imshow(d1, cmap=cm, origin='lower', interpolation='none', aspect='equal');
    ax.scatter([ict.cutout_src_pos[0]], [ict.cutout_src_pos[0]], marker='x')

    ax = plt.subplot(nrows, 4, 4 * k + 2); ax.axis('off')
    ax.imshow(d2, cmap=cm, origin='lower', interpolation='none', aspect='equal');
    ax.scatter([dct.cutout_src_pos[0]], [dct.cutout_src_pos[0]], marker='x')

    ax = plt.subplot(nrows, 4, 4 * k + 3); ax.axis('off')
    ax.imshow(b00, cmap=cm, origin='lower', interpolation='none', aspect='equal');

    ax = plt.subplot(nrows, 4, 4 * k + 4); ax.axis('off')
    ax.imshow(d3, cmap=cm, origin='lower', interpolation='none', aspect='equal');