# WFSS Spectra Simulation for Contamination Correction

This notebook demonstrates basic techniques to simulate WFSS dispersed spectra, given source locations from an imaging mode exposure and a generalized World Coordinate System (gWCS) from an accompanying WFSS exposure.

Such simulations, when applied to all sources in the field, are crucial to estimate the contamination of overlapping spectra, as well as to mask dispersed spectral traces when trying to estimate the background.

This notebook builds on the simpler <a href="https://github.com/spacetelescope/jdat_notebooks/blob/main/notebooks/NIRCam/NIRCam_WFSS_Box_extraction/BoxExtraction_using_Grismconf_CRDS.ipynb">Box Extraction notebook</a> where we introduced the general concepts of spectral extraction as well as the <A HREF="https://github.com/npirzkal/GRISMCONF">GRISMCONF</A> module, which provides us with a low level interface to the gWCS model for WFSS modes.

The simulation method is relatively straight-foward and begins with an imaging mode exposure. 

* We first determine which pixels contain the source and the signal levels in those pixels.

* We then use the gWCS from the imaging mode file and also the WFSS observation to compute the corresponding location of each pixel in the frame of the WFSS observation. 

* Using the GRISMCONF functions and a wavelength vector, we next simulate the dispersion of the source. For each of the source pixels, we move through the wavelength vector and calculate the location where each wavelength value is dispersed to. We calculate the signal level associated with each wavelength position based on the flux of each input pixel in $F_{\lambda}$ units ($erg/s/cm^2/A$). 

* Since the calculated coordinates of the wavelength values will not, in general, align with the detector pixel grid, we use the <a href="https://github.com/spacetelescope/pypolyclip">Sutherland-Hodgman algorithm</A> to compute the fraction of each projected dispersed pixel that overlaps onto each detector pixel. Summing the signals from each 

Author: N. Pirzkal <br>
Date created: 12 December 2024

## Table of Contents
1. [Set CRDS Path and Server](#Set-CRDS-Path-and-Server)
2. [Package Imports](#Package-Imports)
3. [Define Functions and Parameters](#Define-Functions-and-Parameters)
4. [Download Data](#Download-Data)
5. [Run Pipeline Steps](#Run-Pipeline-Steps)
6. [Detect Sources](#Detect-Sources)
7. [Simulate spectrum of one source](#Simulate-spectrum-of-one-source)
      * [Locate source in imaging and WFSS data](#Locate-source-in-imaging-and-WFSS-data)
      * [Get wavelength information](#Get-wavelength-information)
      * [Simulate the dispersion of a single pixel](#Simulate-the-dispersion-of-a-single-pixel)
      * [Disperse all the pixels for our source](#Disperse-all-the-pixels-for-our-source)

## Set CRDS Path and Server

Before running the pipeline steps, we need to ensure our our CRDS environment is configured. This includes defining a CRDS cache directory in which to keep the reference files that will be used by the calibration pipeline.

If the root directory for the local CRDS cache has not already been set, it will be created in the home directory. This needs to be done before importing the crds package or any package that has crds as a dependency.

In [None]:
import os

In [None]:
# Check whether the local CRDS cache directory has been set.
# If not, set it to the user home directory
if (os.getenv('CRDS_PATH') is None):
    os.environ['CRDS_PATH'] = os.path.join(os.path.expanduser('~'), 'crds')
    
# Check whether the CRDS server URL has been set.  If not, set it.
if (os.getenv('CRDS_SERVER_URL') is None):
    os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'

# Echo CRDS path and context in use
print('CRDS local filepath:', os.environ['CRDS_PATH'])
print('CRDS file server:', os.environ['CRDS_SERVER_URL'])

# Import crds after setting up the required environment variables
from crds import client
if client.get_crds_server() != os.environ['CRDS_SERVER_URL']:
    client.set_crds_server('https://jwst-crds.stsci.edu')

## Package Imports

In [None]:
from astropy.convolution import convolve
from astropy.io import fits
from copy import deepcopy
import matplotlib.pyplot as plt
import numpy as np
import requests
from scipy.sparse import coo_matrix
import tqdm

import grismconf
from jwst import datamodels
from jwst.assign_wcs import AssignWcsStep
from jwst.flatfield import FlatFieldStep
from jwst.photom import PhotomStep
from photutils.background import Background2D, MedianBackground
from photutils.segmentation import make_2dgaussian_kernel
from photutils.segmentation import detect_sources
from pypolyclip import clip_multi

## Define Functions and Parameters

Define a function to download a named file via the MAST API to the current directory. The function includes authentication logic, but this example uses public data, so no MAST API token is required.

In [None]:
def get_jwst_file(name, mast_api_token=None, overwrite=False):
    """Retrieve a JWST data file from MAST archive.

    Parameters
    ----------
    name : str
        Name of the file to download from MAST
        
    mast_api_token : str
        MAST API token. Required only for proprietary data
        
    overwrite : bool
        If True and the requested file already exists locally, the file will not be downloaded. IF False,
        the file will be downloaded
"""
    # If the file already exists locally, don't redownload it, unless the
    # user has set the overwrite keyword
    if os.path.isfile(name):
        if not overwrite:
            print(f'{name} already exists locally. Skipping download.')
            return
        else:
            print(f'{name} exists locally. Re-downloading.')

    mast_url = "https://mast.stsci.edu/api/v0.1/Download/file"
    params = dict(uri=f"mast:JWST/product/{name}")
    if mast_api_token:
        headers = dict(Authorization=f"token {mast_api_token}")
    else:
        headers = {}
    r = requests.get(mast_url, params=params, headers=headers, stream=True)
    r.raise_for_status()
    with open(name, "wb") as fobj:
        for chunk in r.iter_content(chunk_size=1024000):
            fobj.write(chunk)
    if os.path.isfile(name):
        print(f"{name} successfully downloaded")

Define a function that will run assign_wcs and flat fielding on an input rate file

In [None]:
def run_pipeline_steps(filename):
    """Run the assign_wcs, flat field, and photom calibration steps on the given file.
    If the file contains WFSS data, trick the pipeline to use the imaging mode flat
    field reference file.
    
    Parameters
    ----------
    filename : str
        Name of the input file upon which the steps will be run
        
    Returns
    -------
    filename : str
        Name of the output file saved by the pipeline steps
        
    photom : jwst.datamodels.ImageModel
        Datamodel instance containing the calibrated data
    """
    assign_wcs = AssignWcsStep.call(filename)

    # In order to apply the imaging mode flat field reference file to the data,
    # we need to trick CRDS by temporarily changing the pupil value to be CLEAR
    reset_pupil = False
    if 'GRISM' in assign_wcs.meta.instrument.pupil:
        true_pupil = deepcopy(assign_wcs.meta.instrument.pupil)
        assign_wcs.meta.instrument.pupil = 'CLEAR'
        reset_pupil = True

    # Run the flat field step
    flat = FlatFieldStep.call(assign_wcs, save_results=True)

    # Run the photom step to populate the name of the WFSS sensitivity 
    photom = PhotomStep.call(flat, save_results=True)

    # Set the pupil back to the original value now that flat fielding is complete
    if reset_pupil:
        photom.meta.instrument.pupil = true_pupil
        photom.save(photom.meta.filename)
    
    # Return the name of the output file, as well as the datamodel
    return photom.meta.filename, photom

In [None]:
def show_2d_spec(data, xlim=(200, 240), ylim=(1705, 1730), vmin=0, vmax=3):
    """Show a 2D image with colorbar. Designed to display the 2D real and
    simulated spectra

    Parameters
    ----------
    data : numpy.ndimage
        2D image
    xlim : tup
        2-tuple of beginning and ending x-coordinates for the display
    ylim : tup
        2-tuple of beginning and ending y-coordinates for the display
    vmin : float
        Signal corresponding to minimum display scale
    vmax : float
        Signal corresponding to maximum display scale
    """
    fig, ax = plt.subplots(1, 1, figsize=(15, 3))
    cax = ax.imshow(data, origin="lower", aspect='auto', cmap='viridis', vmin=vmin, vmax=vmax)
    ax.set_xlim(xlim[0], xlim[1])  # change to (0, 700) to see the entire spectrum)
    ax.set_ylim(ylim[0], ylim[1])
    plt.xticks(range(xlim[0], xlim[1], 5))
    plt.xlabel("Dispersion coordinate (pixel)")
    plt.ylabel("Cross dispersion coordinate (pixel)")
    colorbar = fig.colorbar(cax, orientation='vertical', pad=0.01)
    colorbar.ax.set_ylabel('Signal', labelpad=10, rotation=270)

## Download Data

We start with a simple pair of imaging and wfss files. These were manually selected and point at the same field on the sky, and use the same NIRCam module, channel, and cross filter.

In [None]:
# First, download the imaging and WFSS files from MAST
imaging_file = "jw01076109001_02102_00001_nrcalong_cal.fits"
wfss_file = "jw01076109001_02101_00001_nrcalong_rate.fits"
get_jwst_file(imaging_file)
get_jwst_file(wfss_file)

## Run Pipeline Steps

We want to assign a WCS, apply a flat-field, and flux calibrate the WFSS data. For this, we use the run_pipeline_steps() function defined above. This will call the appropriate pipeline steps. We apply the imaging mode flat field to the WFSS file since the flat fields are not wavelength dependent.

In [None]:
# Run AssignWcsStep, FlatFieldStep, and PhotomStep on the WFSS rate file
wfss_file, wfss_data = run_pipeline_steps(wfss_file)

# Extract the WFSS pixel data from the datamodel instance
wfss_data = wfss_data.data

In [None]:
print(f"Pipeline-processed WFSS file is {wfss_file}")

Read some information from the imaging data and WFSS data. We need to know which module, channel, cross filter, and grism we are looking at. We also need to find the values needed to convert the surface brightness units of the imaging cal files into units of $erg/s/cm^2/A$

In [None]:
img_hdr0 = fits.getheader(imaging_file, 0)
img_hdr1 = fits.getheader(imaging_file, 1)
wfss_hdr = fits.getheader(wfss_file)

In [None]:
FILTER = img_hdr0["FILTER"]
MODULE = img_hdr0["MODULE"]
PUPIL = img_hdr0["PUPIL"]
PIXAR_SR = img_hdr1["PIXAR_SR"]
print(f"IMAGING FILTER: {FILTER}, MODULE: {MODULE}, PUPIL: {PUPIL}, Sise of pixel: {PIXAR_SR} steradians")

In [None]:
FILTER = wfss_hdr["FILTER"]
MODULE = wfss_hdr["MODULE"]
PUPIL = wfss_hdr["PUPIL"]
print(f"WFSS FILTER: {FILTER}, MODULE: {MODULE}, PUPIL: {PUPIL}")

Note that in this simple example, the imaging and WFSS data use the same cross filter and are of course using the same NIRCam module

Compute the conversion between pixel values in the imaging data, which are in MJy/sr, and $F_{\lambda}$ units of $erg/s/cm^2/A$ (per pixel). Multiplying the values in our cal imaging file by this value (called PHOTFLAM) is all we will need to determine the $F_{\lambda}$ values of each of the pixels in an object detected in our imaging data

NIRCam filter pivot wavelengths are listd in Tables 2 and 3 on the <a href="https://jwst-docs.stsci.edu/jwst-near-infrared-camera/nircam-instrumentation/nircam-filters#gsc.tab=0">NIRCam Filters</a> J

In [None]:
Pivot_wavelength = 44010  # Angstroms
PHOTFLAM = 1e6 * PIXAR_SR / 3.3356e4 / Pivot_wavelength**2
print(f'PHOTFLAM is {PHOTFLAM}')

## Detect Sources

Since we want to disperse each pixel comprising a given source, we first need to get a segmentation map of all the objects in the image. We follow the general method presented in photutils' <a href="https://photutils.readthedocs.io/en/stable/api/photutils.segmentation.detect_sources.html">detect_sources() function documentation</a> to create the segmentation map. 

In [None]:
imaging_data = fits.getdata(imaging_file)

In [None]:
# Create a 2D background model of the scene, and subtract
bkg_estimator = MedianBackground()
bkg = Background2D(imaging_data, (50, 50), filter_size=(21, 31), bkg_estimator=bkg_estimator)
imaging_data -= bkg.background 

In [None]:
# Use the RMS of the background to set the threshold value for source detection
threshold = 50 * bkg.background_rms

In [None]:
# Convolve the image with a 2D Gaussian kernel
kernel = make_2dgaussian_kernel(3.0, size=5)
convolved_data = convolve(imaging_data, kernel)

In [None]:
# Detect the sources in the imaging data
segment_map = detect_sources(convolved_data, threshold, npixels=10)

Below we see that the segmentation map shows a large number of sources spread across the detector.

In [None]:
# Show the segmentation map
plt.imshow(segment_map, origin="lower", vmin=1, vmax=2)

## Simulate spectrum of one source

Here, we show how to simulate the dispersion of only one source. In order to simulate a full WFSS observation, what we show here needs to be done for every sourece in the field. Simulating all the dispersed spectra is also a way to mask out spectra when estimating the dispersed background level during subsequent extraction and it also allows for an estimate of the amount of spectral contamination by overlapping spectra.

For this example, we choose the source at coordinates (x,y) = (405,1465) and show how to simulate its spectrum.

We first get its segmentation map ID and create a list of all pixels associated with the source in the imaging data.

In [None]:
xd, yd = 405, 1465
ID = segment_map.data[yd, xd]
print(f"Object ID is: {ID}")

### Locate source in imaging and WFSS data

Working with the imaging mode file, create a list of coordinates for all of the source's pixels, along with their flux values (in MJy/SR).

In [None]:
ok = segment_map.data == ID
yds, xds = np.nonzero(ok)
cds = imaging_data[ok]

Show the source, as well as its segmentation map, from the imaging data. On the left we see that the source appears to be a point source. The segmentation map shows a collection of roughly 6x6 pixels that have been identified as part of this source.

In [None]:
min_x = np.min(xds)
max_x = np.max(xds)
min_y = np.min(yds)
max_y = np.max(yds)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].imshow(imaging_data[min_y:max_y + 1, min_x:max_x + 1], origin="lower")
axs[1].imshow(segment_map.data[min_y:max_y + 1, min_x:max_x + 1], origin="lower")
axs[0].set_title("Source")
axs[1].set_title("Segmentation map")

All the information we have for this source is within the reference frame of the imaging data. But we need to know where the flux for this source is in the WFSS observation. This is handled using the gWCS of both imaging and WFSS observations. With these, we translate the location of each source pixel in the imaging data into the corresponding pixel location in the WFSS data.

In [None]:
# Open the imaging file using the JWST datamodels, and retrieve the WCS information
imaging_wcs = datamodels.open(imaging_file)
imaging_to_world = imaging_wcs.meta.wcs.get_transform('detector', 'world')

In [None]:
# Get the WCS information associated with the WFSS file
wfss_wcs = datamodels.open(wfss_file)
wfss_to_pix = wfss_wcs.meta.wcs.get_transform('world', 'detector')

For reference, below we see the pixel coordinates in the imaging data. The pixels corresponding to the source range between x values of 402 and 408, and y values of 1461 and 1468.

In [None]:
plt.scatter(xds, yds)
plt.xlabel("Imaging columns")
plt.ylabel("Imaging rows")

Compute the R.A. and Dec of each of the source's imaging mode input pixels

In [None]:
ras, decs = imaging_to_world(xds, yds)

Now compute the pixel coordinates of where the **undispersed** source would be in the WFSS data. In this case the gWCS requires an input wavelength and spectral order. These same values are returned by the translation function. In this case we choose a wavelength of 3.56 microns and a spectral order of 1. The pixel coordinates returned in this case are not dependent upon wavelength or spectral order, so any values can be used.

In [None]:
xs, ys, xxx, yyy = wfss_to_pix(ras, decs, 3.56, 1)

Show the pixel location of the **undispersed** source in the WFSS data. Note that the location is significantly different than that in the imaging data. In this case, the source is located between x values of 117 to 123, and y values of 1740 and 1746.

In [None]:
plt.scatter(xs, ys)
plt.xlabel("WFSS columns")
plt.ylabel("WFSS rows")

### Get wavelength information

When simulating this dispersed spectrum, we need to consider the wavelength of the light that is being dispersed. So each of the pixels above will be numericaly dispersed across a range of discrete wavelengths. 

To get the needed wavelength information, we initialize a grismconf Config object. This contains the information and polynomials describing the dispersion of the disperser as well as the corresponding inverse sensitivity curve.

In [None]:
C = grismconf.Config(wfss_file)

Show the inverse sentivity, which includes the wavelength range and shape of the sensitivity. This is defined in units of DN/s per $F_{\lambda}$ ($erg/s/cm^2/A$). The curve shows significant sensitivity between about 3.85 microns, and 5.05 microns.

In [None]:
plt.plot(C.SENS_data["+1"][0], C.SENS_data["+1"][1])
plt.grid()
plt.xlabel("Wavelength (micron)")
plt.ylabel(r"DN/s per $erg/s/cm^2/A$")

We use the grimconf configuration object to quickly get the wavelength range that corresponds to the disperser. This is present in the WRANGE attribute.

In [None]:
wmin = C.WRANGE["+1"][0]
wmax = C.WRANGE["+1"][1]

print(f"The wavelength range to consider is {wmin} to {wmax} microns")

Compute the dispersion in units of wavelength per pixel. Grismconf can give us the derivative of the dispersion in units of wavelength as well as in units of pixels with respect to the t parameter. Details of the $t$ parameter are given in the <a href="https://github.com/spacetelescope/jdat_notebooks/blob/main/notebooks/NIRCam/NIRCam_WFSS_Box_extraction/BoxExtraction_using_Grismconf_CRDS.ipynb">Box Extraction Notebook</a>. As noted there, $t$ is a normalized parameter, whre values of $t = 0$ and $t = 1$ correspond to the blue and red edges of a dispersed spectrum.

While the dispersion varies slightly across the area of the detector covered by a spectrum, we use the dispersion at a $t$ value of 0.5, which corresponds to the middle of the spectral range.

In [None]:
dlam = C.DDISPL("+1", 1000, 1000, 0.5) / C.DDISPX("+1", 1000, 1000, 0.5)
print(f"Dispersion is {dlam * 10000} Angstroms per pixel")

Next, create an array of wavelength values at which the simulated spectrum will be calculated. For this, we must pick a wavelength step. Ideally this step should be smaller than the native dispersion of the grism. We therefore pick half of the dispersion value calculated above.

In [None]:
dlam = dlam / 2
lams = np.arange(wmin, wmax, dlam)

print(f"We are using {len(lams)} values of wavelength")

In [None]:
print(f'First and last elements of the wavelength array: {"%.5f" % lams[0]} microns, {"%.5f" % lams[-1]} microns')

### Simulate the dispersion of a single pixel

With the wavelength information in hand, we can create a simulated dispersion of each object pixel.

Below, we show the process for a single pixel. We choose a pixel relatively close to the center of the source.

In [None]:
i = 22
print(f"Use pixel (x, y) = ({xs[i]}, {ys[i]}) for single pixel dispersion")

We start by computing the $t$ values corresponding to the wavelengths (lams) we are considering. Refer to the Box Extraction notenook for additional background details on the $t$ value.

In [None]:
ts = C.INVDISPL("+1", xs[i], ys[i], lams)

Next, we create an array of polygons representing the locations of the dispersed signal coming from our selected input pixel.

The cell below computes the array of coordinates, in the WFSS observation, of the bottom left corner of our pixel.

In [None]:
xgsA = C.DISPX("+1", xs[i], ys[i], ts) + xs[i]
ygsA = C.DISPY("+1", xs[i], ys[i], ts) + ys[i]

The following three cells compute the locations of the other three corners:

In [None]:
xgsB = C.DISPX("+1", xs[i] + 1, ys[i], ts) + xs[i] + 1
ygsB = C.DISPY("+1", xs[i] + 1, ys[i], ts) + ys[i]

In [None]:
xgsC = C.DISPX("+1", xs[i] + 1, ys[i] + 1, ts) + xs[i] + 1
ygsC = C.DISPY("+1", xs[i] + 1, ys[i] + 1, ts) + ys[i] + 1

In [None]:
xgsD = C.DISPX("+1", xs[i], ys[i] + 1, ts) + xs[i]
ygsD = C.DISPY("+1", xs[i], ys[i] + 1, ts) + ys[i] + 1

Re-organize things a little to contain a list of polygon corners, which are used by the pypolyclip module to compute their overlap with the pixel coordinates of the WFSS observation. While were are looking at a single input source pixel, we are computing this at many different wavelength values so the resultant is a list of many pixels/polygons to project onto our WFSS rectilinear pixel grid.

In [None]:
pxs = [[xgsA[ii], xgsB[ii], xgsC[ii], xgsD[ii]] for ii in range(len(xgsA))]

In [None]:
pys = [[ygsA[ii], ygsB[ii], ygsC[ii], ygsD[ii]] for ii in range(len(ygsA))]

Below we create a figure showing the resulting locations of dispersed pixels. We dispersed the single input pixel using an array of wavelengths that is oversampled by a factor of 2 relative to the native dispersion of the grism. That wavelength array is translated into an array of pixel positions which overlay the WFSS detector grid. 

For clarity, we zoom in on a 40 pixel wide area in the dispersion direction. This shows that our single imaging mode pixel will be dispersed along a nearly horizontal line (seen as the rainbow colored boxes) in this area. Zooming out on this plot, by changing the xlim values in the cell below, will show the entire set of pixels corresponding to the dispersed input pixel. For consistency, we show similar 40 pixel wide plots for the other figures in the next two sections.

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 3))
for i in tqdm.tqdm(range(len(pxs))):
    tx = pxs[i]
    tx.append(pxs[i][0])
    ty = pys[i]
    ty.append(pys[i][0])
    plt.plot(tx, ty)

plt.xticks(range(0, len(pxs)))

plt.xlim(200, 240)  # change to (0, 700) to see the entire spectrum)
plt.xlabel("WFSS columns")
plt.ylabel("WFSS Rows")
plt.grid()
plt.xlabel("Dispersion coordinate (pixel)")
plt.ylabel("Cross ispersion coordinate (pixel)")

We can now use the pypolyclip.clip_multi to compute how much of each dispersed pixel (colored boxes above) falls onto each pixel in the WFSS image pixel grid (shown as the gray grid above). Details about this is available on the <A HREF="https://github.com/spacetelescope/pypolyclip">pypolyclip page</A>.

In [None]:
xc, yc, area, slices = clip_multi(pxs, pys, [2048, 2048])

In the following section, we will repeat this workflow on all pixels, and use the area information to scale the flux values for all output pixels.

### Disperse all the pixels for our source

Note that the figure above shows a single source pixel being dispersed. For a full source, each of the input source pixels should be similarly dispersed, resulting in multiple dispersed pixels contributing to the final counts in each of the detector pixels in the WFSS image. We will also need to compute and attribute the proper flux, in DN/s, to each of the WFSS detector pixels from each input pixel. This is done below for our selected object.

We must keep track of all the information such as the wavelength and fraction of the original imaging pixel flux that falls onto the WFSS simulated pixel array for each imaging pixel.

In [None]:
xcs = []
ycs = []
alams = []
flams = []

all_pxs = []
all_pys = []
all_flams = []
all_counts = []

# Loop over the input source pixels in the WFSS reference frame.
for i in tqdm.tqdm(range(len(xs))):
    # Use the imaging flux in each of these pixels to compute the input DN/s and flam units
    counts = cds[i]
    flam = counts * PHOTFLAM

    # Disperse this pixel using len(lams) wavelength. This results in len(lams) projected
    # pixels contributing to the final WFSS data
    ts = C.INVDISPL("+1", xs[i], ys[i], lams)
    xgsA = C.DISPX("+1", xs[i], ys[i], ts) + xs[i]
    ygsA = C.DISPY("+1", xs[i], ys[i], ts) + ys[i]
    xgsB = C.DISPX("+1", xs[i] + 1, ys[i], ts) + xs[i] + 1
    ygsB = C.DISPY("+1", xs[i] + 1, ys[i], ts) + ys[i]
    xgsC = C.DISPX("+1", xs[i] + 1, ys[i] + 1, ts) + xs[i] + 1
    ygsC = C.DISPY("+1", xs[i] + 1, ys[i] + 1, ts) + ys[i] + 1
    xgsD = C.DISPX("+1", xs[i], ys[i]+1, ts) + xs[i]
    ygsD = C.DISPY("+1", xs[i], ys[i]+1, ts) + ys[i] + 1

    # Use the corners of the dispersed pixels, and compute the WFSS pixels which they
    # overlap, and by how much
    pxs = [[xgsA[ii], xgsB[ii], xgsC[ii], xgsD[ii]] for ii in range(len(xgsA))]
    pys = [[ygsA[ii], ygsB[ii], ygsC[ii], ygsD[ii]] for ii in range(len(ygsA))]
    xc, yc, area, slices = clip_multi(pxs, pys, [2048, 2048])

    # Book keeping to track the wavelength of each of the areas being projected into
    # the WFSS pixel grid
    tlams = np.zeros(len(xc))
    for i in range(len(slices)):
        tlams[slices[i]] = lams[i]

    # Store the flux, wavelength, and where they should end up on the WFSS pixel grid.
    # Note the values in xcs and ycs are not unique
    xcs.extend(xc.tolist())
    ycs.extend(yc.tolist())
    flams.extend((flam * area).tolist())
    alams.extend(tlams.tolist())

    # Save for plotting later. Only used for plot below.
    all_pxs.append(pxs)
    all_pys.append(pys)
    all_flams.append(flam)
    all_counts.append(flam * C.SENS["+1"](tlams) * dlam * 10000)

At this point, we have a list of WFSS pixels (xcs, ycs), the flux falling on these pixels (flams, in $F_{\lambda}$ units), and the wavelength of the light contained in them (alams).
In our simulation, we do not want to project flux units but rather DN/s, so we convert the input flams values into DN/s (using the reverse relation we used in the Box Extraction notebook when we performed the inverse operation to convert extracted DN/s into $F_{\lambda}$ flux units)

In [None]:
# Note: the factor of 10000 below accounts for dlam being in microns, while we
# want Angstroms since the inverse sensitivity is defined per Angstrom.
s = C.SENS["+1"](alams)
counts = flams * s * dlam * 10000

In [None]:
print(f"There are {len(counts)} dispersed bits of pixels to combine into a final WFSS pixel grid")

We now have a large list of DN/s values and where they should be added onto our simulated WFSS observationn in order to simulate the full dispersed spectrum of our source.
There are duplicate entries in the (xcs,ycs) coordinate list as different wavelengths get mixed by the object's "self-contamination".

The following plot shows the dispersed input pixels, using blue outlines, projected onto the final WFSS pixels, which are shown as the gray grid. The dispersed pixels are shaded in black proportionally to their flux (in DN/s).

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 3))
for i in tqdm.tqdm(range(len(all_pxs))):
    for j in range(len(all_pxs[i])):
        
        tx = all_pxs[i][j][:]
        tx.append(tx[0])
        ty = all_pys[i][j][:]
        ty.append(ty[0])
        plt.plot(tx, ty, color='b', alpha=0.02)
        c = all_counts[i]
        c[c < 0] = 0
        plt.fill(tx, ty, color='k', alpha=c[j])
plt.grid()
plt.xlim(200, 240)  # change to (0, 700) to see the entire spectrum)
plt.xticks(range(200, 240))
plt.xlabel("Dispersion coordinate (pixel)")
plt.ylabel("Cross dispersion coordinate (pixel)")

To quickly combine all of these counts at each of their WFSS pixel location, we can use scipy.coo_matrix which is fast and efficient:

In [None]:
xcs = np.array(xcs)
ycs = np.array(ycs)

# Ignore counts and coordinates that are outside of the detector
ok = (xcs >= 0) & (xcs < 2048) & (ycs >= 0) & (ycs < 2048) 
simulated = coo_matrix((counts[ok], (ycs[ok], xcs[ok])), shape=(2048, 2048)).toarray()

Show the resulting 2D simulated spectrum for this source. We see a nearly horizontal trace across the field of view. Increasing the x range of the plot to (0, 700) will reveal the full trace.

In [None]:
# Change xlim to (0, 700) to see the entire spectrum)
show_2d_spec(simulated, xlim=(200, 241), ylim=(1705, 1730), vmin=0, vmax=4)

Show the real data at the same scale. This trace also appears as a nearly horizontal line. The signal level in the real data is slightly higher than that of the simulated data.

In [None]:
# Change xlim to (0, 700) to see the entire spectrum)
show_2d_spec(wfss_data, xlim=(200, 241), ylim=(1705, 1730), vmin=0, vmax=4)

A good way to check our simulated image is to subtract the simulation from our data. The figure below shows the simulated minus real difference image. The difference in peak level of the trace is visible. The alignment of the simulated and real traces is also apparent as the width of the difference is constant from left to right.

In [None]:
# Change xlim to (0, 700) to see the entire spectrum)
show_2d_spec(wfss_data - simulated, xlim=(200, 241), ylim=(1705, 1730), vmin=0, vmax=4)

Sum the simulated spectrum and also the real data in the dispersion direction to check if the trace profiles are well aligned. The figure below shows the summed real data (orange) and simulation (blue line). The left/right alignment of the two peaks shows that the location of the simulated trace is very close to that of the real trace.

In [None]:
plt.plot(np.nansum(simulated, axis=-1), color='blue')
plt.plot(np.nansum(wfss_data, axis=-1), color='orange')
plt.xlim(1705, 1730)
plt.ylim(0, 3000)
plt.grid()
plt.ylabel("Summed Counts (DN/s")
plt.xlabel("Cross dispersion coordinate (pixel)")

While the simulation appears to be astrometrically correct, we have not accounted for the dispersed background, causing the signal in the simulation to be too low.
This should be done using a model of the dispersed background but here, for simplicity and because we are looking at a region where we know the dispersed background is relatively featureless and flat, we can just use the simulation to create  a mask and then compute the background level per pixel.
Below we mask pixels in the simulated image above a signal level of 0.001 DN/s. This masks the trace. The remaining pixels, visible below in black, will be used to calculate the background level.

In [None]:
# Masked pixels, set to NaN, appear white in the display
mask = simulated > 0.0001
tmp = simulated * 1.
tmp[mask] = np.nan
# Change xlim to (0, 700) to see the entire spectrum)
show_2d_spec(tmp, xlim=(200, 241), ylim=(1705, 1730), vmin=0, vmax=1)

Compute the median dispersed background level.

In [None]:
bck_level = np.nanmedian(wfss_data[~mask])
print(f"Background extimated to be {bck_level} DN/s per pixel")

Now plot the collapsed simulated data (blue) and background-subtracted real data (orange). The well-aligned peaks show more clearly that the simulated and real traces are well-aligned. The peak of the simulated data is lower than that of the real data, showing that the simulation slightly under-estimates the signal in the real trace.

In [None]:
plt.plot(np.nansum(wfss_data - bck_level, axis=-1), color='orange')
plt.xlim(1705, 1730)
plt.ylim(-100, 2500)
plt.plot(np.nansum(simulated, axis=-1), color='blue')
plt.grid()
plt.ylabel("Summed Counts (DN/s")
plt.xlabel("Cross dispersion coordinate (pixel)")

After subtracting the background, the simulated spectrum is a good match to the real data. The 2D simulated spectrum can now be used to subtract contamination of overlapping spectra, as well as to mask dispersed spectral traces when trying to estimate the background.

<img style="float: right;" src="https://raw.githubusercontent.com/spacetelescope/notebooks/master/assets/stsci_pri_combo_mark_horizonal_white_bkgd.png" alt="Space Telescope Logo" width="200px"/>