## WFC3 grism model testing

See https://github.com/nden/documentation/blob/master/grisms/JWST_Grisms.ipynb for the JWST version that this notebook is based on.

In [None]:
import asdf
import numpy as np

from dispersion_models import DISPXY_Model, DISPXY_Extension
from transform_models import (WFC3IRForwardGrismDispersion,
                               WFC3IRBackwardGrismDispersion)
from astropy.modeling.models import *
from astropy.modeling.models import math as astmath

from jwst import datamodels
from jwst.assign_wcs import util
from jwst.assign_wcs import nircam

from gwcs import coordinate_frames as cf
from astropy import units as u

In [None]:
# Add the asdf extension for the custom dispersion models
asdf.get_config().add_extension(DISPXY_Extension())

In [None]:
specwcs = asdf.open('wfc3_ir_specwcs.asdf').tree
displ = specwcs['displ']
dispx = specwcs['dispx']
dispy = specwcs['dispy']
invdispl = specwcs['invdispl']
invdispx = specwcs['invdispx']
invdispy = specwcs['invdispy']
orders = specwcs['order']

print('orders', orders)
print('dispersion_wavelength', displ)
print('dispersion_x', dispx)

In [None]:
gdetector = cf.Frame2D(name='grism_detector', 
                       axes_order=(0, 1),
                       unit=(u.pix, u.pix))

In [None]:
det2det = WFC3IRForwardGrismDispersion(orders,
                                          lmodels=displ,
                                          xmodels=dispx,
                                          ymodels=dispy)

In [None]:
det2det.inverse = WFC3IRBackwardGrismDispersion(orders,
                                                lmodels=invdispl,
                                                xmodels=dispx,
                                                ymodels=dispy)

In [None]:
grism_pipeline = [(gdetector, det2det)]

In [None]:
from gwcs import WCS

wcsobj = WCS(grism_pipeline)
print(wcsobj)

In [None]:
x0=917
y0=800
order=1
x=919
y=806

In [None]:
det2det.evaluate(x, y, x0, y0, order)

In [None]:
# Using the output of the previous cell as input
det2det.inverse.evaluate(917.0, 800.0, 0.9143279451180034, 1.0)

In [None]:
# Using the output of the previous cell as input 
det2det.evaluate(919.0, 800.584437342649, 917.0, 800.0, order)