# Exploring the TSGRISM mode for NIRCAM
### Time Series Grism observations which can be done with subarrays

### Make sure that you have set the JWST_NOTEBOOK_DATA environment variable in the terminal from which you started Jupyter Notebook.

The data will be read from that directory, and the pipeline should write to the current working directory, avoiding clobbers.
If you would like to use your own data just substitute the locations below.

In [None]:
notebook_dir = os.environ['JWST_NOTEBOOK_DATA']
nircam_data = notebook_dir + 'nircam/'

In [None]:
# plotting, the inline must come before the matplotlib import
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib.patches as patches

params = {'legend.fontsize': 6,
          'figure.figsize': (8, 8),
          'figure.dpi': 150,
         'axes.labelsize': 6,
         'axes.titlesize': 6,
         'xtick.labelsize':6,
         'ytick.labelsize':6}
plt.rcParams.update(params)


# python general
import os
import sys
import numpy as np

# astropy modules
from astropy.io import fits
from astropy.table import QTable
from astropy.wcs.utils import skycoord_to_pixel
import photutils

# jwst 
import jwst
from jwst.datamodels import image, CubeModel
from jwst.assign_wcs import nircam

print("Using jwst version: {}".format(jwst.__version__))
print(sys.version)

In [None]:
tsgrism_image=nircam_data + 'jw12345001001_01101_00001_nrcalong_rateints.fits'

In [None]:
gim=CubeModel(tsgrism_image)

In [None]:
gim.meta.instrument.pupil, gim.meta.instrument.filter,gim.meta.exposure.type, gim.meta.instrument.module

In [None]:
gim.shape

In [None]:
gim.wavelength

In [None]:
gim.meta.wcsinfo.crpix1, gim.meta.wcsinfo.crpix2, gim.meta.wcsinfo.crval1, gim.meta.wcsinfo.crval2, gim.meta.wcsinfo.v2_ref, gim.meta.wcsinfo.v3_ref, gim.meta.wcsinfo.v3yangle, gim.meta.wcsinfo.roll_ref

## Take a look at the image we have

In [None]:
ys,xs=gim.shape[1:]
fig = plt.figure(figsize=(8,8), dpi=150)
ax = fig.add_subplot(1, 2, 1)
ax.set_title(tsgrism_image.split("/")[-1]+"[0]", fontsize=8)
ax.imshow(gim.data[0], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title(tsgrism_image.split("/")[-1]+"[1]", fontsize=8)
ax2.imshow(gim.data[1], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

fig.tight_layout()

### Get the reference files for use with our image

In [None]:
step=assign_wcs.AssignWcsStep()
distortion=step.get_reference_file(gim,'distortion')
specwcs=step.get_reference_file(gim, 'specwcs')
reference_file_names = {'camera': 'N/A',
 'collimator': 'N/A',
 'disperser': 'N/A',
 'distortion': distortion,
 'filteroffset': 'N/A',
 'fore': 'N/A',
 'fpa': 'N/A',
 'ifufore': 'N/A',
 'ifupost': 'N/A',
 'ifuslicer': 'N/A',
 'msa': 'N/A',
 'ote': 'N/A',
 'regions': 'N/A',
 'specwcs': specwcs,
 'v2v3': 'N/A',
 'wavelengthrange': 'N/A'}

In [None]:
reference_file_names

### We need to do the subarray transform on either side of the grism trace. This is blurred in the grism() assign_wcs because the regular grism mode doesn't allow subarrays.

In [None]:
print(gim.meta.subarray.name)
tsgrism_pipeline=nircam.tsgrism(gim,reference_file_names)  # create the wcs pipeline

In [None]:
tsgrism_pipeline

In [None]:
from gwcs.wcs import WCS
tswcs=WCS(tsgrism_pipeline)

#### This will detail the available transformation frames

In [None]:
tswcs.available_frames

### In grism time-series mode, the Module A Grism R is used to disperse the target's spectrum along (parallel to) detector rows. The grism is used in conjunction with one of 4 wide filters in the long wavelength channel (2.4–5.0 µm): F277W, F322W2, F356W, and F444W.

In [None]:
tswcs  # the three inputs are x, y, order

### Translate pixel 100,100 in the tso grism image for source location crpix1, crpix2, and order 1

In [None]:
tswcs(100,100,1)  # returns ra, dec, wave, order

In [None]:
gim.get_fits_wcs()
#gim.meta.wcsinfo.crval1, gim.meta.wcsinfo.crval2, gim.meta.wcsinfo.crpix1, gim.meta.wcsinfo.crpix2

#### Change the location. The wavelength should change, but the ra, dec should be the same

In [None]:
tswcs(200,200,1)

#### This should return the same x value and CRPIX2 for y, since the transforms for NIRCAM return 0 for GRISMR in the y so that the extraction box can be chosen around that

In [None]:
tswcs.invert(3.2472519999724887, 1.0)

#### This should return error about and invalid order since order 3 is not available

In [None]:
tswcs(100,100,3)

In [None]:
tswcs.input_frame, tswcs.output_frame

In [None]:
tswcs.get_transform('world','v2v3')

In [None]:
tswcs.get_transform('v2v3', 'world')

# Now go back to the image and see if we can get extract_2d to go to the correct place

In [None]:
ys,xs=gim.shape[1:]
fig = plt.figure(figsize=(8,8), dpi=150)
ax = fig.add_subplot(1, 1, 1)
ax.set_adjustable('box-forced')
ax.set_title(tsgrism_image.split("/")[-1]+"[0]", fontsize=8)
ax.imshow(gim.data[0], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

fig.tight_layout()

In [None]:
from jwst.extract_2d import extract_2d_step, extract_2d

In [None]:
step=extract_2d_step.Extract2dStep()
reference_file_names = {'camera': 'N/A',
 'collimator': 'N/A',
 'disperser': 'N/A',
 'distortion': step.get_reference_file(gim,'distortion'),
 'filteroffset': 'N/A',
 'fore': 'N/A',
 'fpa': 'N/A',
 'ifufore': 'N/A',
 'ifupost': 'N/A',
 'ifuslicer': 'N/A',
 'msa': 'N/A',
 'ote': 'N/A',
 'regions': 'N/A',
 'specwcs':  step.get_reference_file(gim,'specwcs'),
 'v2v3': 'N/A',
 'wavelengthrange': step.get_reference_file(gim, 'wavelengthrange')}

In [None]:
reference_file_names

## Set the wavelengthrange reference file to the one we updated, but hasn't yet been accepted to CRDS

In [None]:
# run assign_wcs step on the image to attach the gwcs object we created
tso_wcs_assigned = assign_wcs.assign_wcs.load_wcs(gim, reference_files=reference_file_names)

In [None]:
tso_wcs_assigned.meta.instrument.filter

In [None]:
from jwst.datamodels import WavelengthrangeModel
wrm = WavelengthrangeModel(reference_file_names['wavelengthrange'])

In [None]:
print(wrm.wavelengthrange)
print()
print(wrm.waverange_selector)

In [None]:
import asdf
test_read=asdf.open(reference_file_names['wavelengthrange'])
test_read.tree

In [None]:
with WavelengthrangeModel(reference_file_names['wavelengthrange']) as f:
    if (f.meta.instrument.name != 'NIRCAM'):
        raise ValueError("Wavelengthrange reference file not for NIRCAM!")
    wavelengthrange = f.wavelengthrange
    waverange_selector = f.waverange_selector
    orders = f.order
    extract_orders = f.extract_orders
    print(f.meta.instrument.name)

print(wavelengthrange)
range_select = [(x[2], x[3]) for x in wavelengthrange if (x[0] == 1 and x[1] == 'F444W')]
lmin, lmax = range_select.pop()
print(lmin, lmax)

In [None]:
from jwst.datamodels import SpecwcsModel
print(reference_file_names['specwcs'])
spec=SpecwcsModel(reference_file_names['specwcs'])


In [None]:
spec.instance

In [None]:
x2d = extract_2d.extract_tso_object(tso_wcs_assigned, reference_files=reference_file_names, extract_height=64)

In [None]:
x2d.instance

In [None]:
x2d.data.shape

In [None]:
x2d.meta.wcs

In [None]:
x2d.meta.wcs(887, 35)

In [None]:
x2d.meta.wcs.invert(4.586513215999723, 1.0)

In [None]:
x2d.meta.wcsinfo.instance

In [None]:
x2d.ysize, x2d.xsize, x2d.shape, gim.shape

In [None]:
x2d.wavelength

In [None]:
ys,xs=gim.shape[1:]
fig = plt.figure(figsize=(10,10), dpi=150)
ax3 = fig.add_subplot(1, 3, 1)
ax3.set_title(tsgrism_image.split("/")[-1]+"[0]", fontsize=8)
ax3.imshow(gim.data[0], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

zs, ys, xs = x2d.data.shape
fig = plt.figure(figsize=(10,10), dpi=150)
ax = fig.add_subplot(1, 3, 2)
xpos, ypos = x2d.source_xpos, x2d.source_ypos
title = x2d.meta.instrument.filter+" order {0}\nx={1} y={2}".format(x2d.meta.wcsinfo.spectral_order,
                                                                    xpos,
                                                                    ypos)
ax.set_title(title, fontsize=8)
ax.imshow(x2d.data[0,:,:], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)


ax2.set_title(title, fontsize=8)
ax2.imshow(x2d.data[0,:,:], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

fig.tight_layout()

In [None]:
tso_wcs_assigned.data.shape

In [None]:
x2d.data.shape

In [None]:
x2d.source_xpos, x2d.source_ypos

In [None]:
x2d.xsize, x2d.ysize

## We should also be able to call the extraction and override the extract_orders for filter, this should produce a SlitModel output for just the order specified

In [None]:
x2d_single = extract_2d.extract_tso_object(tso_wcs_assigned, reference_files=reference_file_names, extract_orders=[1])

In [None]:
x2d_single.meta.model_type  # should be SlitModel this time

In [None]:
x2d_single.data.shape # should exist and be 3D

In [None]:
x2d_single.meta.wcsinfo.spectral_order

In [None]:
x2d_single.meta.wcs

In [None]:
ys,xs=gim.shape[1:]
fig = plt.figure(figsize=(10,10), dpi=150)
ax3 = fig.add_subplot(1, 3, 1)
ax3.set_title(tsgrism_image.split("/")[-1]+"[0]", fontsize=8)
ax3.imshow(gim.data[0], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

zs, ys, xs = x2d_single.data.shape
fig = plt.figure(figsize=(10,10), dpi=150)
ax = fig.add_subplot(1, 3, 2)
xpos, ypos = x2d_single.source_xpos, x2d_single.source_ypos
title = x2d.meta.instrument.filter+" order {0}\nx={1} y={2}".format(x2d_single.meta.wcsinfo.spectral_order,
                                                                    xpos,
                                                                    ypos)
ax.set_title(title, fontsize=8)
ax.imshow(x2d_single.data[0,:,:], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)


## Check that height can be specified

In [None]:
from jwst.extract_2d.grisms import extract_tso_object

### The source should always be returned at pixel 34, even if the extract height doesn't allow for that.

In [None]:
x2d_single = extract_tso_object(tso_wcs_assigned, reference_files=reference_file_names, extract_orders=[1], extract_height=50)

In [None]:
x2d_single.meta.model_type  # should be SlitModel this time

In [None]:
x2d_single.data.shape # should exist and be 3D

In [None]:
x2d_single.meta.wcsinfo.spectral_order, x2d_single.ysize, x2d_single.xsize

In [None]:
ys,xs=gim.shape[1:]
fig = plt.figure(figsize=(10,10), dpi=150)
ax3 = fig.add_subplot(1, 3, 1)
ax3.set_title(tsgrism_image.split("/")[-1]+"[0]", fontsize=8)
ax3.imshow(gim.data[0], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)

zs, ys, xs = x2d_single.data.shape
fig = plt.figure(figsize=(10,10), dpi=150)
ax = fig.add_subplot(1, 3, 2)
xpos, ypos = x2d_single.source_xpos, x2d_single.source_ypos
title = x2d.meta.instrument.filter+" order {0}\nx={1} y={2}".format(x2d_single.meta.wcsinfo.spectral_order,
                                                                    xpos,
                                                                    ypos)
ax.set_title(title, fontsize=8)
ax.imshow(x2d_single.data[0,:,:], origin='lower', extent=[0,xs,0,ys], vmin=-3, vmax=3)


In [None]:
x2d.write('blah.fits')  # extract2d saves a SlitModel

In [None]:
from jwst import datamodels
blah=datamodels.open('blah.fits')

In [None]:
blah  # should have been read in as a SlitModel