# Notebook to showcase the basic functioning of the wfc3_dash module

Everything here can also be run under a single main function, explained at the end of this tutorial.

## Imports

* *astroquery.mast Observations* used to download IMA files from the MAST HST archive
* *astropy.io import fits* used to open the files
* *matplotlib.pyplot* used to plot the images
* *numpy* used for some math

In [None]:
from astroquery.mast import Observations
from astropy.io import fits 
from astropy.table import Table
import matplotlib.pyplot as plt 
import numpy as np
from drizzlepac import tweakreg
from drizzlepac import astrodrizzle

%matplotlib notebook 

## Introduction

The wfc3_dash submodule of wfc3_tools is used to reduce the effects of the spacecraft drift for WFC3/IR images taken in DASH mode (i.e. under GYRO control, rather than under Fine-Guide-Sensor control) 

This notebook works on a single .flt file but can be easily adapted to work on all exposures within a DASH visit or even a DASH program

## Downloading some relevant data

#### Get the table of observations associated to GO-14114 (PI van Dokkum, the first proposal to use the DASH mode)

In [None]:
obsTable = Observations.query_criteria(proposal_id=['14114'])

#### Get the full list of products associated to the table and restric the list to IMA files

In [None]:
product_list = Observations.get_product_list(obsTable)
BM = (product_list['productSubGroupDescription']  == 'IMA') 
product_list = product_list[BM]

#### Display (part of) the IMA files list

In [None]:
product_list.show_in_notebook(display_length=5)

#### Pick a single exposure file to work on

In [None]:
myID = product_list['obsID'][0:1]

#### Download the IMA and FLT files for that exposure. The standard pipeline-FLT will be used for comparison with the detrended final product

In [None]:
download = Observations.download_products(myID,mrp_only=False,productSubGroupDescription=['IMA','FLT'])

#### Display the results of the download operation

In [None]:
download

#### Read the files that were just downloaded locally 

In [None]:
#have path be everything minus last 8 characters (ima.fits)
localpathtofile = download['Local Path'][0][:-8]
localpathtofile

original_ima = fits.open(localpathtofile+'ima.fits')
original_flt = fits.open(localpathtofile+'flt.fits')
original_ima.info()

#### Plot the individual reads of the IMA file
Note: the individual 'SCI' extensions are stored in reverse order, with 'SCI', 1 corresponding to the last read

In [None]:
nsamp = original_ima[0].header['NSAMP']
print('NSAMP',nsamp)
fig,axarr = plt.subplots((nsamp+3)//4,4, figsize=(9,3*((nsamp+3)//4)))

for i in range(1,4*((nsamp+3)//4)+1):

    row = (i-1)//4
    col = (i-1)%4
    if (i <= nsamp+1):
        immed = np.nanmedian(original_ima['SCI',i].data)
        stdev = np.nanstd(original_ima['SCI',i].data)
        axarr[row,col].imshow(original_ima['SCI',i].data,clim=[immed-.3*stdev,immed+.5*stdev],cmap='Greys',origin='lower')
        axarr[row,col].set_title('SCI '+str(i))
        axarr[row,col].set_xticks([]) 
        axarr[row,col].set_yticks([]) 
    else:
        fig.delaxes(axarr[row,col])

fig.tight_layout()

## Run the individual steps of the DASH pipeline

#### This cell is inserted temporarely to allow for relative imports until the whole wfc3_dash submodule is properly packaged and installed within the wfc3_tools module

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from reduce_dash import DashData

### 1. Create a DashData object using the path to the ima file we have downloaded above

In [None]:
myDash = DashData(localpathtofile+'ima.fits')

### 2. Create diff files. 

A diff file contains the counts accumulated between two reads.  
The diff files are written to disk in a directory named ./diff under the current working directory (cwd).  
In creating diff files, the first difference, between the 1-st and 0-th read is ignored becuase of   
its very short expsoure time of 2.9 seconds, resulting in a noisy image.

In order to create a correct error extension, the split_ima() method calls the utils.get_flat() function.  
Such function reads the name of the flat field used for calibrating the ima images from the ima file header.  
If the flat file is not present locally in a directory named ./iref under the cwd, get_flat() will download   
the flat field file from the CRDS database https://hst-crds.stsci.edu/unchecked_get/references/hst/ 
and place it in ./iref .

In [None]:
myDash.split_ima()

#### Plot the diff files

In [None]:
ndiff = len(myDash.diff_files_list)
print('Number of diff files',ndiff)
fig,axarr = plt.subplots((ndiff+3)//4,4, figsize=(9,3*((ndiff+3)//4)))

for i in range(4*((ndiff+3)//4)):

    row = (i)//4
    col = (i)%4
    if (i < ndiff):
        diff_i = fits.open(myDash.diff_files_list[i]+'_diff.fits')
        immed = np.nanmedian(diff_i['SCI'].data)
        stdev = np.nanstd(diff_i['SCI'].data)
        axarr[row,col].imshow(diff_i['SCI'].data,clim=[immed-.3*stdev,immed+.5*stdev],cmap='Greys',origin='lower')
        axarr[row,col].set_title('Diff:'+str(i+1))
        axarr[row,col].set_xticks([]) 
        axarr[row,col].set_yticks([]) 
    else:
        fig.delaxes(axarr[row,col])

fig.tight_layout()

### 3. Create an association file

This file mimics a typical association file for dithered exposures, that is used by astrodrizzle   
to align and stack multiple exposures taken at the same sky position with small dithers.  
We exploit the fact that a WFC3/IR exposure taken under gyro control can be effectively split into   
individual pseudo-exposures (the diff images).  
Astrodrizzle can treat such pseudo-expsoures as individual dithers, and comnbine them.

In [None]:
myDash.make_pointing_asn()

#### Show the content of the asn file

In [None]:
asn_filename = 'diff/{}_asn.fits'.format(myDash.root)
asn_table = Table(fits.getdata(asn_filename, ext=1))
asn_table.show_in_notebook()

---------------------------------------------------------------

### 4. Subtract Background from original FLT

This is necessary in order to create a drizzled science array and a segmentation map to help subtract the background from the new FLT's. <br>

In [None]:
myDash.subtract_background_flt()

### 5. Subtract Background from *new* FLT's

Next, subtract background from the individual reads taken from the original IMA file using the DRZ and SEG imaged produced in the background subtraction of the original FLT. <br>
By default, this function will subtract the background and write it to the header. Setting parameter subtract to False will not subtract the background and only write it to the header. <br>
Set parameter reset_stars_dq to True to reset cosmic rays within objects to 0 (because the centers of the stars are flagged).

In [None]:
myDash.subtract_background_reads()

### 6. Fix Cosmic Rays

In [None]:
myDash.fix_cosmic_rays()

### 7a. Align reads to catalog

#### Determine coordinates and search area from the WCS's of your images (procedure taken from Gaia_alignment).

In [None]:
import glob

from astropy import units as u
from astropy.io import fits
from astropy.wcs import WCS
from astropy.visualization import wcsaxes
from astropy.coordinates.sky_coordinate import SkyCoord

from matplotlib.patches import Polygon
import matplotlib.cm as cm


# ----------------------------------------------------------------------------------------------------------

#use coordinates of original exposure
def get_footprints(im_name):
    """Calculates positions of the corners of the science extensions of some image 'im_name' in sky space"""
    footprints = []
    hdu = fits.open(im_name)
    
    flt_flag = 'flt.fits' in im_name or 'flc.fits' in im_name
    
    # Loop ensures that each science extension in a file is accounted for. 
    for ext in hdu:
        if 'SCI' in ext.name:
            hdr = ext.header
            wcs = WCS(hdr, hdu)
            footprint = wcs.calc_footprint(hdr, undistort=flt_flag)
            footprints.append(footprint)
    
    hdu.close()
    return footprints

# ----------------------------------------------------------------------------------------------------------
def bounds(footprint_list):
    """Calculate RA/Dec bounding box properties from multiple RA/Dec points"""
    
    # flatten list of extensions into numpy array of all corner positions
    merged = [ext for image in footprint_list for ext in image]
    merged = np.vstack(merged)
    ras, decs = merged.T
    
    # Compute width/height
    delta_ra = (max(ras)-min(ras))
    delta_dec = max(decs)-min(decs)

    # Compute midpoints
    ra_midpt = (max(ras)+min(ras))/2.
    dec_midpt = (max(decs)+min(decs))/2.
    

    return ra_midpt, dec_midpt, delta_ra, delta_dec
# ----------------------------------------------------------------------------------------------------------

            
images = glob.glob(localpathtofile+'flt.fits')
footprint_list = list(map(get_footprints, images))

# # If that's slow, here's a version that runs it in parallel:
# from multiprocessing import Pool
# p = Pool(8)
# footprint_list = list(p.map(get_footprints, images))
# p.close()
# p.join()

ra_midpt, dec_midpt, delta_ra, delta_dec = bounds(footprint_list)

coord = SkyCoord(ra=ra_midpt, dec=dec_midpt, unit=u.deg)
print(coord)

#### Querying from Gaia

In [None]:
from astropy.units import Quantity
from astroquery.gaia import Gaia

width = Quantity(delta_ra, u.deg)
height = Quantity(delta_dec, u.deg)

In [None]:
# Perform the query!
r = Gaia.query_object_async(coordinate=coord, width=width, height=height)

In [None]:
# Print the table
r

In [None]:
ras = r['ra']
decs = r['dec']
mags = r['phot_g_mean_mag']
ra_error = r['ra_error']
dec_error = r['dec_error']

#### Aligning Data to Catalog

In [None]:
from astropy.table import Table

tbl = Table([ras, decs]) # Make a temporary table of just the positions
tbl.write('gaia.cat', format='ascii.fast_commented_header') # Save the table to a file.  The format argument ensures
                                                            # the first line will be commented out.

In [None]:
thresh = 10.

def get_error_mask(catalog, max_error):
    """Returns a mask for rows in catalog where RA and Dec error are less than max_error"""
    ra_mask = catalog['ra_error']< max_error
    dec_mask = catalog['dec_error'] < max_error
    mask = ra_mask & dec_mask
#     print('Cutting sources with error higher than {}'.format(max_error))
#     print('Number of sources befor filtering: {}\nAfter filtering: {}\n'.format(len(mask),sum(mask)))
    return mask

mask = get_error_mask(r, thresh)

tbl_filtered = Table([ras[mask], decs[mask]]) 
tbl.write('gaia_filtered_{}_mas.cat'.format(thresh), format='ascii.fast_commented_header')

In [None]:
ls *cat

In [None]:
pwd

#### Align images. <br>
Will align to catalog by default. Parameters ref_catalog and ref_image denote the reference catalog and reference image, respectively. <br>
To not align to catalog, set parameter align_method to ... <br>
To not subtract background, set parameter subtract_background to False.

[WIP]: Try to fix align function.

In [None]:
# from astropy.io.fits import getdata


# asn_filename = 'diff/{}_asn.fits'.format(myDash.root)

# asn_list=[]
# for index, file in enumerate(myDash.diff_files_list):
#     asn_list += [file +'_diff.fits']
    
# #asn_list.append(myDash.root)

# # Create Primary HDU:
# hdr = fits.Header()
# hdr['FILENAME'] = "'" + asn_filename + "'"
# hdr['FILETYPE'] = 'ASN_TABLE'
# hdr['ASN_ID'] = "'" + myDash.root + "'"
# hdr['ASN_TABLE'] = "'" + asn_filename + "'"
# hdr['COMMENT'] = "This association table is for the read differences for the IMA."
# primary_hdu = fits.PrimaryHDU(header=hdr)

# # Create the information in the asn file
# num_mem = len(asn_list)

# asn_mem_names = np.array(asn_list)
# asn_mem_types =  np.full(num_mem,'EXP-DTH',dtype=np.chararray)
# asn_mem_types[-1] = 'PROD-DTH'
# asn_mem_prsnt = np.ones(num_mem, dtype=np.bool_)
# asn_mem_prsnt[-1] = 0

# hdu_data = fits.BinTableHDU().from_columns([fits.Column(name='MEMNAME', format='17A', array=asn_mem_names), 
#             fits.Column(name='MEMTYPE', format='14A', array=asn_mem_types), 
#             fits.Column(name='MEMPRSNT', format='L', array=asn_mem_prsnt)])

# # Create the final asn file
# hdu = fits.HDUList([primary_hdu, hdu_data])

# if 'EXTEND' not in hdu[0].header.keys():
#     hdu[0].header.update('EXTEND', True, after='NAXIS')
        
# hdu.writeto(asn_filename, overwrite=True)

# data = getdata(asn_filename, 1)

In [None]:
myDash.align(align_method='CATALOG', ref_catalog = 'gaia.cat')

Plot aligned science images.

In [None]:
sci = fits.getdata('final_drz_sci.fits')

fig = plt.figure(figsize=(20, 20))

plt.imshow(sci, vmin=-0.05, vmax=0.4, cmap='Greys_r', origin='lower')

### 7b. Align reads to each other

Alternative option to aligning reads to catalog.

In [None]:
myDash.make_read_catalog()

## Using main function to reduce data