# Run the MSAEXP pipeline steps for an example dataset from GO-4233 (RUBIES)

Note this is a copy of Gabe Brammer's msaexp example notebook: https://github.com/gbrammer/msaexp/tree/main/docs/examples

We will use this notebook to do reductions in an alternative way. msaexp still use parts of the STScI pipeline which we used before but there are some changes included to increase the data quality. 

**Exercise**
Go to https://www.stsci.edu/jwst/science-execution/program-information?id=4233 and explore the public pdf and the APT file to get some idea about the program. 

What is their main scientific objective?

**Pipeline steps**

1. Query mast and download full-frame exposures (``rate.fits``)
1. Run the preprocessing pipline through extracting 2D cutouts
1. Combine and rectify the cutouts to a final stack
1. Extract 1D spectrum
1. (fit redshift, line fluxes, etc. (``spectral-extractions-2024.ipynb``)

In [None]:
import os
os.environ['CRDS_PATH'] = os.path.expanduser('/lustre/JDAP/jdap_data/crds_cache')
os.environ['CRDS_SERVER_URL'] = 'https://jwst-crds.stsci.edu'



In [None]:
try:
    import numba
except ImportError:
    ! pip install numba

In [None]:
import os
import glob
import yaml
import warnings
import time

import numpy as np
import matplotlib.pyplot as plt

import grizli
from grizli import utils, jwst_utils
jwst_utils.set_quiet_logging()
utils.set_warnings()

import astropy.io.fits as pyfits
import jwst.datamodels
import jwst

import mastquery.jwst

import msaexp
from msaexp import pipeline
import msaexp.slit_combine

print(f'jwst version = {jwst.__version__}')
print(f'grizli version = {grizli.__version__}')
print(f'msaexp version = {msaexp.__version__}')

plt.rcParams['scatter.marker'] = '.'
plt.rcParams['image.origin'] = 'lower'
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['grid.linestyle'] = ':'

# Query MAST for NIRSpec data

Query by program name and download `rate.fits` files with `mastquery`.  May need to set `$MAST_TOKEN` environment variable to download proprietary datasets from MAST.

Can optionally limit the query to specific 

 - gratings:  ``prism``, ``g140m``, ``g235m``, ``g395m``, ``g140m``, ``g235m``, ``g395m``
 - filters:  ``clear``, ``f170lp``, ``f100lp``, ``f170lp``, ``f290lp``
 - detectors: ``nrs1``, ``nrs2``
 
 

In [None]:
## Optional
# Find the program, source_id based on a particular position
ra, dec = 214.97433534766, 52.92461350726

slits_url = f"https://grizli-cutout.herokuapp.com/nirspec_slits?coord={ra},{dec}"
slits = utils.read_catalog(slits_url, format='csv')

slits['program','msametfl','source_id','grating','footprint'][slits['is_source'] == 'True']

In [None]:
# JWST observing program ID
prog = 4233

# Single source for testing
source_ids = [
    44597, # line emitter
    46811, # more extended
]

# A single RUBIES mask
outroot = 'rubies-egs61'
mask_query = mastquery.jwst.make_query_filter('visit_id', values=['04233006001'])

gratings = ['prism']
detectors = ['nrs1'] # limit to NRS1 for the example

## Query and download

In [None]:
# Query NIRSpec data for a program name
masks = pipeline.query_program(prog,
    download=False,
    detectors=detectors,
    gratings=gratings,
    extensions=['uncal','s2d'],
    extra_filters=mask_query,
)

files = glob.glob(f'jw0{prog}*rate.fits')

print(files)

In [None]:
# Unset DQ=4 pixels to avoid running ``snowblind`` for now
for file in files:
    with pyfits.open(file, mode='update') as im:
        im['DQ'].data -= (im['DQ'].data & 4)
        im.flush()

# Initialize pipeline

Exposures are grouped by detector and with a common `MSAMETFL` metadata file for the MSA setup.

## Preprocessing pipeline

1. Apply 1/f correction and identify "snowballs" on the `rate.fits` files
1. Remove "bias" (i.e., simple median) of each exposure
1. Rescale RNOISE array based on empty parts of the exposure
1. Run parts of the Level 2 JWST calibration pipeline ([calweb_spec2](https://jwst-pipeline.readthedocs.io/en/latest/jwst/pipeline/calwebb_spec2.html#calwebb-spec2)):
    - [AssignWcs](https://jwst-pipeline.readthedocs.io/en/latest/api/jwst.assign_wcs.AssignWcsStep.html) : initialize WCS and populate slit bounding_box data
    - [Extract2dStep](https://jwst-pipeline.readthedocs.io/en/latest/api/jwst.extract_2d.Extract2dStep.html) : identify slits and set slit WCS
    - [FlatFieldStep](https://jwst-pipeline.readthedocs.io/en/latest/api/jwst.flatfield.FlatFieldStep.html#flatfieldstep) : slit-level flat field
    - [PathLossStep](https://jwst-pipeline.readthedocs.io/en/latest/api/jwst.pathloss.PathLossStep.html) : NIRSpec path loss
    - [BarShadowStep](https://jwst-pipeline.readthedocs.io/en/latest/api/jwst.barshadow.BarShadowStep.html#jwst.barshadow.BarShadowStep) : Bar Shadow
        - See also [MSAEXP PR#66](https://github.com/gbrammer/msaexp/pull/66)
    - [PhotomStep](https://jwst-pipeline.readthedocs.io/en/latest/api/jwst.photom.PhotomStep.html) : Photometric calibration
    - Note that the `srctype`, `master_background`, `wavecorr` steps are not performed.  The background subtraction is done manually on the 2D slit cutouts.
1. Parse slit metadata
1. Save slit cutout `SlitModel` files of the last pipeline step performed (`phot` = `PhotomStep`)

## Note! 

When the ``source_ids`` list is specified, the pipeline is only run for those sources in the MSA plan and will be much faster.  Set ``source_ids=None`` to extract *everything*.

In [None]:
SKIP_COMPLETED = True

for file in files:
    mode = '-'.join(file.split('_')[:4])    
    if (not os.path.exists(f'{mode}.slits.yaml')) & SKIP_COMPLETED:
        pipe = pipeline.NirspecPipeline(mode=mode, files=[file])

        pipe = pipeline.NirspecPipeline(mode=mode,
                                        files=[file],
                                        source_ids=source_ids,
                                        positive_ids=True # Ignore background slits
                                       ) 
        
        pipe.full_pipeline(run_extractions=False,
                           initialize_bkg=False,
                           load_saved=None,
                           scale_rnoise=True)
                
    else:
        print(f'Skip preprocessing: {mode}')

The final result of the preprocessing pipeline are the SlitModel objects stored in individual ``*.phot.*fits`` files

In [None]:
phot_files = glob.glob(f'jw0{prog}*{source_ids[0]}.fits')
phot_files.sort()

print('\n'.join(phot_files))

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(8,6), sharex=True, sharey=True)

for ax, file in zip(axes, phot_files):
    dm = jwst.datamodels.open(file)
    ax.imshow(dm.data, vmin=-0.1, vmax=0.3, aspect='auto', cmap='gray_r')
    ax.grid()
    
fig.tight_layout(pad=1)

# Exposure combination and spectral extraction

In [None]:
obj = msaexp.slit_combine.SlitGroup?

In [None]:
group_kws = dict(
    diffs=True, # For nod differences
    undo_barshadow=2, # For msaexp barshadow correction
    min_bar=0.35, # minimum allowed value for the (inverse) bar shadow correction
    position_key="y_index",
    trace_with_ypos=True, # Include expected y shutter offset in the trace
    trace_from_yoffset=True,
    flag_profile_kwargs=None, # Turn off profile flag
)    

obj = msaexp.slit_combine.SlitGroup(
    phot_files,
    outroot,
    **group_kws,
)

In [None]:
print(f"""
Number of exposures: {obj.N}
2D array shape: {obj.sh}
Flattened data array: {obj.sci.shape}
""")

Show the exposure data again now with the trace.

Also note that the sky is flatter than before with the updated bar shadow correction.

In [None]:
fig, axes = plt.subplots(obj.N, 1, figsize=(8,6), sharex=True, sharey=True)

for i, ax in enumerate(axes):
    ax.imshow(obj.sci[i,:].reshape(obj.sh), vmin=-0.1, vmax=0.3, aspect='auto', cmap='gray_r')
    ax.plot(obj.ytr[i,:], color='magenta', alpha=0.3, lw=4)
    ax.grid()
    
fig.tight_layout(pad=1)

## Fit the trace profile

The profile is modeled as a (pixel-integrated) Gaussian with a specified width that is added in quadrature with the tabulated 
PSF width.

In [None]:
obj.fit_all_traces?

In [None]:
fit = obj.fit_all_traces(
    offset_degree=0, # order of the offset polynomial to fit
    force_positive=False,
    x0=[2, 0.], # Initial guess:  gaussian width in pixels x 10, trace offset pixels
    niter=1,
    ref_exp=obj.calc_reference_exposure
)

In [None]:
# Show the profile fit
fig2d = obj.plot_2d_differences(fit=fit)

# Resample the spectra to a rectified pixel grid and get optimal 1D extraction

In [None]:
drizzle_kws = dict(
    step=1, # cross dispersion step size
    ny=15, # number of cross dispersion pixels
    with_pathloss=True, # use MSAEXP path loss that accounts for source size
    wave_sample=1.05, # wavelength sampling
    dkws=dict(oversample=16, pixfrac=0.8),
)

hdul = msaexp.slit_combine.combine_grating_group(
    {'prism': {'obj':obj, 'fit': fit}},
     ['prism'],
     drizzle_kws=drizzle_kws
)


In [None]:
hdul.info()

In [None]:
spec = utils.read_catalog(hdul['SPEC1D'])

fig, ax = plt.subplots(1,1,figsize=(10,5))

pl = ax.plot(spec['wave'], spec['flux'], label='msaexp', color='steelblue', alpha=0.5)
ax.plot(spec['wave'], spec['err'], color=pl[0].get_color(), alpha=0.5)

ax.legend()

ax.set_xlabel('wavelength, um')
ax.set_ylabel(r'$f_\nu$ $\mu$Jy')
ax.grid()

In [None]:
# 2D 
fig, ax = plt.subplots(1,1,figsize=(10,5))
ax.imshow(hdul['SCI'].data, vmin=-0.1, vmax=0.2, aspect='auto', cmap='gray_r')
fig.tight_layout(pad=0)

## Estimate the sky directly from the spectrum

If the sky is well determined, this can eliminate the need to take the differences of the nodded exposure

In [None]:
obj.estimate_sky?

In [None]:
estimate_sky_kwargs = dict(
    mask_yslit=[[-4.5, 4.5]], # mask pixels expected to contain the source
    min_bar=0.95,
    df=81, # number of splines to fit.  Needs to be high to fit the wiggles in the sky spectrum
    high_clip=1.0,
    make_plot=True,
)  

_ = obj.estimate_sky(**estimate_sky_kwargs)

The ``data`` attribute is ``sci - sky2d`` if ``sky2d`` is available

In [None]:
fig, axes = plt.subplots(obj.N, 1, figsize=(8,6), sharex=True, sharey=True)

for i, ax in enumerate(axes):
    ax.imshow(obj.data[i,:].reshape(obj.sh), vmin=-0.1, vmax=0.3, aspect='auto', cmap='gray_r')
    ax.plot(obj.ytr[i,:], color='magenta', alpha=0.3, lw=4)
    ax.grid()
    
fig.tight_layout(pad=1)

## Flag outliers based on the cross-dispersion profile

In [None]:
obj.flag_from_profile?

In [None]:
flag_profile_kwargs = dict(require_multiple=True, make_plot=True, grow=2, nfilt=-32)
obj.flag_from_profile(**flag_profile_kwargs)

## Turn off exposure differences and do resample and extraction again

In [None]:
obj.meta["diffs"] = False

In [None]:
drizzle_kws = dict(
    step=1, # cross dispersion step size
    ny=15, # number of cross dispersion pixels
    with_pathloss=True, # use MSAEXP path loss that accounts for source size
    wave_sample=1.05, # wavelength sampling
    dkws=dict(oversample=16, pixfrac=0.8),
)

hdul_nodiff = msaexp.slit_combine.combine_grating_group(
    {'prism': {'obj':obj, 'fit': fit}},
     ['prism'],
     drizzle_kws=drizzle_kws
)


In [None]:
# 2D 
fig, axes = plt.subplots(2,1,figsize=(10,6), sharex=True, sharey=True)
kws = dict(vmin=-0.1, vmax=0.2, aspect='auto', cmap='gray_r')

axes[0].imshow(hdul['SCI'].data, **kws)
axes[0].set_ylabel('Nod diffs')

axes[1].imshow(hdul_nodiff['SCI'].data, **kws)
axes[1].set_ylabel('Global sky')

fig.tight_layout(pad=0.5)

# Combination and extraction wrapped into a single script

In [None]:
msaexp.slit_combine.extract_spectra?

In [None]:
group_kws['diffs'] = True
group_kws['flag_profile_kwargs'] = flag_profile_kwargs

target=f'{prog}_{source_ids[0]}'

_ = msaexp.slit_combine.extract_spectra(
    target=target,
    root=outroot,
    **group_kws,
)

In [None]:
# With sky estimation
group_kws['diffs'] = False

_ = msaexp.slit_combine.extract_spectra(
    target=target,
    root=outroot,
    estimate_sky_kwargs=estimate_sky_kwargs,
    **group_kws,
)

In [None]:
! ls *{target}.*

# Fitting and analysis

Now go to the ``spectral-extractions-2024.ipynb`` notebook for a demo on fitting the spectra for redshift, lines, etc.

# Expand sky fit for spatially-extended object

In [None]:
# With sky estimation
group_kws['diffs'] = False
target=f'{prog}_{source_ids[1]}'

estimate_sky_kwargs['mask_yslit'] = [[-4.5, 4.5]]
estimate_sky_kwargs['high_clip'] = 0.5

_ = msaexp.slit_combine.extract_spectra(
    target=target,
    root=outroot,
    estimate_sky_kwargs=estimate_sky_kwargs,
    **group_kws,
)


In [None]:
estimate_sky_kwargs['mask_yslit'] = [[-4., 7.5]]
estimate_sky_kwargs['df'] = 51

result = msaexp.slit_combine.extract_spectra(
    target=target,
    root=outroot,
    estimate_sky_kwargs=estimate_sky_kwargs,
    **group_kws,
)

In [None]:
# Show profile near the Halpha line

hdu = result['prism']

wave_cuts = {
    'continuum blue': [2.5, 2.8],
    'continuum red': [3.0, 3.3],
    'line': [2.85, 2.91],
}

spec = utils.read_catalog(hdu['SPEC1D'])

y0 = (hdu['SCI'].header['NAXIS2'] - 1)/2
y_arcsec = (np.arange(hdu['SCI'].header['NAXIS2']) - y0)*hdu['SCI'].header['YPIXSCL']

fig, ax = plt.subplots(1,1,figsize=(8,5), sharex=True)

profile = {}

for cut in wave_cuts:
    slx = slice(*np.cast[int](np.round(np.interp(wave_cuts[cut], spec['wave'], np.arange(len(spec))))))
    data = hdu['SCI'].data[:, slx]
    wht = hdu['WHT'].data[:, slx]
    profile[cut] = np.nansum(data*wht, axis=1) / np.nansum(wht, axis=1)

    ax.plot(y_arcsec, profile[cut],
                 alpha=0.5,
                 color=('0.8' if cut.startswith('continuum') else 'pink')
    )

ax.set_ylabel('flux')

ax.fill_between(y_arcsec, y_arcsec*0.,
    (profile['continuum red'] + profile['continuum blue']) / 2.,
    color='0.4', alpha=0.3,
    label='continuum',
)

ax.fill_between(y_arcsec, y_arcsec*0.,
    profile['line'] - (profile['continuum red'] + profile['continuum blue']) / 2.,
    color='tomato', alpha=0.3,
    label='line - cont.',
)

ax.legend()

ax.set_xlabel(r'$\Delta y$, arcsec')

ax.set_xlim(-0.8, 0.8)
ax.set_ylim(-0.03, 0.3)

ax.grid()


_ = fig.tight_layout(pad=0.5)




