Based on JWST JDAT NIRSpec MOS Optimal Spectral Extraction



In [None]:
%matplotlib widget

from glob import glob
import numpy as np
from jwst.datamodels import ImageModel, MultiSpecModel
from astropy.io import fits
from astropy.modeling import models, fitting
from astropy.visualization import astropy_mpl_style, simple_norm
from specutils import Spectrum1D
from scipy.interpolate import interp1d, RegularGridInterpolator
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from ipywidgets import interact
import ipywidgets as widgets
import os
import tarfile
import urllib.request
plt.style.use(astropy_mpl_style) #use the style we imported for matplotlib displays


# select object and open it

In [None]:
obsid='200'
objid = '14552'## 5 digits
example_file = f'/jw2767-o{obsid}_s{objid}_nirspec_clear-prism'

s2d_file = os.path.join('', example_file+'s2d.fits')
x1d_file = os.path.join('', example_file+'x1d.fits')


In [None]:

data_model = ImageModel(s2d_file)
resampled_2d_image = data_model.data # if multiple SCI extensions, also specify EXTVER
weights_2d_image = data_model.wht # we will use this to estimate the per-pixel variance later
err_2d_image = data_model.err

image_shape = resampled_2d_image.shape
print(image_shape) #note the swap of x and y

In [None]:
norm = simple_norm(resampled_2d_image, stretch='power')
aspect_ratio = image_shape[1] / (2 * image_shape[0])
fig1 = plt.figure() # we save these in dummy variables to avoid spurious Jupyter Notebook output
img1 = plt.imshow(resampled_2d_image, cmap='gray', aspect=aspect_ratio, 
                  norm=norm, interpolation='none')
clb1 = plt.colorbar()

# 1D extraction

## Define a region for extraction

In [None]:
fig2 = plt.figure(figsize=(9,9)) # we want the largest figure that will fit in the notebook
img2 = plt.imshow(resampled_2d_image, cmap='gray', aspect=aspect_ratio, 
                  norm=norm, interpolation='none') # reuse norm from earlier

# create region box and slider
region_x = region_y = 0
region_h, region_w = image_shape
region_rectangle = Rectangle((region_x, region_y), region_w, region_h, 
                             facecolor='none', edgecolor='b', linestyle='--')
current_axis = plt.gca()
current_axis.add_patch(region_rectangle)

# interactive widget controls
def region(x1=0, y1=0, x2=region_w-1, y2=region_h-1):
    region_rectangle.set_bounds(x1, y1, x2-x1, y2-y1)
    plt.draw()
    
interact1 = interact(region, x1=(0, region_w-2, 1), y1=(0, region_h-2, 1), 
                    x2=(1, region_w-1, 1), y2=(1, region_h-1, 1))

In [None]:
#comment these lines out if interativity is not desired

x, y = region_rectangle.xy
w = region_rectangle.get_width() 
h = region_rectangle.get_height()
#uncomment and set these to your desired extraction region if interativity is not desired
# x = y = 0
# h, w = image_shape

print(x, y, x+w, y+h)

er_y, er_x = np.mgrid[y:y+h, x:x+w]
extraction_region = resampled_2d_image[er_y, er_x]
extraction_region_var = (data_model.err[er_y, er_x])**2

weights_region = weights_2d_image[er_y, er_x]


er_ny, er_nx = extraction_region.shape

aspect_ratio = er_nx / (3. * er_ny)

er_norm = simple_norm(extraction_region, stretch='power')
fig3 = plt.figure()
img3 = plt.imshow(extraction_region, cmap='gray', aspect=aspect_ratio, 
                  norm=er_norm, interpolation='none')
clb3 = plt.colorbar()

## create a profile for extraction

In [None]:
slice_width = 30
initial_column = er_nx // 2
initial_column=100
def kernel_slice_coadd(width, column_idx):
    """
    Coadd a number of columns (= width) of the extraction region,
    centered on column_idx.
    """
    
    half_width = width // 2
    to_coadd = np.arange(max(0, column_idx - half_width), 
                         min(er_nx-1, column_idx + half_width))
    return extraction_region[:, to_coadd].sum(axis=1) / width

slice_0 = kernel_slice_coadd(slice_width, initial_column)

In [None]:
fig4, (iax4, pax4) = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
plt.subplots_adjust(hspace=0.15, top=0.95, bottom=0.05)
img4 = iax4.imshow(extraction_region, cmap='gray', aspect=aspect_ratio, 
                  norm=er_norm, interpolation='none')

#create slice box
def make_slice(width, column_idx):
    sy, sh, sw = 0, er_ny, width
    sx = column_idx - width // 2
    return sx, sy, sw, sh

*sxy, sw, sh = make_slice(slice_width, initial_column)
slice_rectangle = Rectangle(sxy, sw, sh, facecolor='none', 
                            edgecolor='b', linestyle='--')
iax4.add_patch(slice_rectangle)

#plot the coadded slice
xd_pixels = np.arange(er_ny)
lin4, = pax4.plot(xd_pixels, slice_0, 'k-')
pax4.set_xlabel('Cross-dispersion pixel')
pax4.axes.set_ylabel('Coadded signal')

column_slider = widgets.IntSlider(initial_column, 0, er_nx-1, 1)
width_slider = widgets.IntSlider(slice_width, 1, er_nx-1, 1)

#interactive controls
def slice_update(column_idx, width):
    #update rectangle
    new_slice_box = make_slice(width, column_idx)
    slice_rectangle.set_bounds(*new_slice_box)
    #update line plot
    lin4.set_ydata(kernel_slice_coadd(width, column_idx))
    #update the axis limits
    pax4.relim()
    pax4.autoscale_view()
    plt.draw()

interact2 = interact(slice_update, column_idx=column_slider, width=width_slider)

In [None]:
kernel_slice = kernel_slice_coadd(width_slider.value, column_slider.value)
bbox_extraction=  make_slice(width_slider.value, column_slider.value) ## for fig at the end

In [None]:
max_pixel = np.argmax(kernel_slice)
fwhm = 1.

moffat_profile = models.Moffat1D(amplitude=1, gamma=fwhm, x_0=max_pixel, alpha=1)
gauss_profile = models.Gaussian1D(amplitude=1, mean=max_pixel, stddev=fwhm)

fig5 = plt.figure()
kern5 = plt.plot(xd_pixels, kernel_slice / kernel_slice[max_pixel], label='Kernel Slice')
moff5 = plt.plot(xd_pixels, moffat_profile(xd_pixels), label='Moffat Profile')
gaus5 = plt.plot(xd_pixels, gauss_profile(xd_pixels), label='Gaussian Profile')
lgd5 = plt.legend()

In [None]:
psf_template = gauss_profile
psf_template.amplitude = kernel_slice[max_pixel]
print(psf_template)
# If deblending multiple sources, add more PSF templates here:




In [None]:
background_poly = models.Polynomial1D(2)
print(background_poly)

In [None]:
### ignore the background component for now
extraction_kernel = psf_template# + background_poly
print(extraction_kernel)

In [None]:
fitter = fitting.LevMarLSQFitter()
fit_extraction_kernel = fitter(extraction_kernel, xd_pixels, kernel_slice)
print(fit_extraction_kernel)

fit_line = fit_extraction_kernel(xd_pixels)

fig6, (fax6, fln6) = plt.subplots(nrows=2, ncols=1, figsize=(8, 12))
plt.subplots_adjust(hspace=0.15, top=0.95, bottom=0.05)

if fit_extraction_kernel.n_submodels==1:
    psf6 = fax6.plot(xd_pixels, fit_extraction_kernel(xd_pixels), label="PSF")
else:
    psf6 = fax6.plot(xd_pixels, fit_extraction_kernel[0](xd_pixels), label="PSF")
    poly6 = fax6.plot(xd_pixels, fit_extraction_kernel[1](xd_pixels), label="Background")
sum6 = fax6.plot(xd_pixels, fit_line, label="Composite Kernel")
lgd6a = fax6.legend()
lin6 = fln6.plot(xd_pixels, kernel_slice, label='Kernel Slice')
fit6 = fln6.plot(xd_pixels, fit_line, 'o', label='Extraction Kernel')
lgd6b = fln6.legend()

## if the spatial profile needs to be varied as a function of wavelength

In [None]:
# from astropy.stats import sigma_clip

# n_bin = 100
# bin_width = er_nx // n_bin
# bin_centers = np.arange(0, er_nx, bin_width+1, dtype=float) + bin_width // 2
# binned_spectrum = np.hstack([extraction_region[:, i:i+bin_width+1].sum(axis=1)[:, None] 
#                                  for i in range(0, er_nx, bin_width+1)])
# bin_fwhms = np.zeros_like(bin_centers, dtype=float)

# for y in range(bin_centers.size):
#     bin_fit = fitter(fit_extraction_kernel, xd_pixels, binned_spectrum[:, y])
#     if fit_extraction_kernel.n_submodels==1:
#         bin_fwhms[y] = bin_fit.stddev.value
#     else:
#         bin_fwhms[y] = bin_fit.stddev_0.value
    
# bin_ny, bin_nx = binned_spectrum.shape
# bin_ar = bin_nx / (3 * bin_ny)

# fig_fwhm, ax_fwhm = plt.subplots(nrows=2, ncols=1, figsize=(6, 10))
# plt.subplots_adjust(hspace=0.05)
# fwhm_img = ax_fwhm[0].imshow(binned_spectrum, aspect=bin_ar, interpolation='none',
#                              cmap='gray')
# fwhm_plot = ax_fwhm[1].plot(bin_centers, bin_fwhms)
# xlbl_fwhm = ax_fwhm[1].set_xlabel("Bin center (px)")
# ylbl_fwhm = ax_fwhm[1].set_ylabel("FWHM (arcsec)")

## this gives an option to fix rectification issues if needed


In [None]:
trace_center_model = models.Polynomial1D(0) #we use a constant because the spectrum has already been rectified

if fit_extraction_kernel.n_submodels==1:
    trace_center_model.c0 = fit_extraction_kernel.mean.value # use the parameter for center of the PSF profile

else:
    trace_center_model.c0 = fit_extraction_kernel.mean_0.value # use the parameter for center of the PSF profile

    
print(trace_center_model)


## create a noise spectrum

In [None]:
# scale = 1.0 # adjust this if and when the NIRSpec PIXFRAC changes

# # We want any pixel with 0 weight to be excluded from the calculation
# # in the next step, so we'll use masked array operations.
# bad_pixels = weights_region == 0
# masked_wht = np.ma.array(weights_region, mask=bad_pixels)
# variance_image = np.ma.divide(1., weights_region * scale**4)

# # variance_image =  np.ma.multiply(error_region**2 , weights_region)

In [None]:
# from copy import copy

# fig_var = plt.figure()
# palette = copy(plt.cm.gray)
# palette.set_bad('r', alpha=0.7)
# var_norm = simple_norm(variance_image, stretch='log', min_cut=0.006, max_cut=0.1)
# img_var = plt.imshow(variance_image, interpolation='none', aspect=aspect_ratio, norm=var_norm, cmap=palette)
# plt.colorbar()


# ### red regions should be avoided

## generate 1D 

In [None]:
spectrum = np.zeros(er_nx, dtype=float) #initialize our spectrum with zeros
error = np.zeros(er_nx, dtype=float) #initialize our spectrum with zeros


column_pixels = np.arange(er_nx)
trace_centers = trace_center_model(column_pixels) # calculate our trace centers array

# Loop over columns
for x in column_pixels:
    # create the kernel for this column, using the fit trace centers
    kernel_column = fit_extraction_kernel.copy()
    kernel_column.mean_0 = trace_centers[x]
    # kernel_column.stddev_0 = fwhm_fit(x) # if accounting for a varying FWHM, uncomment this line.
    kernel_values = kernel_column(xd_pixels)
    
    # isolate the relevant column in the spectrum and variance images
    variance_column = extraction_region_var[:, x] # remember that numpy arrays are row, column
    image_pixels = extraction_region[:, x]
    
    # calculate the kernal normalization
    g_x = np.nansum(kernel_values**2 / variance_column)
    if np.ma.is_masked(g_x): #this column isn't valid, so we'll skip it
        continue
    
    # and now sum the weighted column
    weighted_column = np.divide(image_pixels * kernel_values, variance_column)
    spectrum[x] = np.nansum(weighted_column) / g_x
    
    error[x] = (1. / g_x)**0.5
    

    

### create wavelength seperately from data header

In [None]:
wcs = data_model.meta.wcs
# print(wcs.__repr__())
alpha_C, delta_C, y = wcs(er_x, er_y)
wavelength = y[0]

## plot 1d spectrum

In [None]:
fig7 = plt.figure()

plt.imshow(extraction_region, aspect=aspect_ratio)

In [None]:
spec_1d = np.median(extraction_region, axis=0)

In [None]:
fig7 = plt.figure()
spec7 = plt.plot(wavelength, spectrum, 'b-')

err7 = plt.plot(wavelength, error, 'r-')


# plt.plot(wavelength, spec_1d/1e12, 'g-')

## Save 1d extraction

In [None]:
def batch_save_extracted_spectrum(filename, wavelength, spectrum):
    """
    Quick & dirty fits dump of an extracted spectrum.
    Replace with your preferred output format & function.
    """
    
    wcol = fits.Column(name='wavelength', format='E', 
                       array=wavelength)
    scol = fits.Column(name='spectrum', format='E',
                       array=spectrum)
    
    ecol = fits.Column(name='error', format='E',
                       array=error)
    
    cols = fits.ColDefs([wcol, scol, ecol])
    hdu = fits.BinTableHDU.from_columns(cols)
    hdu.writeto(filename, overwrite=True)

In [None]:
filename = f'onedspec_obs{obsid}_{objid}.fits'

batch_save_extracted_spectrum(filename , wavelength, spectrum)

## save a fig with extraction process

In [None]:
def batch_plot_output(resampled_image, extraction_bbox, 
                kernel_slice, kernel_model,
                wavelength, spectrum, filename):
    """
    Convenience function for summary output figures,
    allowing visual inspection of the results from 
    each file being processed.
    """
    
    fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, 
                                        figsize=(8,12))
    fig.suptitle(filename)
    
    ny, nx = resampled_image.shape
    aspect = nx / (2 * ny)
    
    # Subplot 1: Extraction region
    power_norm = simple_norm(resampled_image, 'power')
    er_img = ax1.imshow(resampled_image, interpolation='none',
               aspect=aspect, norm=power_norm, cmap='gray')
    rx, ry, rw, rh = extraction_bbox
    region = Rectangle((rx, ry), rw, rh, facecolor='none', 
                       edgecolor='b', linestyle='--')
    er_ptch = ax1.add_patch(region)
    
    # Subplot 2: Kernel fit
    xd_pixels = np.arange(kernel_slice.size)
    fit_line = kernel_model(xd_pixels)
    ks_line = ax2.plot(xd_pixels, kernel_slice, label='Kernel Slice')
    kf_line = ax2.plot(xd_pixels, fit_line, 'o', label='Extraction Kernel')
    k_lgd = ax2.legend()
    
    # Subplot 3: Extracted spectrum
    spec_line = ax3.plot(wavelength, spectrum)
    
    fig.savefig(filename, bbox_inches='tight')
    plt.close(fig)

In [None]:
batch_plot_output(extraction_region, bbox_extraction, 
                kernel_slice, fit_extraction_kernel,
                wavelength, spectrum, filename.replace('fits', 'pdf'))

In [None]:
kernel_model