In [None]:
from astroquery.mast import Observations
from astropy.io import fits 
from astropy.table import Table
import matplotlib.pyplot as plt 
import numpy as np

from astropy.io import ascii

from glob import glob
from drizzlepac import astrodrizzle

import os

%matplotlib notebook 

## Downloading some relevant data

#### Get the table of observations associated with 15238 

In [None]:
obsTable = Observations.query_criteria(proposal_id=['15238'], obs_id=['IDNM0J030'])

#### Get the full list of products associated to the table and restrict the list to IMA files

In [None]:
product_list = Observations.get_product_list(obsTable)
BM = (product_list['productSubGroupDescription']  == 'IMA') 
product_list = product_list[BM]

product_list.show_in_notebook(display_length=5)

#### Pick a single exposure file to work on - to create usable data you will have to follow this work flow on all individual IMA files in your dataset.

In [None]:
myID = product_list['obsID'][0:1]

#### Download the IMA and FLT files for that exposure. The standard pipeline-FLT will be used for comparison with the detrended final product

In [None]:
download = Observations.download_products(myID,mrp_only=False,productSubGroupDescription=['IMA','FLT'])

#### Display the results of the download operation

In [None]:
download

#### Read the files that were just downloaded locally 

In [None]:
#have path be everything minus last 8 characters (ima.fits)
localpathtofile = download['Local Path'][0][:-8]
localpathtofile

original_ima = fits.open(localpathtofile+'ima.fits')
original_flt = fits.open(localpathtofile+'flt.fits')
original_ima.info()

#### Plot the individual reads of the IMA file
Note: the individual 'SCI' extensions are stored in reverse order, with 'SCI', 1 corresponding to the last read

In [None]:
nsamp = original_ima[0].header['NSAMP']
print('NSAMP',nsamp)
fig,axarr = plt.subplots((nsamp+3)//4,4, figsize=(9,3*((nsamp+3)//4)))

for i in range(1,4*((nsamp+3)//4)+1):

    row = (i-1)//4
    col = (i-1)%4
    if (i <= nsamp):
        immed = np.nanmedian(original_ima['SCI',i].data)
        stdev = np.nanstd(original_ima['SCI',i].data)
        axarr[row,col].imshow(original_ima['SCI',i].data,clim=[immed-.3*stdev,immed+.5*stdev],cmap='Greys',origin='lower')
        axarr[row,col].set_title('SCI '+str(i))
        axarr[row,col].set_xticks([]) 
        axarr[row,col].set_yticks([]) 
    else:
        fig.delaxes(axarr[row,col])

fig.tight_layout()

## Run the individual steps of the DASH pipeline
Run the DASH pipeline for a single exposure.  
This procedure showcases the capabilities and customization options of the DASH pipeline.


#### This cell is inserted temporarily to allow for relative imports until the whole wfc3_dash submodule is properly packaged and installed within the wfc3_tools module

In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from reduce_dash import DashData

### 1. Create a DashData object using the path to the IMA file we have downloaded above

In [None]:
myDash = DashData(localpathtofile+'ima.fits', flt_file_name=localpathtofile+'flt.fits')
print(myDash.root)

### 2. Create diff files

A diff file contains the counts accumulated between two reads.  
The diff files are written to disk in a directory named ./diff under the current working directory (cwd).  
In creating diff files, the first difference, between the 1-st and 0-th read is ignored becuase of   
its very short expsoure time of 2.9 seconds, resulting in a noisy image.

In order to create a correct error extension, the split_ima() method calls the utils.get_flat() function and the utils.get_IDCtable function.  
The get_flat function reads the name of the flat field used for calibrating the ima images from the ima file header.  
The get_IDCtable reads the name of image distortion correction table, a reference file containing distortion coefficients that are used to correct for distortion in MAST drizzled data products.  
If the flat file is not present locally in a directory named ./iref under the cwd, get_flat() will download   
the flat field file from the CRDS database https://hst-crds.stsci.edu/unchecked_get/references/hst/ 
and place it in ./iref . Similarly for the IDCTAB reference file.

In [None]:
myDash.split_ima()

#### Plot the diff files

In [None]:
ndiff = len(myDash.diff_files_list)
print('Number of diff files',ndiff)

if ndiff > 4: 
    fig,axarr = plt.subplots((ndiff+3)//4,4, figsize=(9,3*((ndiff+3)//4)))

    for i in range(4*((ndiff+3)//4)):

        row = (i)//4
        col = (i)%4
        if (i < ndiff):
            diff_i = fits.open(myDash.diff_files_list[i]+'_diff.fits')
            immed = np.nanmedian(diff_i['SCI'].data)
            stdev = np.nanstd(diff_i['SCI'].data)
            axarr[row,col].imshow(diff_i['SCI'].data,clim=[immed-.3*stdev,immed+.5*stdev],cmap='Greys',origin='lower')
            axarr[row,col].set_title('Diff:'+str(i+1))
            axarr[row,col].set_xticks([]) 
            axarr[row,col].set_yticks([]) 
        else:
            fig.delaxes(axarr[row,col])
else:
    fig,axarr = plt.subplots(1,ndiff,figsize=(15,15))
    for i in range(ndiff):
        immed = np.nanmedian(diff_i['SCI'].data)
        stdev = np.nanstd(diff_i['SCI'].data)
        diff_i = fits.open(myDash.diff_files_list[i]+'_diff.fits')
        axarr[i].imshow(diff_i['SCI'].data,clim=[immed-.3*stdev,immed+.5*stdev],cmap='Greys',origin='lower')
        axarr[i].set_title('Diff:'+str(i+1))
        axarr[i].set_xticks([]) 
        axarr[i].set_yticks([]) 

fig.tight_layout()

### 3. Create an association file

This file mimics a typical association file for dithered exposures, that is used by astrodrizzle   
to align and stack multiple exposures taken at the same sky position with small dithers.  
We exploit the fact that a WFC3/IR exposure taken under gyro control can be effectively split into   
individual pseudo-exposures (the diff images).  
Astrodrizzle can treat such pseudo-expsoures as individual dithers, and combine them.

In [None]:
myDash.make_pointing_asn()

#### Show the content of the asn file

In [None]:
asn_filename = 'diff/{}_asn.fits'.format(myDash.root)
asn_table = Table(fits.getdata(asn_filename, ext=1))
asn_table.show_in_notebook()

### 4. Create Segmentation Map

#### Create segmentation map from original FLT
Make segmentation map from original FLT image to assist with background subtraction and fixing of cosmic ray flags.  
We first use `create_seg_map` to create a segmentation map from the original FLT file using `photutils`.   

In [None]:
myDash.create_seg_map()

#### View segmentation map.

In [None]:
rootname = myDash.root
segmap_name = ('segmentation_maps/'+ rootname + '_seg.fits')
segmap = fits.getdata(segmap_name)
print(segmap_name)
fig = plt.figure(figsize=(6, 8))
plt.imshow(segmap, origin='lower', vmin=0, vmax=1, cmap='Greys_r')

#### Print and read source list.

In [None]:
sourcelist_name = ('segmentation_maps/' + rootname + '_source_list.dat')
sourcelist = ascii.read(sourcelist_name)
print(sourcelist)

#### Create segmentation map and source list from diff files
Make source lists from our difference files created from the IMA so that `TweakReg` can better align these difference files to catalogs, each other, etc.
The function `diff_seg_map` needs a list of difference files that contain the full path name.

In [None]:
diffpath = os.path.dirname(os.path.abspath('diff/{}_*_diff.fits'.format(rootname)))
cat_images=sorted([os.path.basename(x) for x in glob('diff/{}_*_diff.fits'.format(rootname))])

sc_diff_files = [diffpath + '/' + s for s in cat_images]

In [None]:
myDash.diff_seg_map(cat_images=sc_diff_files,overwrite=True)

In [None]:
segmap_name = ('segmentation_maps/' + rootname + '_01_diff_seg.fits')
segmap = fits.getdata(segmap_name)
print(segmap_name)
fig = plt.figure(figsize=(6, 8))
plt.imshow(segmap, origin='lower', vmin=0, vmax=1, cmap='Greys_r')

### 5. Subtract Background from diff files
Subtract background from the individual reads taken from the original IMA file using the DRZ and SEG imaged produced in the background subtraction of the original FLT.  
By default, this function will subtract the background and write it to the header. Setting parameter subtract to False will not subtract the background and only write it to the header.  
Set parameter reset_stars_dq to True to reset cosmic rays within objects to 0 (because the centers of the stars are flagged).

In [None]:
myDash.subtract_background_reads()

### 6. Fix Cosmic Rays

In [None]:
myDash.fix_cosmic_rays()

### 7. Align reads to each other
Align reads to one another by aligning each to the first diff file.  

Uses TweakReg to update the WCS information in the headers of the diff files, then drizzles the images together using Astrodrizzle.  

Refer to documentation to customize parameters for TweakReg and AstroDrizzle. 

(`NOTE: UnboundLocalError: local variable 'sig' referenced before assignment` --> Can be solved by lowering threshold parameter)

In [None]:
myDash.align(updatehdr=False, updateWCS=False, astrodriz=False)

Print the shifts file to analyze how well the alignment went.
Do not update header until shifts are satisfactory.

In [None]:
shift_file = glob('shifts/shifts_*.txt')
print(open(shift_file[0]).read())

Update header and WCS information, then plot final drizzled image.

Listed below are all the inputs available through the function call to `myDash.align()` which runs `TweakReg` and `AstroDrizzle`; there are more inputs available to users when working with `TweakReg` and `Astrodrizzle` that could be an integral part of the workflow for users of DASH. The example in this 

```myDash.align(self, subtract_background = True, 
            align_method = None, 
            ref_catalog = None, 
            create_diff_source_lists = True,
            updatehdr = True, 
            updateWCS = True, 
            wcsname = 'DASH', 
            threshold = 50., 
            cw = 3.5, 
            searchrad = 20., 
            astrodriz = True, 
            cat_file = 'catalogs/diff_catfile.cat',
            drz_output = None, 
            move_files = False)```

In [None]:
myDash.align(threshold = 20.)

In [None]:
sci_name = myDash.root + '_drz_sci.fits'
og_flt_name = 'mastDownload/HST/' + myDash.root + '/' + myDash.root + '_ima.fits'
sci = fits.getdata(sci_name)
og_flt = fits.getdata(og_flt_name)

fig = plt.figure(figsize=(20, 10))
ax1 = fig.add_subplot(1,2,2)
ax2 = fig.add_subplot(1,2,1)

ax1.set_title('DASH Pipeline Reduced Science File')
ax2.set_title('Original IMA (not reduced using pipeline)')

ax1.imshow(sci, vmin=0, vmax=40, cmap='Greys_r', origin='lower', aspect="auto")
ax2.imshow(og_flt, vmin=0, vmax=40, cmap='Greys_r', origin='lower', aspect="auto")