<a id="title_ID"></a>
# JWST Pipeline Validation Notebook: Calwebb_coron3 for MIRI coronagraphic imaging
<span style="color:red"> **Instruments Affected**</span>: MIRI, NIRCam 

### Table of Contents

<div style="text-align: left"> 
    
<br> [Introduction](#intro)
<br> [JWST CalWG Algorithm](#algorithm)
<br> [Defining Terms](#terms)
<br> [Test Description](#test_descr)
<br> [Data Description](#data_descr)
<br> [Imports](#imports)
<br> [Load Input Data](#data_load)
<br> [Run the Pipeline](#run_pipeline)
<br> [Examine Outputs](#testing) 
<br> [About This Notebook](#about)
<br>    

</div>

<a id="intro"></a>
# Introduction

This is the validation notebook for stage 3 coronagraphic processing of MIRI 4QPM/ Lyot exposures. 

The calwebb_coron3 module is stage 3 of the JWST Science Calibration Pipeline for NIRCam and MIRI coronagraphic data.  The inputs to this stage are the calibrated slope images (calwebb_image2 output) and the output is a reference-PSF-subtracted image stack. The steps are listed in Figure 1 with the flow from the top to the bottom.

The stage 3 coronagraphic pipeline ([`calwebb_coron3`](https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/calwebb_coron3.html#calwebb-coron3)) is to be applied to associations of calibrated NIRCam and MIRI coronagraphic exposures (output of [`calwebb_image2`](https://jwst-pipeline.readthedocs.io/en/stable/jwst/pipeline/calwebb_image2.html)), and is used to produce PSF-subtracted, resampled, combined images of the source object. For more information on `calwebb_coron3`, please visit the links below.

> Module description: https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/calwebb_coron3.html#calwebb-coron3

> Pipeline code: https://github.com/spacetelescope/jwst/blob/master/jwst/coron/


[Top of Page](#title_ID)

<a id="algorithm"></a>
# JWST CalWG Algorithm

The `coron3` pipeline consists of the following steps:

- **`outlier_detection`**:  Identifies and flags bad pixels/outliers in the input images

- **`stack_refs`**: stacks the reference PSFs into a single 3D data cube

- **`align_refs`**: aligns the stack of reference PSFs with the target PSFs.

- **`klip`**: uses the Karhunen-Loeve Image Plane (KLIP) algorithm to fit and subtract an optimal PSF from the target PSFs.

- **`resample`**: combines the PSF-subtracted target images into a single product.

The current status of these algorithms is summarized in the link below:

> https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/calwebb_coron3.html#calwebb-coron3


[Top of Page](#title_ID)

<a id="terms"></a>
# Defining Terms

- **JWST**: James Webb Space Telescope ([see documentation](https://jwst-docs.stsci.edu/))
- **MIRI**: Mid-Infrared Instrument ([see documentation](https://jwst-docs.stsci.edu/mid-infrared-instrument))
- **NIRCam**: Near-Infrared Instrument ([see documentation](https://jwst-docs.stsci.edu/near-infrared-camera))
- **4QPM**: 4 Quadrant Phase Mask ([see documentation](https://jwst-docs.stsci.edu/mid-infrared-instrument/miri-instrumentation/miri-coronagraphs#MIRICoronagraphs-4qpm))
- **Lyot**: coronagraph design incorporating a classical Lyot spot ([see documentation](https://jwst-docs.stsci.edu/mid-infrared-instrument/miri-instrumentation/miri-coronagraphs#MIRICoronagraphs-lyotcoron))
- **PanCAKE**: an in-house tool at STScI used to simulate coronagraphic PSFs 
([see documentation](https://github.com/spacetelescope/pandeia-coronagraphy))
- **SGD**: Small Grid Dither 
([see documentation](https://jwst-docs.stsci.edu/methods-and-roadmaps/jwst-high-contrast-imaging/hci-proposal-planning/hci-small-grid-dithers))
- **RSCD**: Reset-Switch Charge Decay ([see documentation](https://jwst-pipeline.readthedocs.io/en/stable/jwst/rscd/description.html))
    
    

    
 [Top of Page](#title_ID)

<a id="test_descr"></a>
# Test Description

This notebook tests the following steps applied by `calwebb_coron3` for pipeline version == **'1.9.5'**.
 - [**outlier_detection**](#outlier_detection)
 - [**stack_refs**](#stack_refs)
 - [**align_refs**](#align-refs)
 - [**klip**](#klip)
 - [**resample**](#resample)

These tests are performed using MIRI F1140C 4QPM coronagraphic data (see [Data Description](#data_descr)).


[Top of Page](#title_ID)

<a id="data_descr"></a>
# Data Description

### Input Data:

This notebook uses MIRI 1140C 4QPM coronagraphic observations of the young debris disk around HD 141569A obtained from JWST ERS program #[1386](https://www.stsci.edu/jwst/science-execution/approved-programs/dd-ers/program-1386): High Contrast Imaging of Exoplanets and Exoplanetary Systems with JWST. 

The data consists of two science exposures (6 ints, 940 groups) and five reference PSF exposures (4 ints, 400 groups) based on the following observation scenario (1) two science observations of the disk HD 141569A performed at separate roll angles in order to obtain 360 azimuthal coverage of the disk (2) the execution of a 5-point small grid dither (SGD) pattern on the PSF reference star HD 140986, to obtain a set of five slightly offset reference PSF observations.

The data has the following naming format:

* **Science exposures**: 
    - `jw01386022001_04101_00001_mirimage_calints.fits` (HD 141569 - *Roll 1*)
    - `jw01386023001_04101_00001_mirimage_calints.fits` (HD 141569 - *Roll 2*)<br></br>
    
* **PSF Reference exposures**:
    - `jw01386024001_04101_00001_mirimage_calints.fits` (HD 140986 - *position 1/5 in SGD pattern*)
    - `jw01386024001_04101_00002_mirimage_calints.fits` (HD 140986 - *position 2/5 in SGD pattern*)
    - `jw01386024001_04101_00003_mirimage_calints.fits` (HD 140986 - *position 3/5 in SGD pattern*)
    - `jw01386024001_04101_00004_mirimage_calints.fits` (HD 140986 - *position 4/5 in SGD pattern*)
    - `jw01386024001_04101_00005_mirimage_calints.fits` (HD 140986 - *position 5/5 in SGD pattern*)

The ERS data used in this tutorial can be downloaded directly from MAST (mast.stsci.edu). The following link provides a shortcut: [1386 MAST data download](https://mast.stsci.edu/portal/Mashup/Clients/Mast/Portal.html?searchQuery=%7B%22service%22%3A%22CAOMBYPROPID%22%2C%22inputText%22%3A%5B%7B%22paramName%22%3A%22proposal_id%22%2C%22niceName%22%3A%22proposal_id%22%2C%22values%22%3A%5B%221386%22%5D%2C%22valString%22%3A%221386%22%2C%22displayString%22%3A%221386%22%2C%22isDate%22%3Afalse%2C%22facetType%22%3A%22discrete%22%7D%5D%2C%22paramsService%22%3A%22Mast.Caom.Filtered%22%2C%22title%22%3A%22Proposal%20ID%20Results%22%2C%22columns%22%3A%22*%22%2C%22caomVersion%22%3Anull%7D).


### Reference Files:

The `align_refs` step requires a PSFMASK reference file containing a 2D mask that is used as a weight function when computing shifts between images. 

> File description: https://jwst-pipeline.readthedocs.io/en/latest/jwst/align_refs/description.html#psfmask-reffile
 

### Association File:

The `calwebb_coron3` requires the input of an association file (ASN) that lists one or more exposures of a science target and one or more reference PSF targets. The individual target and reference PSF exposures should be in the form of 3D calibrated (“_calints”) products from calwebb_image2 processing. Each pipeline step will loop over the 3D stack of per-integration images contained in each exposure.


> Level 3 Associations documentation: https://jwst-pipeline.readthedocs.io/en/latest/jwst/associations/level3_asn_rules.html

In this notebook we will use a provided ASN file:

 - `jw01386-c1027_coron3_00001_asn.json`



[Top of Page](#title_ID)

<a id="imports_ID"></a>
# Imports
* `jwst` is the main package for the James Webb Space Telescope (JWST) pipeline.
* `jwst.Coron3Pipeline` is the specific pipeline within the `jwst.pipeline` package used to reduce JWST coronagraphic imaging data.

* `os` to interact with the operating system.
* `glob` to search for files.
* `tempfile.TemporaryDirectory` to create a temporary directory for testing.
* `numpy` for array calculations and manipulation.
* `matplotlib.pyplot` for visualizing data.
* `mpl_toolkits.axes_grid1` provides a way to create axes with adjustable size and position.
* `astropy.io.fits` for working with FITS files.
* `ci_watson.artifactory_helpers.get_bigdata` for downloading the data files from from the Artifactory server.
* `pysiaf.Siaf` for accessing the JWST Science Instrument Aperture Files (SIAFs).
* `jwst.pipeline.Coron3Pipeline` provides a way to run the JWST Coronagraphic Science pipeline.
[Top of Page](#title_ID)

In [None]:
# Import necessary modules
import os
import glob
from tempfile import TemporaryDirectory

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

from astropy.io import fits
from astropy.utils.data import download_file

from ci_watson.artifactory_helpers import get_bigdata

from pysiaf import Siaf

import jwst
from jwst.pipeline import Coron3Pipeline

# Set up notebook for inline plotting and configure inline figures to remain open
%matplotlib inline
%config InlineBackend.close_figures=False

# Set up the CRDS_PATH environment accordingly
if 'CRDS_CACHE_TYPE' in os.environ:
    if os.environ['CRDS_CACHE_TYPE'] == 'local':
        os.environ['CRDS_PATH'] = os.path.join(os.environ['HOME'], 'crds', 'cache')
    elif os.path.isdir(os.environ['CRDS_CACHE_TYPE']):
        os.environ['CRDS_PATH'] = os.environ['CRDS_CACHE_TYPE']
print('CRDS (Calibration Reference Data System) cache location: {}'.format(os.environ['CRDS_PATH']))

In [None]:
jwst.__version__
# should out '1.9.5'

<a id="data_load"></a>
# Load Input Data

In [None]:
# Create a temporary directory to hold notebook output, and change the working directory to that directory
data_dir = TemporaryDirectory()
print(f"Created temporary directory: {data_dir}")
os.chdir(data_dir.name)
print(f"Changed to temporary directory: {os.getcwd()}")

In [None]:
# Copy the test files from Artifactory into the temporary directory
data_files = ['jw01386022001_04101_00001_mirimage_calints.fits',
              'jw01386023001_04101_00001_mirimage_calints.fits',
              'jw01386024001_04101_00001_mirimage_calints.fits',
              'jw01386024001_04101_00002_mirimage_calints.fits',
              'jw01386024001_04101_00003_mirimage_calints.fits',
              'jw01386024001_04101_00004_mirimage_calints.fits',
              'jw01386024001_04101_00005_mirimage_calints.fits',
              'jw01386-c1027_coron3_00001_asn.json']

for f in data_files:
    file = get_bigdata('jwst_validation_notebooks',
                       'validation_data',
                       'calwebb_coron3',
                       'coron3_miri_test', f)

In [None]:
ls

[Top of Page](#title_ID)
<a id="run_pipeline"></a>


------------

# Run the Pipeline

In [None]:
asn_dir = 'jw01386-c1027_coron3_00001_asn.json' # Define ASN file
myCoron3Pipeline = Coron3Pipeline()                              
myCoron3Pipeline.save_results = True 
myCoron3Pipeline.klip.truncate = 5 # it's a disk target so we don't want to use many KLIP modes. Default is 50
myCoron3Pipeline.output_dir = os.getcwd() 
myCoron3Pipeline.run(asn_dir) # run pipeline

[Top of Page](#title_ID)
<a id="testing"></a>
--------------
# Examine Output Data 

First, we define some plotting helper functions that will overlay the SIAF aperture boundaries and reference point on the data

In [None]:
# First, here are some plotting utilities to overlay SIAF information about the subarray
# over the images
miri_siaf = Siaf("MIRI")
aper_name = f'MIRIM_{fits.getval(data_files[0], "SUBARRAY", 0)}'

def plot_with_aper(img, aper_name=aper_name, ax=None, fig_kws={}, plot_kws={}, cbar_kws={}):
    """
    Plots an image with a specified aperture and reference point overlaid on top.

    Parameters
    ----------
    img : numpy.ndarray
        The 2D image data to be plotted.
    aper_name : str, optional
        The SIAF-searchable name of the aperture to be overlaid.
    ax : matplotlib.axes.Axes, optional
        The matplotlib axis object to use for plotting. If not provided, a new axis is created.
    fig_kws : dict, optional
        Additional keyword arguments to pass to the `plt.subplots` function when creating a new axis.
    plot_kws : dict, optional
        Additional keyword arguments to pass to the `ax.pcolor` function when plotting the image 
        (e.g. ax.imshow(img, **plot_kws)).

    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object.
    ax : matplotlib.axes.Axes
        The matplotlib axis object used for plotting.
    """
    if ax == None:
        fig, ax = plt.subplots(1, 1, **fig_kws)
    else:
        fig = ax.get_figure()
    ax.set_xlabel("subarray pixel")
    ax.set_ylabel("subarray pixel")

    aper = Siaf("MIRI")[aper_name]
    aper.plot(frame='sci', ax=ax, fill=False, mark_ref=True, c='C1')

    # using the aperture corner definitions to set the plot coordinates
    # ensures that the plotted image will match up with the aperture's 
    # features, e.g. the reference point
    corners = aper.corners('sci')
    x, y = np.meshgrid(np.arange(corners[0].min(), corners[0].max()+1, 1),
                       np.arange(corners[1].min(), corners[1].max()+1, 1))
    
    im = ax.pcolor(x, y, img, zorder=-1, **plot_kws)
    
    lolim, hilim = np.min(corners, axis=1), np.max(corners, axis=1)
    ax.set_xlim(lolim[0]-5, hilim[0]+5)
    ax.set_ylim(lolim[1]-5, hilim[1]+5)
    
    # Add colorbar
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = plt.colorbar(im, cax=cax, **cbar_kws)
    cbar.ax.tick_params(labelsize=8)
    return fig, ax

def plot_stack_with_aper(images, aper_name=aper_name, fig_kws={}, plot_kws={}):
    """
    Plots a stack of images with a specified aperture overlaid on top.

    Parameters
    ----------
    images : numpy.ndarray
        3D stack of image data to be plotted.
    aper_name : str, optional
        The SIAF-searchable name of the aperture to be overlaid. 
    fig_kws : dict, optional
        Additional keyword arguments to pass to the `plt.subplots` function when creating the figure.
    plot_kws : dict, optional
        Additional keyword arguments to pass to the `plot_with_aper` function when plotting each image.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object.
    """
    # initialize the figure
    ncols = 2
    nrows = np.ceil(images.shape[0]/ncols).astype(int)
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols,
                             figsize=(6*ncols, 6*nrows),
                             squeeze=False)
    for index, (ax, img) in enumerate(zip(axes.ravel(), images)):
        plot_with_aper(img, aper_name=aper_name, ax=ax, 
                       fig_kws={}, plot_kws=plot_kws);
        
    plt.tight_layout()
    return fig

<a id="outlier_detection"></a>
###  `outlier_detection`: (*'_median.fits' and '_crfints.fits' product*)

The role of the `outlier_detection` step is to identify and flag any remaining cosmic-rays or other artifacts within the calibrate images left over from previous calibrations. The results of this process have the identical format and content as the input `_calints` products. The only difference is that the DQ arrays have been updated to contain CR flags. 

*Output*: **CR-Flagged per integration product** <br>
*File suffix*: **'_crfints'**

> Step description: https://jwst-pipeline.readthedocs.io/en/stable/jwst/outlier_detection/index.html#outlier-detection-step

Here let's open one of these CR-flagged products:

In [None]:
crflag_image_hdu = fits.open('jw01386022001_04101_00001_mirimage_c1027_crfints.fits') 
crflag_image = crflag_image_hdu[1].data
print(crflag_image_hdu.info())

In [None]:
input_image_hdu = fits.open('jw01386022001_04101_00001_mirimage_calints.fits') # input calibrated per-int product
print("'_calints' data product dimensions: "+str(crflag_image_hdu[1].data.shape))
print("'_crfints' data product dimensions: "+str(crflag_image.shape))

As expected, the `crfints` product has the same format as the `calints` product (*Number of INTS x NX x NY*).
The only difference is that the DQ arrays have been updated to contain CR flags:

In [None]:
# subtract the cr-flagged DQ arrays from the input DQ arrays to see the flagged pixels.
fig = plot_stack_with_aper(input_image_hdu[3].data - crflag_image_hdu[3].data)

----------


<a id="stack_refs"></a>
###  `stack_refs`:  Stack PSF References (*'_psfstack' product*)

The role of the `stack_refs` step is to stack all of the PSF reference exposures (specified in the input ASN file) into a single 3D data cube for use by subsequent coronagraphic steps. The size of the stack should be equal to the sum of the number of integrations in each input PSF exposure.  The image data are simply copied and reformatted and should not be modified in any way.

*Output*: **3D PSF Image Stack** <br>
*File suffix*: **'_psfstack'**

> Step description: https://jwst-pipeline.readthedocs.io/en/latest/jwst/stack_refs/index.html#stack-refs-step

In [None]:
stacked_cube_hdu = fits.open('jw01386-c1027_t006_miri_f1140c-mask1140_psfstack.fits')
stacked_ref_images = stacked_cube_hdu[1].data
print(stacked_cube_hdu.info())

In [None]:
print("'_psfstack' data product dimensions: "+str(stacked_ref_images.shape))

Plot just the stacked reference images

In [None]:
vmin, vmax = np.nanquantile(stacked_ref_images[0], [0.01, 0.99])
plot_kws = {'vmin':vmin, 'vmax':vmax}
fig = plot_stack_with_aper(stacked_ref_images, plot_kws=plot_kws)

The `stack_psfs` step has successfully stacked the reference PSF exposures into a single 3D '*_psfstack*' product, with size equal to the sum of the number of integrations in each input PSF exposure. To confirm that the image data has not been modified, the input PSF images are subtracted from each image in the stack below:

In [None]:
# Gather the input reference images
input_ref_data = np.array([fits.getdata(data_files[2]), fits.getdata(data_files[3]), fits.getdata(data_files[4]), 
                    fits.getdata(data_files[5]), fits.getdata(data_files[6])])

input_ref_images = np.concatenate((input_ref_data[0], 
                         input_ref_data[1], 
                         input_ref_data[2],
                         input_ref_data[3],
                         input_ref_data[4],), axis=0)

In [None]:
fig = plot_stack_with_aper(stacked_ref_images - input_ref_images)
fig.suptitle("Stacked References, Differenced");

-----------

<a id="align_refs"></a>


### `align_refs`:  Align PSF References (*'_psfalign' product*)

The role of the `align_refs` step is to align the coronagraphic PSF images with science target images. It does so by computing the offsets between the science target and reference PSF images and shifts the PSF images into alignment. The output of the `align_refs` step is a 4D data product containing a stack of 2D PSF images aligned to each integration within a corresponding science target exposure.

*Output*: **4D aligned PSF Images** <br>
*File suffix*: **_psfalign**


> Step description: https://jwst-pipeline.readthedocs.io/en/latest/jwst/align_refs/index.html#align-refs-step

In [None]:
aligned_cube_hdu = fits.open('jw01386023001_04101_00001_mirimage_c1027_psfalign.fits')
aligned_cube_hdu.info()

In [None]:
aligned_cube_data = (aligned_cube_hdu[1].data)
print("'_psfalign' data product dimensions: " + str(aligned_cube_data.shape))

In [None]:
# cubes after alignment
aligned_images = aligned_cube_data[0]

vmin, vmax = np.nanquantile(aligned_images[0], [0.01, 0.99])
plot_kws = {'vmin':vmin, 'vmax':vmax}
fig = plot_stack_with_aper(aligned_images, plot_kws=plot_kws)
#fig.suptitle("Aligned References");

Show the aligned images, each differenced against the one before it (the first image is shown as normal)

In [None]:
# subtract each aligned image from the one before it
diff_images = np.diff(aligned_cube_data[0], axis=0, prepend=0)

#vmin, vmax = np.nanquantile(diff_images[0], [0.01, 0.99])
plot_kws = {}
fig = plot_stack_with_aper(diff_images, plot_kws=plot_kws)
#fig.suptitle("Aligned References, Differenced");

The `align_refs` step has successfully aligned the psf images - note the smaller residuals in the difference images. The first integration of each exposure is systematically different due to RSCD.
 
The output is indeed a 4D '*_psfalign*' product, where the 3rd axis has length equal to the total number of per-integration reference images in the input PSF stack (20) and 4th axis equal to the number of integrations in the input science target image (6). 

------------
<a id="klip"></a>
### `klip`:  Reference PSF Subtraction

The role of the `klip` step is to apply the Karhunen-Loeve Image Plane (KLIP) algorithm on the science target images, using an accompanying set of aligned reference PSF images (result of the `align_refs` step) in order to fit and subtract an optimal PSF from the science target image. The PSF fitting and subtraction is applied to each integration image independently. The output is a 3D stack of PSF-subtracted images of the science target, having the same dimensions as the input science target product.

*Output*: **3D PSF-subtracted image** <br>
*File suffix*: **_psfsub**

> Step description: https://jwst-pipeline.readthedocs.io/en/latest/jwst/klip/index.html#klip-step


In [None]:
sub_hdu = fits.open('jw01386022001_04101_00001_mirimage_c1027_psfsub.fits')
sub_hdu.info()

In [None]:
subtracted_image = sub_hdu[1].data
print("Science target image dimensions: " + str(fits.getdata('jw01386022001_04101_00001_mirimage_c1027_crfints.fits', 'SCI').shape))
print("PSF subtracted image dimensions: " + str(subtracted_image.shape))

Note that the PSF subtracted image has the same dimensions as the input target image.

In [None]:
vmin, vmax = np.nanquantile(subtracted_image[0], [0.01, 0.99])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,8))
fig, ax = plot_with_aper(subtracted_image[0], aper_name=aper_name, ax=ax, plot_kws={'vmin':vmin, 'vmax':vmax})
ax.set_title("PSF subtracted image, Roll 1")

The `klip` step has successfully made a synthetic psf reference image and subtracted it from the target PSF. The output is a 3D data array containing a stack of 2D PSF-subtracted science target images, one per integration.

------------------------------------------------------
<a id="resample"></a>
### `resample`: Image resampling

The role of the `resample` step is to combine the PSF-subtracted and CR-flagged images into a single resampled image. The `resample` routine will resample each input 2D image based on the WCS and distortion information and will combine multiple resampled images into a single undistorted product for the science target.  

*Output*: **2D resampled Images**</br>
*File Suffix*: **_i2d**


> Step description: https://jwst-pipeline.readthedocs.io/en/latest/jwst/resample/index.html#resample-step


In [None]:
res_hdu = fits.open('jw01386-c1027_t006_miri_f1140c-mask1140_i2d.fits')
res_hdu.info()

In [None]:
resampled_image = res_hdu[1].data
print("Resampled image dimensions: " + str(resampled_image.shape))

In [None]:
vmin, vmax = np.nanquantile(resampled_image, [0.01, 0.99])
fig, ax = plt.subplots(1, 1, figsize=(8,8))
im = plt.imshow(resampled_image, vmin=vmin, vmax=vmax, origin='lower')
# Add colorbar
divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)
cbar = plt.colorbar(im, cax=cax)
cbar.ax.tick_params(labelsize=8)

[Top of Page](#title_ID)
<br>
<a id="future_tests"></a>
--------------------

<a id="about"></a>
## About this Notebook
**Authors:** 
- Bryony F. Nickson (Staff Scientist, *MIRI Branch*) 
- J. Brendan Hagan
- Jonathan Aguilar (Staff Scientist, *MIRI Branch*)

<br> **Updated On:** 21/12/2022

[Top of Page](#title_ID)