# Techniques to Simulate WFSS Spectra

This notebook demonstrate basic techniques to simulate WFSS dispersed spectra. The method is relatively straight-foward and relies on  using an image of a source to determine which pixels contain the source and how bright they each are. This information is then used to compute the flux of each pixel in flam units ($erg/s/cm^2/A$). The World Coordinate System (WCS) of the imaging data and WFSS observation is then used to compute the location of each pixel in the reference france of the dispersed WFSS observation. We then use the GRISMCONF functions and a wavelength vector to find where each of these wavelength segments end up being dispersed to on the WFSS observation. Since the detector pixel grid is unlikely to align with the computed coordinates in the WFSS observation, we use the Sutherland-Hodgman algorithm <a href="https://github.com/spacetelescope/pypolyclip">Sutherland-Hodgman algorithm</A> to quickly compute the fraction of each projected dispersed pixel that falls on the actuall WFSS observed pixels. 

## Pre-Requisites

This notebook builds on the simpler Box Extraction notebook where we introduced the general concepts of spectral extraction, as well as the use of the <A HREF="https://github.com/npirzkal/GRISMCONF">GRISMCONF</A> module, which provides us with a low level interface to the calibration polynomial for WFSS modes.

This notebook uses the GRISMCONF module and GRISMCONF NIRCam Configuration files. 
* GRISMCONF can be obtained from <a href="https://github.com/npirzkal/GRISMCONF">here</a>. It can also be installed using the command "pip install grismconf"
* NIRCam WFSS configuration files can be obtained from <a href="https://github.com/npirzkal/GRISM_NIRCAM">here</a>. V9 of the configuration files were delivered to STScI in the Summer of 2023 and represent the latest version of the NIRCam WFSS Calibration as of the writing of this document. All of the files in the V9 sub-directory should be manually downloaded and stored somewhere locally, e.g. ./GRISMDATA

In addition to the standard numpy, astropy, scipy, and matplotlib packages, this notebook also uses the ```jwst``` pipeline package, the ```pypolyclip``` package for clip polygons against a pixel grid, ```nf9```, a high level interface to SAOImageDS9 using pyds9(), and ```tqdm```, a convenience tool that returns an iterator that acts exactly like the original iterable, but prints a dynamically updating progressbar every time a value is requested. 

To install the ```jwst``` package, follow instructions here: [jwst package installation](https://github.com/spacetelescope/jwst/tree/master?tab=readme-ov-file#detailed-installation)

To install the other packages, use ```pip```, e.g.:

```pip install pypolyclip```

## Imports

In [None]:
import os
import requests
from copy import deepcopy

import numpy as np
from pypolyclip import clip_multi
from astropy.io import fits
from astropy.convolution import convolve
from scipy.sparse import coo_matrix
from photutils.segmentation import make_2dgaussian_kernel
from photutils.segmentation import detect_sources
from photutils.background import Background2D, MedianBackground
import matplotlib.pyplot as plt
import tqdm
import grismconf

from jwst import datamodels
from jwst.assign_wcs import AssignWcsStep
from jwst.flatfield import FlatFieldStep
from jwst.photom import PhotomStep

## Define Functions and Parameters

Define the directory containing the WFSS configuration files

In [None]:
#cPath = "./GRISMDATA"
cPath = "/grp/jwst/wit/nircam/reference_files/specwcs/V9/V9"

Define a function to download a named file via the MAST API to the current directory. The function includes authentication logic, but this example uses public data, so no MAST API token is required.

In [None]:
def get_jwst_file(name, mast_api_token=None, overwrite=False):
    """Retrieve a JWST data file from MAST archive."""
    # If the file already exists locally, don't redownload it, unless the
    # user has set the overwrite keyword
    if os.path.isfile(name):
        if not overwrite:
            print(f'{name} already exists locally. Skipping download.')
            return
        else:
            print(f'{name} exists locally. Re-downloading.')

    mast_url = "https://mast.stsci.edu/api/v0.1/Download/file"
    params = dict(uri=f"mast:JWST/product/{name}")
    if mast_api_token:
        headers = dict(Authorization=f"token {mast_api_token}")
    else:
        headers = {}
    r = requests.get(mast_url, params=params, headers=headers, stream=True)
    r.raise_for_status()
    with open(name, "wb") as fobj:
        for chunk in r.iter_content(chunk_size=1024000):
            fobj.write(chunk)

Define a function that will run the Assign WCS and Flat Field steps of the pipeline on an input rate file

In [None]:
def run_pipeline_steps(filename):
    """Run assign_wcs, followed by flat fielding"""
    assign_wcs = AssignWcsStep.call(filename)

    # In order to apply the imaging mode flat field reference file to the data,
    # we need to trick CRDS by temporarily changing the pupil value to be CLEAR
    reset_pupil = False
    if 'GRISM' in assign_wcs.meta.instrument.pupil:
        true_pupil = deepcopy(assign_wcs.meta.instrument.pupil)
        assign_wcs.meta.instrument.pupil = 'CLEAR'
        reset_pupil = True

    # Run the flat field step
    flat = FlatFieldStep.call(assign_wcs, save_results=True)

    # Set the pupil back to the original value now that flat fielding is complete
    if reset_pupil:
        flat.meta.instrument.pupil = true_pupil
        flat.save(flat.meta.filename)

    #Return the name of the output file, as well as the datamodel
    return flat.meta.filename, flat

## Download the Data

We start with a simple pair of imaging and WFSS data. These were manually selected and they point at the same field on the sky using the same NIRCam module, channel, and cross filter.

In [None]:
# First, download the imaging and WFSS files from MAST
imaging_file = "jw01076109001_02102_00001_nrcalong_cal.fits"
wfss_file = "jw01076109001_02101_00001_nrcalong_rate.fits"
get_jwst_file(imaging_file)
get_jwst_file(wfss_file)

## Run the Assign WCS and Flat Field Steps

Using the calibration pipeline, we run the [Assign WCS](https://jwst-pipeline.readthedocs.io/en/latest/jwst/assign_wcs/index.html) step to get the WCS object for coordinate transformations, followed by the [Flat Field](https://jwst-pipeline.readthedocs.io/en/latest/jwst/flatfield/index.html) correction for our science data.

In [None]:
wfss_flat_file, wfss_data = run_pipeline_steps(wfss_file)

## Load the Data into Models

Read some information from the imaging data and WFSS data. We need to know which module, channel, cross filter, and grism we are looking at. We also need to find the values needed to convert the surface brightness units of the imaging cal files into units of $erg/s/cm^2/A$.

In [None]:
image_model = datamodels.open(imaging_file)
imaging_data = image_model.data

wfss_model = datamodels.open(wfss_flat_file)

In [None]:
FILTER = image_model.meta.instrument.filter
MODULE = image_model.meta.instrument.module
PUPIL = image_model.meta.instrument.pupil
PHOTUJA2 = image_model.meta.photometry.conversion_microjanskys
PIXAR_SR = image_model.meta.photometry.pixelarea_steradians

print(f"IMAGING FILTER: {FILTER}, MODULE: {MODULE}, PUPIL: {PUPIL}")
print(f"Pixel size: {PIXAR_SR} SR")

In [None]:
WFSS_FILTER = wfss_model.meta.instrument.filter
WFSS_MODULE = wfss_model.meta.instrument.module
WFSS_PUPIL = wfss_model.meta.instrument.pupil

print(f"WFSS FILTER: {FILTER}, MODULE: {MODULE}, PUPIL: {PUPIL}")

## Convert from $MJy/SR$ to $erg/s/cm^2$

We compute the conversion between pixel values in the imaging data, which are in MJy/SR, and flam units of $erg/s/cm^2$ (per pixel). Multiplying the values in our calibrated image file by this value is what we will need to determine the flam values of each of the pixels in an object detected in our imaging data.

In [None]:
PHOTFLAM = PIXAR_SR * 1e6 /3.3356e4/44210**2

## Create a Segmentation Map

Since we want to disperse each of the pixels comprising a given source, we first need to get a segmentation map of all the objects. This can be done relatively easily using the ```photutils``` package.

In [None]:
bkg_estimator = MedianBackground()
bkg = Background2D(imaging_data, (50, 50), filter_size=(21, 31),bkg_estimator=bkg_estimator)
imaging_data -= bkg.background 

In [None]:
threshold = 50 * bkg.background_rms

In [None]:
kernel = make_2dgaussian_kernel(3.0, size=5)
convolved_data = convolve(imaging_data, kernel)

In [None]:
segment_map = detect_sources(convolved_data, threshold, npixels=10)

In [None]:
plt.imshow(segment_map,origin="lower")

## Simulate the Dispersion

Here, we want to show how to simulate the dispersion of only one source, so we pick one. In order to simulate a full WFSS observation, what we show here needs to be done for every source in the field. Simulating all the dispersed spectra is also a way to mask out spectra when estimating the dispersed background level during subsequent extraction. It also allows you to estimate the amount of spectral contamination by overlapping spectra.

Below, we find where our sources are in the segmentation map. For this example, we manually choose a source. However, we also provide an option to choose a random source for the example. 

In [None]:
# # find source 47
# find_sources = np.argwhere(segment_map.data == 47) # Indices where board == 0
# indices = np.ravel_multi_index([find_sources[:, 0], find_sources[:, 1]], segment_map.data.shape) 
# ID = segment_map.data[np.unravel_index(indices[0], segment_map.data.shape)]
# print(f"We picked object {ID}")

# # choose a random source
# find_sources = np.argwhere(segment_map.data != 0) # Indices where board == 0
# indices = np.ravel_multi_index([find_sources[:, 0], find_sources[:, 1]], segment_map.data.shape) 
# random_source = np.random.choice(indices)
# ID = segment_map.data[np.unravel_index(random_source, segment_map.data.shape)]
# print(f"We picked object {ID}")

# # choose a pixel coordinate
# xd,yd = 405,1465
xd,yd = 1575,89
ID = segment_map.data[yd,xd]
print(f"We picked object {ID}")

Get the pixel coordinates and their flux values (in $Mjy/SR$) for this source:

In [None]:
ok = segment_map.data == ID
yds,xds = np.nonzero(ok)
cds = imaging_data[ok]

Check what this source looks like and plot its segmentation map in the imaging data.

In [None]:
min_x = np.min(xds)
max_x = np.max(xds)
min_y = np.min(yds)
max_y = np.max(yds)

fig,axs = plt.subplots(1,2,figsize=(15,5))
axs[0].imshow(imaging_data[min_y:max_y,min_x:max_x],origin="lower")
axs[1].imshow(segment_map.data[min_y:max_y,min_x:max_x],origin="lower");

All the information we have for this source is within the reference frame of the imaging data, but we want to know where each of these pixels are in the WFSS observation, which could be at slightly different pointings. As we did when performing a basic box extraction (see the NIRCam WFSS Box Extraction notebook), this is handled using the WCS of both imaging and WFSS observations -- but this time we compute the positions of all of the pixels in the source and not simply the position of the peak of this source.

In [None]:
imaging_to_world = image_model.meta.wcs.get_transform('detector','world')

In [None]:
wfss_to_pix = wfss_model.meta.wcs.get_transform('world','detector')

We compute the R.A. and Dec of each of the input pixels:

In [None]:
ras,decs = imaging_to_world(xds,yds)

We can now compute the center coordinates of the imaging pixels in the WFSS data:

In [None]:
wavelength, order = 2.5, 0
xs,ys,wav,ord = wfss_to_pix(ras,decs,wavelength,order)

In [None]:
plt.scatter(xs,ys)
plt.xlabel("WFSS columns")
plt.ylabel("WFSS rows")

We initialize the ```grismconf``` config object. This contains the information and polynomials describing the dispersion of the disperser, as well as the corresponding inverse sensitiviy curve.

In [None]:
grismconf_file = os.path.join(cPath,f"NIRCAM_{FILTER}_mod{MODULE}_{PUPIL[-1]}.conf")
print(f"Using the grismconf file {grismconf_file}")
C = grismconf.Config(os.path.join(cPath,"NIRCAM_F444W_modA_R.conf"))

Plot the inverse sentivity, which shows the wavelength range and shape of the sensitivity. This is defined in units of $DN/s$ per flam ($erg/s/cm^2/A$).

In [None]:
plt.plot(C.SENS_data["+1"][0],C.SENS_data["+1"][1])
plt.grid()
plt.xlabel("Wavelength (micron)")
plt.ylabel(r"DN/s per $erg/s/cm^2/A$")

When simulating this dispersed spectrum, we need to consider which wavelength of light is being dispersed, so each of the pixels above gets numerically dispersed at different discrete wavelengths. We use the ```grimconf``` configuration to quickly get the wavelength range that corresponds to the disperser and is included as the WRANGE attribute of the ```grismconf``` config object we initialized above.

In [None]:
wmin = C.WRANGE["+1"][0]
wmax = C.WRANGE["+1"][1]

print(f"The wavelength range to consider is {wmin} to {wmax}")

In [None]:
dlam = 0.001
lams = np.arange(wmin,wmax,dlam)

print(f"We are using {len(lams)} values of wavelengths")

Next, we need to process each object. The following shows the process for a single pixel.

In [None]:
i = 10

We start by computing the ```t``` values corresponding to the wavelengths (lams) we are considering. Refer to the Box Extraction notebook for additional background information about this.

In [None]:
ts = C.INVDISPL("+1",xs[i],ys[i],lams)

This computes the coordinates in the WFSS observation of the bottom left corner of our pixel.

In [None]:
xgsA = C.DISPX("+1",xs[i],ys[i],ts) + xs[i]
ygsA = C.DISPY("+1",xs[i],ys[i],ts) + ys[i]

The following three computations compute the other three corners:

In [None]:
xgsB = C.DISPX("+1",xs[i]+1,ys[i],ts) + xs[i]+1
ygsB = C.DISPY("+1",xs[i]+1,ys[i],ts) + ys[i]

In [None]:
xgsC = C.DISPX("+1",xs[i]+1,ys[i]+1,ts) + xs[i]+1
ygsC = C.DISPY("+1",xs[i]+1,ys[i]+1,ts) + ys[i]+1

In [None]:
xgsD = C.DISPX("+1",xs[i],ys[i]+1,ts) + xs[i]
ygsD = C.DISPY("+1",xs[i],ys[i]+1,ts) + ys[i]+1

We re-organize things a little to contain a list of polygon corners which are used by the ```pypolyclip``` module to compute their overlap with the pixel coordinates of the WFSS observation. While we are looking at a single input source pixel, we are computing this at many different wavelength values so the resultant is a list of many pixels/polygons to project onto our WFSS rectilinear pixel grid.

In [None]:
pxs = [ [xgsA[ii],xgsB[ii],xgsC[ii],xgsD[ii]] for ii in range(len(xgsA))]

In [None]:
pys = [ [ygsA[ii],ygsB[ii],ygsC[ii],ygsD[ii]] for ii in range(len(ygsA))]

Checking the resulting set of dispersed pixels, we can see that everything is properly sampled. As this figure shows, the input pixel that is being dispersed, using discrete values of wavelengths, results in the dispersed pixels over all of the WFSS detector grid. It is necessary to compute and attribute the proper contribution, in $DN/s$, to each of the WFSS detector pixels. Note that we only show one of the source pixels being dispersed, while for a full object, each of the input source pixels should be similarly dispersed, resulting in multiple dispersed pixels contributing to the final counts in each of the WFSS detector pixels.

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,3))
for i in range(len(pxs)):
    tx = pxs[i]
    tx.append(pxs[i][0])
    ty = pys[i]
    ty.append(pys[i][0])
    plt.plot(tx,ty)

mid = (len(pxs) - 1)/2

plt.xticks(range(0, len(pxs)))
plt.xlim(pxs[int(mid)][0]-15,pxs[int(mid)][0]+15)
plt.xlabel("WFSS columns")
plt.ylabel("WFSS Rows")
plt.grid()

We can now use the ```pypolyclip.clip_multi``` module to compute how much each dispersed pixel (the colored boxes above) falls onto WFSS pixels (shown as the grid above).

In [None]:
xc, yc, area, slices = clip_multi(pxs, pys, [2048,2048])

We can do this for all the pixels in this object and keep track of all the information, such as the wavelength and how much of the fraction of the original imaging pixel flux falls onto the WFSS simulated pixel array.

In [None]:
xcs = []
ycs = []
alams = []
flams = []

all_pxs = []
all_pys = []
all_flams = []
all_counts = []

# Go through all the input source pixels, in the WFSS reference frame.
for i in tqdm.tqdm(range(len(xs))):
    
    # We use the imaging flux in each of these pixels to compute the input DN/s
    counts = cds[i]
    flam = counts*PHOTFLAM

    # Disperse this pixel using len(lams) wavelength. This results in len(lams) projected pixels contributing to the final WFSS data
    ts = C.INVDISPL("+1",xs[i],ys[i],lams)
    xgsA = C.DISPX("+1",xs[i],ys[i],ts) + xs[i]
    ygsA = C.DISPY("+1",xs[i],ys[i],ts) + ys[i]
    xgsB = C.DISPX("+1",xs[i]+1,ys[i],ts) + xs[i]+1
    ygsB = C.DISPY("+1",xs[i]+1,ys[i],ts) + ys[i]
    xgsC = C.DISPX("+1",xs[i]+1,ys[i]+1,ts) + xs[i]+1
    ygsC = C.DISPY("+1",xs[i]+1,ys[i]+1,ts) + ys[i]+1
    xgsD = C.DISPX("+1",xs[i],ys[i]+1,ts) + xs[i]
    ygsD = C.DISPY("+1",xs[i],ys[i]+1,ts) + ys[i]+1

    # Use the corners of the dispersed pixels, and compute the WFSS to which they whould contribute, and by how much
    pxs = [ [xgsA[ii],xgsB[ii],xgsC[ii],xgsD[ii]] for ii in range(len(xgsA))]
    pys = [ [ygsA[ii],ygsB[ii],ygsC[ii],ygsD[ii]] for ii in range(len(ygsA))]
    xc, yc, area, slices = clip_multi(pxs, pys, [2048,2048])

    # Bookkeeping to track of the wavelength of each of the areas being projected into the WFSS pixel grid
    tlams = np.zeros(len(xc))
    for i in range(len(slices)):
        tlams[slices[i]] = lams[i]

    # Store the flux, wavlength, and where they should end up on the WFSS pixel grid. Note the values in xcs and ycs are not unique
    xcs.extend(xc.tolist())
    ycs.extend(yc.tolist())
    flams.extend((flam*area).tolist())
    alams.extend(tlams.tolist())

    # Save for plotting later. Only used for plot below.
    all_pxs.append(pxs)
    all_pys.append(pys)
    all_flams.append(flam)
    all_counts.append(flam * C.SENS["+1"](tlams) * dlam * 10000 )

We can attempt to display what needs to be done by plotting our projected pixels and their relative intensity on top of a simulated WFSS grid.

At this point, we have a list of WFSS pixels (xcs,ycs), the flux falling on these pixels (flams, in flam units), and the wavelength of the light contained in them (alams). In our simulation, we do not want to project flux units, but $DN/s$, so we convert the input flam valued into $DN/s$ (using the reverse relation we used in the Box Extraction notebook when we performed the inverser operation to convert extracted $DN/s$ into flam flux units).

In [None]:
s = C.SENS["+1"](alams)
counts = flams * s * dlam * 10000 

# Note: the factor of 10000 accounts for dlam being in micron while we want A since the inverse sensitivity is defined per A.

In [None]:
print(f"There are {len(counts)} dispersed bits of pixels to combine into a final WFSS pixel grid")

At this point, we have a large list of $DN/s$ values and where they should be added onto our simulated WFSS observation in order to simulate the full dispersed spectrum of our source. There are duplicated entries in the xcs,ycs coordinate list as different wavelengths get mixed by the object's "self-contamination".

The following plot shows the dispersed input pixels, using blue outlines, projected onto the final WFSS pixels, and shaded in black proportionally to their flux.

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,3))
for i in range(len(all_pxs)):
    for j in range(len(all_pxs[i])):
        
        tx = all_pxs[i][j][:]
        tx.append(tx[0])
        ty = all_pys[i][j][:]
        ty.append(ty[0])
        plt.plot(tx,ty,color='b',alpha=0.02)
        c = all_counts[i]
        c[c<0] = 0
        alpha_val = c[j]/c.max()/10
        if np.isnan(alpha_val):
            alpha_val=0
        plt.fill(tx,ty,color='k',alpha=alpha_val)
        
plt.grid()
plt.xticks(range(0, len(pxs)))
plt.xlim(pxs[int(mid)][0]-15,pxs[int(mid)][0]+15)

To quickly combine all of these counts at each of their WFSS pixel locations, we can use ```scipy.coo_matrix``` which is fast and efficient:

In [None]:
xcs = np.array(xcs)
ycs = np.array(ycs)

ok = (xcs>=0) & (xcs<2048) &  (ycs>=0) & (ycs<2048) 
simulated = coo_matrix((counts[ok],(ycs[ok],xcs[ok])),shape=(2048,2048)).toarray()

## Plot the Simulated Spectrum

Now we plot the simulated spectrum for this source.

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,3))
ax.imshow(simulated, origin="lower", aspect='auto',vmin=0,vmax=2)
ax.set_xlim(pxs[int(mid)][0]-15,pxs[int(mid)][0]+15)
ax.set_ylim(pys[int(mid)][0]-15,pys[int(mid)][0]+15)
# plt.xticks(range(0, len(pxs)));

We can compare with the real data.

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,3))
ax.imshow(wfss_data.data,origin="lower", aspect='auto',vmin=0,vmax=2)
ax.set_xlim(pxs[int(mid)][0]-15,pxs[int(mid)][0]+15)
ax.set_ylim(pys[int(mid)][0]-15,pys[int(mid)][0]+15)

And we can plot the difference between real and simulated data.

In [None]:
fig,ax = plt.subplots(1,1,figsize=(15,3))
ax.imshow(wfss_data.data-simulated,origin="lower", aspect='auto',vmin=0,vmax=2)
ax.set_xlim(pxs[int(mid)][0]-15,pxs[int(mid)][0]+15)
ax.set_ylim(pys[int(mid)][0]-15,pys[int(mid)][0]+15)

In [None]:
plt.plot(np.sum(simulated,axis=-1),label="Simulated")
plt.plot(np.nansum(wfss_data.data,axis=-1),label="Real")

max = np.max(np.sum(simulated,axis=-1))
min = np.min(np.sum(simulated,axis=-1))
plt.xlim(pys[int(mid)][0]-15,pys[int(mid)][0]+15)
plt.ylim(min-100,max+2000)
plt.legend()

In [None]:
plt.plot(np.sum(simulated,axis=-1),label="Simulated")
plt.plot(np.nansum(wfss_data.data,axis=-1)-700,label="Real")

max = np.max(np.sum(simulated,axis=-1))
min = np.min(np.sum(simulated,axis=-1))
plt.xlim(pys[int(mid)][0]-15,pys[int(mid)][0]+15)
plt.ylim(min-100,max+2000)
plt.legend()

## About This Notebook

**Author**: Nor Pirzkal, ESA/AURA Level III Astronomer <br>
**Created On**: 2024-06-11