# Understanding datamodels,  transforms, and the GWCS object created for  NIRCAM Wide Field Slitless Spectroscopy

### Simulated NIRCAM data are available from MAST:  http://archive.stsci.edu/jwst/simulations/index.html

Only first order spectra are visible in most cases. Fainter second order spectra only appear in F322W2 + grism observations due to sources near one edge of the field of view.
https://jwst-docs.stsci.edu/display/JTI/NIRCam+Grisms

In [None]:
import os
import jwst
import asdf

# transforms and datamodels
from jwst.transforms.models import NIRCAMForwardColumnGrismDispersion, NIRCAMForwardRowGrismDispersion
from jwst.transforms.models import NIRCAMBackwardGrismDispersion
from jwst.datamodels.wcs_ref_models import NIRCAMGrismModel
from jwst.datamodels import image

# wcs
from jwst import assign_wcs

# print this out as a visual reminder of what version of the jwst pipeline is being referenced
print("Using jwst pipeline version: {}".format(jwst.__version__))

## Make sure that you have set the JWST_NOTEBOOK_DATA environment variable in the terminal from which you started Jupyter Notebook.

The data will be read from that directory, and the pipeline should write to the current working directory, avoiding clobbers.
If you would like to use your own data just substitute the locations below.

In [None]:
notebook_dir = os.environ['JWST_NOTEBOOK_DATA']
nircam_data = notebook_dir + 'data/nircam/'

grism_file = nircam_data + 'nircam_grism_dispersed_image.fits'  # this is a row dispersed grism image
direct_file = nircam_data + 'nircam_grism_direct_image.fits'

# Examining the grism dispersion model
The following examples are going to be done specific to the Column Grism, but they are valid for the Row grism as well. In fact, since the referece trace image has been made for GRISMR at the end of this notebook you'll see an example using it's full reference file matches.
<table width="100%"><th width="100%" style="text-align:left" bgcolor="d9d9d9"><h2>NIRCAM Column Grism Dispersion</h2></th></table>

### Lets take a closer look at the reference file
Reference files for the JWST pipeline are stored in <a href="https://jwst-crds.stsci.edu/">CRDS</a>.

For NIRCAM, the specwcs reference files can be found by browsing to nircam->specwcs on the page linked above

In [None]:
# You can specify your own reference file, this is one that is also available from CRDS
specwcs = nircam_data + 'jwst_nircam_specwcs_0011.asdf'  # this is a GRISMC reference file

# open the file into it's datamodel, this can be done by giving the datamodel the filename directly
# We can save some of the items to local variables for easy reference using syntax like this:
with NIRCAMGrismModel(specwcs) as f:
    displ = f.displ
    dispx = f.dispx
    dispy = f.dispy
    invdispx = f.invdispx
    invdispy = f.invdispy
    invdispl = f.invdispl
    orders = f.orders

In [None]:
# f is the pointer to the file object and it still exists; and there is some history information that we can look at:
f.history

#### There's a link the the 'homepage' in the description, this repository contains the software, `nircam_reftools.py` that was used to create the reference file.
Currently, if you look at that repository, you'll notice that the software was restructured and the functions for making the grism reference files were moved to their own module, this will be reflected in the history of that software is used to make a new reference file for delivery

In [None]:
# what type of reference file is it?
f.reftype

In [None]:
# Not sure what's in the object? You can always look at the full instance 
f.instance

#### if you're creating a new reference file, it's useful to check it by using validate(), this will validate the correctness of the reference object against the schema that specifies what it should contain before writing


In [None]:
f.validate()

#### what happens if we try and change orders from a list of numbers to a string?

In [None]:
f['orders'] = 'wrong'

#### compared to what happens if we retain the expected data type but still change the data:


In [None]:
f['orders'] = [1,2,3]

#### this time it should be happy and return no error, because orders is allowed to be a list of numbers


In [None]:
f.validate()

## Create a transform model
We've just looked at the reference file that will support the GWCS model, it contains the building blocks that the full GWCS will need. 

There are several transform models that support NIRCAM:
* NIRCAMForwardColumnGrismDispersion --> this relates to the Column dispersed grism, it moves inputs in the grism frame to the direct image frame
* NIRCAMForwardRowGrismDispersion --> this relates to the Row dispersed grism, it moves inputs in the grism frame frame to the direct image frame
* NIRCAMBackwardGrismDispersion --> this relates to either Row or Column dispersed grisms, it moves inputs in the direct image frame to the grism frame


#### Hit enter here for the full explanation of the model inputs, the same can be done for the other models

press the 'x' that appears in the upper right of the box to close the information blurb


In [None]:
NIRCAMForwardColumnGrismDispersion?

### We'll create the model by giving it the expected inputs from the reference file we read in above
We already know the from the reference data that it should be used to instantiate the column dispersed grism model from NIRCAM module A:

`'instrument': {'module': 'A', 'name': 'NIRCAM', 'pupil': 'GRISMC'}`

In [None]:
model=NIRCAMForwardColumnGrismDispersion(orders, displ, dispx, invdispy)
back=NIRCAMBackwardGrismDispersion(orders, invdispl, dispx, dispy)

# The inverse transform can be added to the model so that one model object can be used to tranform in both directions
model.inverse = back

In [None]:
model.orders  # these are the spectral orders with which the model can be used 

### What are the inputs that this model expects?

In [None]:
# model should accept (x, y, x0, y0, order)
model.inputs

<strong>The default direction of the model is forwards, from the grism frame to the direct image frame. So here, the inputs relate to:</strong>

- x: the x-pixel location in the grism image
- y: the y-pixel location in the grism image
- x0: the source object x-location in the direct image
- y0: the source object y-location in the direct image
- order: the spectral order of concern

### What are the outputs that this model gives?

In [None]:
# and the model should return (x, y, wavelength, spectral order)
model.outputs

<strong> The ouputs here are in the detector direct image reference frame</strong>
- x: the x-pixel location of the source object in the direct image
- y: the y-pixel location of the source object in the direct image
- wavelength: the wavelength of the corresponding pixel in the grism image
- order: the spectral order the information pertains to

This information allows the model to properly pass and translate image <-> dispsersed image <-> image

<strong>You can see the relation a little better by also looking at the inputs and outputs to the inverse model:</strong>

In [None]:
# direct image frame ->  grism image frame
model.inverse.inputs

In [None]:
# grism image frame -> direct image frame
model.inverse.outputs

### Does the model have a name associated with it?

In [None]:
model.name

## Now for something even more specific....
The WFSS models contain transforms related to x, y and wavelength

The reference file stores list of models that are valid for each spectral order or element.

What the following is telling you is that the first order is contained in the zero index of the model list. This is true for any any model list in the reference file; xmodels, ymodels or lmodels.


In [None]:
# This is used internally to grab the correct transform based on spectral order
model._order_mapping  

### What do the orders and polynomials that we see actually relate to?
The pipeline documentation for the `specwcs` reference files used during the `assign_wcs` step can be found here:

https://jwst-pipeline.readthedocs.io/en/latest/jwst/assign_wcs/reference_files.html

For NIRCAM GRISM and TSGRIM modes the specwcs file is in ASDF format with the following members:

- **displ**:	contains the wavelength-dispersion models
- **dispx**:	contains the x-dispersion models
- **dispy**:	contains the y-dispersion models
- **invdispx**:	contains the inverse x-dispersion models
- **invdispy**:	contains the inverse y-dispersion models
- **invdispl**:	contains the inverse wavelength-dispersion transform models
- **orders**:	a list of order numbers that the models relate to, in the same order as the models

"models" above are actual `astropy.modeling` models. The `asdf` format allows transforms to be saved to file and recreated on load.

In [None]:
# For example, lets look at the wavelength transform models, you should see two models in this list, one for each order
model.xmodels 

#### We can see that both orders take the form of a Polynomial 1D, and if you look closer you'll notice that they are unitary because GRISMC disperses in y.

Lets look at the ymodels:

In [None]:
model.ymodels

In [None]:
model.lmodels  # the wavelength relations

You can read more about the Polynomial1D model here: http://docs.astropy.org/en/stable/api/astropy.modeling.polynomial.Polynomial1D.html

### Now, if I just wanted to use these models by themselves I could, but for the WFSS modes they are really meant to work together.

In [None]:
model.lmodels[0].inputs  # yah, it says x, it basically just wants one value, this is a linear model in x

In [None]:
model.lmodels[0].outputs

#### We can execute the model by giving it an input value:

In [None]:
model.lmodels[0](1.5)

#### The results should be zero because this is the column model:

In [None]:
 model.xmodels[0](10)

#### This calculation should yield a non-zero result:

In [None]:
model.ymodels[0](10)

### Some of the models gave us a number, what do they mean?
Here's a longer explanation of the models used in the WFSS modes:

The models originate from a history of slitless spectroscopy for HST, in packages such as <a href="http://axe-info.stsci.edu/extract_calibrate">aXe</a> and <a href="https://github.com/gbrammer/grizli">Grizli</a>. They have been reformulated for use with JWST, the details of which can be found in this document: http://www.stsci.edu/hst/wfc3/documents/ISRs/WFC3-2017-01.pdf  as well as the code here: https://github.com/npirzkal/GRISMCONF


For example, the `lmodel` above accepts a wavelength as input, and returns the normalized location along the trace for that wavelength. This can then be used as input into the xmodel and ymodel to find the exact pixel location of that wavelength, knowing the starting (x0,y0) location of the source object from the direct image associated with the dispersed image. This interaction is what the large GWCS model that is created using the reference file information is helping to facilitate.

### Let's step through a more detailed example of how this works, we're still just using the pieces we've already examined to this point. Afterwards, we'll create the full GWCS model and show the same procedure, this time the model will translate all the way through to world coordinates.
We'll use the Backwards transform, going from an object position taken from the direct image to the pixel location of a specific wavelength associated with that object

#### Given pixel location (110,110) in the dispersed image that relates to pixel in the direct image at (100,100) for the first order, what is the wavelength? 

In [None]:
# (x, y, x0, y0, order) --> (x0, y0, lam, order)
x0, y0, wavelength, order = model(100, 100, 110, 110, 1)
print(x0, y0, wavelength, order)

**Why does the user need to input the coordinates of the object in the direct image?**
This was done for two reasons:
- the user may want to specify directly the object they think the pixel in the dispersed image is associated with in the direct image
- the `extract_1d` code that the pipeline uses works by asking for the wavelength of the pixel in the grism image and looping over pixels. By the time we get to this point for the WFSS pipelines, the code is using cutouts specific to the source and order, the position of the source in full frame coordinates, and the spectral order of the cutout is carried through the process, so it merely asks "for each dispersed image pixel, what's the wavelegth"

#### what about using that information to go the other direction?

In [None]:
# (x0, y0, lam, order) --> (x, y, x0, y0, order)?
model.inverse.inputs

In [None]:
x, y, x0, y0, order = model.inverse(x0, y0, wavelength, order)
print(x, y, x0, y0, order)

*One thing to note above, the x-location in the output is somewhat arbitrary because the dispersion is along the column and the user can choose the x-width to use, the location of the x-pixel on input is returned for reference.*

### Now we'll contruct the full transform chain by hand, this is what the full transform model's evaluate() function is doing with the model inputs from the reference file inside the pipeline

In [None]:
# REM
print(order, wavelength, x0, y0)

In [None]:
# this combination will translate from the direct image to the dispersed image, what is defined as the backwards direction
iorder = model._order_mapping[int(order)]
t = model.inverse.lmodels[iorder](float(wavelength))
dx = model.inverse.xmodels[iorder](float(t))
dy = model.inverse.ymodels[iorder](float(t))
print(x0+dx, y0+dy, x0, y0, order)

In [None]:
# this combination will translate from the dispersed image to the direct image, the forwards direction
x = 100.
y = 100.
x0 = 110.
y0 = 110.
t = model.ymodels[iorder](y-y0)
dx = model.xmodels[iorder](t)
wavelength = model.lmodels[iorder](t)

print(x0+dx, y0, wavelength, order)

## Now that we've gone over how the models work, we can create our larger model chain that will form the GWCS object and translate all the way to the sky coordinate frame

First we'll import the modules we need:

In [None]:
import gwcs.coordinate_frames as cf
from astropy import units as u
from astropy import coordinates as coord

#### Now we'll create the grism detector frame using `coordinate_frames`, the reference frame has units of pixels since it represents a detector

In [None]:
gdetector = cf.Frame2D(name='grism_detector',
                       axes_order=(0, 1),
                       unit=(u.pix, u.pix))
gdetector

### Now we need to create the transform model that takes us from the direct image pixel location to the sky coordinate frame
For the NIRCAM WFSS modes, this goes through the imaging wcs pipeline. The pipeline knows to use the information from the dispersed exposure, including the filter element that was used, to get the correct distortion solution. 

In order to demonstrate use of the modules outside of STPIPE, I'm going to call the `jwst.assign_wcs` function directly to create the wcs model I need.  

In [None]:
from jwst.assign_wcs import nircam

In [None]:
# open up the grism image, this will be fed to the module for reference
grism_image = image.ImageModel(grism_file)

In [None]:
grism_image.meta.instrument.filter, grism_image.meta.instrument.pupil, grism_image.meta.instrument.detector,grism_image.meta.exposure.type


#### You'll notice I'm using a row dispersed grism image here. There are reference files which are independent of the grism, this will only be important at the end of this notebook when we call the full `assign_wcs` on the image because it will use a different `specwcs` reference file. Most of the examples above worked with the GRISMC `specwcs` reference file and the non-grism dependent reference files.

### Reference file types and retrieval

The module will need some reference file information. CRDS returns different reference file types for NRC_IMAGE and NRC_GRISM, we need the imaging mode reference files for part of the grism pipeline. We've told CRDS to assign the same distortion file to both EXP_TYPES so they should match.

In [None]:
# create the step object
assign_wcs_step = assign_wcs.AssignWcsStep()

In [None]:
# get reference files we need from CRDS
distortion = assign_wcs_step.get_reference_file(grism_image, 'distortion')  # distortion is independent of grism
wavelengthrange = assign_wcs_step.get_reference_file(grism_image,'wavelengthrange')  # independent of grism
specwcs = specwcs  # use our grismc local reference file 
print("Reference files for:\nspecwcs: {}\nwavelengthrange: {}\ndistortion: {}".format(specwcs, wavelengthrange, distortion))


#### we'll use the CRDS reference files, except for specwcs, where we'll use our example file.

I've only populated some of the reference files because those are the only ones we need for WFSS. When testing locally, you can specify any reference filename here.

In [None]:
reference_file_names = {'camera': 'N/A',
 'collimator': 'N/A',
 'disperser': 'N/A',
 'distortion': distortion,
 'filteroffset': 'N/A',
 'fore': 'N/A',
 'fpa': 'N/A',
 'ifufore': 'N/A',
 'ifupost': 'N/A',
 'ifuslicer': 'N/A',
 'msa': 'N/A',
 'ote': 'N/A',
 'regions': 'N/A',
 'specwcs': specwcs,
 'wavelengthrange': wavelengthrange}

In [None]:
image_pipeline=nircam.imaging(grism_image, reference_file_names)
image_pipeline 

In [None]:
# there should be 3 models that are returned in this list
for m in image_pipeline:
    print("model: {}\n\n".format(m))


#### We need to insert the models in the correct places, so lets take note of the celestial frame

In [None]:
# keep the world reference frame which should be at the end
world = image_pipeline.pop()
world

#### Next, we need to make sure ra, dec, wavelength, and the spectral order get passed through the backwards transforms and are returned from the forwards transforms. To do this, we're going to deconstruct the imaging pipeline and add it to our grism pipeline.

These modules will help us:

In [None]:
from astropy.modeling.models import Scale, Identity, Mapping, Const1D

#### GWCS will accept a chain of (reference frame, transform) tuples, so we'll create the first one, for the grism detector reference frame here. This is the model you saw earlier that moved coordinates from the dispersed image to the direct image

In [None]:
grism_pipeline = [(gdetector, model)]

In [None]:
# deconstruct the imaging pipeline and append to the grism pipeline
imagepipe = []
for cframe, trans in image_pipeline:
    trans = trans & (Identity(2))
    imagepipe.append((cframe, trans))
imagepipe.append((world))
grism_pipeline.extend(imagepipe)

In [None]:
# This should show there are 4 models in the chain
for m in grism_pipeline:
    print("model: {}\n\n".format(m))

### Now we can create the full GWCS model

In [None]:
from gwcs.wcs import WCS

In [None]:
grism_wcs = WCS(grism_pipeline)

#### The GWCS model we now have is set up to translate from the dispersed image, to the direct image, to the sky

In [None]:
grism_wcs.input_frame

In [None]:
grism_wcs.output_frame

In [None]:
grism_wcs.available_frames

### Let's try out our new complete model and make sure it returns the same results as we found previously

In [None]:
# REM
print(x0, y0, wavelength, order)

In [None]:
grism_wcs(100,100,110,110,1)  # should return (ra, dec, wavelength, order)

### The GWCS object allows us to extract sets of transforms, let grab the grism-->direct image transform so we can make a direct comparison with previous results that only had the pixel->pixel translation models

In [None]:
g2d = grism_wcs.get_transform('grism_detector','detector')  # forwards

In [None]:
g2d(100,100,110,110,1)  # forwards

In [None]:
g2d.inverse(110.0, 110.0, 3.9269599999722775, 1.0)  # backwards

### We can pull the transform that goes from the sky --> direct image reference frame

In [None]:
w2d = grism_wcs.get_transform('world','detector')  # this is part of the backwards transform

In [None]:
# pick starting values 
ra = -0.04476397838420304  # deg
dec = 0.14901165082101248  # deg
wave = 1.0
order = 1.0
x0, y0, wave, order = w2d(ra, dec, wave, order)
print(x0, y0, wave, order)

#### Do the values roundtrip?

In [None]:
d2w = grism_wcs.get_transform('detector','world')
n_ra, n_dec, n_wave, n_order = d2w(x0, y0, wave, order)

In [None]:
print("differences:\nra: {}\ndec: {}\nwave: {}\norder: {}".format(ra-n_ra, dec-n_dec, wave-n_wave, order-n_order))

#### If the differences are not zero, there there is something going on with either:
* the models that translate from the imaging detector pixels, through the distortion solution and the telescope pointing to the sky
* the values in the header of the image that was used for the example
* the reference files in use

## Now let's tie everything together and use the pipeline to assign a WCS to a grism image

In [None]:
from jwst.assign_wcs.assign_wcs import load_wcs

In [None]:
grism_image = image.ImageModel(grism_file)
grism_image.meta.instrument.filter, grism_image.meta.instrument.pupil, grism_image.meta.instrument.detector,grism_image.meta.exposure.type


In [None]:
# This time we'll get all correct reference files from CRDS
distortion = assign_wcs_step.get_reference_file(grism_image, 'distortion')  # distortion is independent of grism
wavelengthrange = assign_wcs_step.get_reference_file(grism_image,'wavelengthrange')  # independent of grism
specwcs = assign_wcs_step.get_reference_file(grism_image,'specwcs')  # we should see the other reference file for GRISMR now
reference_file_names['specwcs'] = specwcs
reference_file_names['wavelengthrange'] = wavelengthrange
reference_file_names['distortion'] = distortion
reference_file_names

In [None]:
doesitblend = load_wcs(grism_image, reference_file_names)

In [None]:
# It blends!
doesitblend.meta.wcs(100,100,110,110,1)  # inputs of x,y,x0,y0,order

In [None]:
doesitblend.meta.wcs.available_frames

### Now we'll fake out the GRISM reference in our test image so you can see the comparative results from the column grism

In [None]:
grism_image.meta.instrument.pupil = 'GRISMC'

In [None]:
distortion = assign_wcs_step.get_reference_file(grism_image, 'distortion')  # distortion is independent of grism
wavelengthrange = assign_wcs_step.get_reference_file(grism_image,'wavelengthrange')  # independent of grism
specwcs = assign_wcs_step.get_reference_file(grism_image,'specwcs')  # now the GRISMC reference file should show up again
reference_file_names['specwcs'] = specwcs  
reference_file_names['wavelengthrange'] = wavelengthrange
reference_file_names['distortion'] = distortion
reference_file_names

In [None]:
doesitblend = load_wcs(grism_image, reference_file_names)

In [None]:
# It blends!
ra, dec, wave, order = doesitblend.meta.wcs(100,100,110,110,1)  # inputs of x,y,x0,y0,order
print(ra, dec, wave, order )

In [None]:
inverse = doesitblend.meta.wcs.get_transform("world","grism_detector")

In [None]:
inverse(ra, dec, wave, order)  # outputs of x, y, x0, y0, order

#### Almost, you can see how far off the rountrip is now in pixel values, where above we showed the difference in ra/dec values in the sky frame

### Here's a programmatic look at the full GWCS object and transforms it contains

In [None]:
doesitblend.meta.wcs

## Finally, wan't to save the image with the full GWCS to a new file?
I'm saving this to a filename that reminds us we messed with the actual grism specification.

In [None]:
doesitblend.write('nircam_bad_grism_spec.fits')

## Want to make sure it's all really still there?

In [None]:
check_image = image.ImageModel('./nircam_bad_grism_spec.fits')

In [None]:
check_image.meta.wcs