<a id="title_ID"></a>
# JWST Pipeline Validation Notebook: Outlier Detection for MIRI MRS

<span style="color:red"> **Instruments Affected**</span>: MIRI


Tested on MIRI Simulated data



#### Author: Isha Nayak

This notebook checks the outlier detection step of the pipeline during calwebb_spec3 for Channel 1 long wavelengths.

First this notebook checks ten different locations in the detector and confirm these pixels in the detector fall in the range of the Channel 1 long wavelength range (6.42 - 7.51 microns). These outliers range from 3000 to 30000 in value to make sure a range of outlier values can be detected in the pipeline. After the outliers are injected in the detector frame, a file is saved.

We then run the cube build step (without outlier detection) and name this file with suffix 'before' to denote the effect before outlier detection step. Then we run outlier detection and cube build steps of calwebb_spec3. The output from the second run ends in suffix 'after' since this is after the outlier detection step has been run.

The image and the mask in the detector frame is shown to confirm proper (x,y) coordinates have been chosen. Then the cube slice of the expected outlier is shown with and without the outlier detector step for each of the ten outliers inserted. 

Outliers at varying flux levels get detected and removed with the pipeline. The flux of the central source different when using outlier detection in comparison to when not using outlier detection. This difference is as high as 35%. We set this as the criteria for passing.

In [None]:
import os
if 'CRDS_CACHE_TYPE' in os.environ:
    if os.environ['CRDS_CACHE_TYPE'] == 'local':
        os.environ['CRDS_PATH'] = os.path.join(os.environ['HOME'], 'crds', 'cache')
    elif os.path.isdir(os.environ['CRDS_CACHE_TYPE']):
        os.environ['CRDS_PATH'] = os.environ['CRDS_CACHE_TYPE']
print('CRDS cache location: {}'.format(os.environ['CRDS_PATH']))

In [None]:
# Basic system utilities for interacting with files
import glob, sys, os, time

# Astropy utilities for opening FITS and ASCII files
from astropy.io import fits
from astropy.io import ascii

# Astropy utilities for making plots
from astropy.visualization import (LinearStretch, LogStretch, ImageNormalize, ZScaleInterval)

# Numpy for doing calculations
import numpy as np

# Matplotlib for making plots
import matplotlib.pyplot as plt
from matplotlib import rc

# JWST pipelines
from jwst.pipeline import Detector1Pipeline
from jwst.pipeline import Spec2Pipeline
from jwst.pipeline import Spec3Pipeline

# Individual JWST pipeline steps
from jwst.assign_wcs import AssignWcsStep
from jwst.background import BackgroundStep
from jwst.flatfield import FlatFieldStep
from jwst.srctype import SourceTypeStep
from jwst.straylight import StraylightStep
from jwst.fringe import FringeStep
from jwst.photom import PhotomStep
from jwst.cube_build import CubeBuildStep
from jwst.extract_1d import Extract1dStep
from jwst.cube_skymatch import CubeSkyMatchStep
from jwst.master_background import MasterBackgroundStep
from jwst.outlier_detection import OutlierDetectionStep
from jwst.extract_1d import Extract1dStep

# JWST pipeline utilities
from jwst.datamodels import dqflags
from jwst import datamodels
from jwst.associations import asn_from_list as afl
from jwst.associations.lib.rules_level2_base import DMSLevel2bBase
from jwst.associations.lib.rules_level3_base import DMS_Level3_Base
import stcal

# MIRIcoord for detector to pixel conversion
import miricoord
import miricoord.mrs.mrs_tools as mt

# Box download imports 
from astropy.utils.data import download_file
from pathlib import Path
from shutil import move
from os.path import splitext

### Create a temporary location for the data

In [None]:
# Create a temporary directory to hold notebook output, and change the working directory to that directory.
from tempfile import TemporaryDirectory
import os
data_dir = TemporaryDirectory()
os.chdir(data_dir.name)

# For info, print out where the script is running
print("Running in {}".format(os.getcwd()))

In [None]:
# Check JWST version
import jwst
print(jwst.__version__ )

In [None]:
# Check stcal version
print(stcal.__version__ )

### Read in data from Box

In [None]:
# Function for Box files
def get_box_files(file_list):
    for box_url,file_name in file_list:
        if 'https' not in box_url:
            box_url = 'https://stsci.box.com/shared/static/' + box_url
        downloaded_file = download_file(box_url)
        if Path(file_name).suffix == '':
            ext = splitext(box_url)[1]
            file_name += ext
        move(downloaded_file, file_name)

In [None]:
# Get the files from Box
file_urls = ['https://stsci.box.com/shared/static/7va325g09uesfh9wb569sedcefll4cv3.fits', 
                  'https://stsci.box.com/shared/static/jsyf3k5frn4w3zzw7hrqbbe92iszgye5.fits',   
                  'https://stsci.box.com/shared/static/6tlbhwco7qz98qvnew6x0797cugp68re.fits',    
                  'https://stsci.box.com/shared/static/w3oj50ei9py7i9e9jmi3d2oyr40127us.fits '] 
file_names = ['det_image_seq1_MIRIFUSHORT_12LONGexp1_cal.fits',              
               'det_image_seq2_MIRIFUSHORT_12LONGexp1_cal.fits',             
               'det_image_seq3_MIRIFUSHORT_12LONGexp1_cal.fits',            
               'det_image_seq4_MIRIFUSHORT_12LONGexp1_cal.fits']  
box_download_list = [(url,name) for url,name in zip(file_urls,file_names)]  
get_box_files(box_download_list)

### Start data processing

In [None]:
# Look for our _rate.fits files produced by the Detector1 pipeline
sstring='det*cal.fits'
calfiles=sorted(glob.glob(sstring))
print(calfiles)

In [None]:
# Check detector array to pixel array values for ten inserted outliers and check the exepected wavelength
lambda_wav=[0.0]*9
n=0
skip=100
for x in range(0, 10):
    count=0
    for i in range(355, 358):
        for j in range(20+(n*skip), 23+(n*skip)):
            values=mt.xytoabl([i],[j],'1C')
            lambda_wav[count]=values['lam']
            count=count+1
    print(n+1)
    print(min(lambda_wav))    
    print(max(lambda_wav))
    n=n+1

In [None]:
# Hack the file
hdu=fits.open(calfiles[0])
data=hdu['SCI'].data

# Insert the ten outliers varying from 3000 to 30000
n=0
skip=100
for x in range(0, 10):
    for i in range(355,358):
        for j in range(20+(n*skip), 23+(n*skip)):
            data[j,i]=3000*(n+1)
    n=n+1
hdu['SCI'].data=data

# Overwrite the file
hdu.writeto(str.replace(calfiles[0],'cal','od_test'),overwrite=True)
hdu.close()

In [None]:
# Define a useful function to write out a Lvl3 association file from an input list
def writel3asn(files,asnfile,prodname,**kwargs):
    asn = afl.asn_from_list(files,rule=DMS_Level3_Base,product_name=prodname)
    if ('bg' in kwargs):
        for bgfile in kwargs['bg']:
            asn['products'][0]['members'].append({'expname': bgfile, 'exptype':'background'})
    _, serialized = asn.dump()
    with open(asnfile, 'w') as outfile:
        outfile.write(serialized)

In [None]:
# Create an association file
testfiles=calfiles.copy()
testfiles[0]=str.replace(calfiles[0],'cal','od_test')
writel3asn(testfiles,'od.json','od')

In [None]:
# Run it through cube building calling the result 'od_before'
cb=CubeBuildStep()
cb.call('od.json',channel='1',save_results=True,output_file='od_before')

In [None]:
# Run this association through the Spec3 pipeline with just outlier detection and cube build
spec3=Spec3Pipeline()
spec3.save_results = True
spec3.master_background.skip = True
spec3.mrs_imatch.skip = True
spec3.outlier_detection.save_intermediate_results = True
spec3.outlier_detection.scale = '2.0 2.0'
spec3.cube_build.channel='1'
spec3.cube_build.output_file='od_after'
spec3('od.json')

In [None]:
# Show the image and mask
hdu=fits.open('det_image_seq1_MIRIFUSHORT_12LONGexp1_od_test_a3001_crf.fits')
flux=hdu['SCI'].data
dq=hdu['DQ'].data

# Use a classic ZScale normalization
norm = ImageNormalize(flux, interval=ZScaleInterval(),stretch=LinearStretch())

rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1,2, figsize=(7,7),dpi=100)

# Plot the data to visually check outliers were inserted correctly
ax1.imshow(flux, cmap='gray',norm=norm,origin='lower')
ax1.set_title('2d SCI array')
ax1.set_xlabel('X pixel')
ax1.set_ylabel('Y pixel')
ax1.set_xlim(340,370)
ax1.set_ylim(0, 300)

ax2.imshow(dq, cmap='gray',vmin=0,vmax=1,origin='lower')
ax2.set_title('2d DQ array')
ax2.set_xlabel('X pixel')
ax2.set_xlim(340,370)
ax2.set_ylim(0, 300)

In [None]:
# Check DQ flags due to the the above plots being messy
print('flux:', flux[21,355], 'DQ Flag:', dq[21,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[21,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[121,355], 'DQ Flag:', dq[121,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[121,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[221,355], 'DQ Flag:', dq[221,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[221,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[321,355], 'DQ Flag:', dq[321,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[321,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[421,355], 'DQ Flag:', dq[421,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[421,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[521,355], 'DQ Flag:', dq[521,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[521,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[621,355], 'DQ Flag:', dq[621,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[621,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[721,355], 'DQ Flag:', dq[721,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[721,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[821,355], 'DQ Flag:', dq[821,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[821,355],mnemonic_map=datamodels.dqflags.pixel))

print('flux:',flux[921,355], 'DQ Flag:', dq[921,355], 'DQ Name:', dqflags.dqflags_to_mnemonics(dq[921,355],mnemonic_map=datamodels.dqflags.pixel))

In [None]:
# Close files
hdu.close()

In [None]:
# Cube without outlier rejection
hdu1=fits.open('od_before_ch1-long_s3d.fits')
flux1=hdu1['SCI'].data

# Cube with outlier rejection
hdu2=fits.open('od_after_ch1-long_s3d.fits')
flux2=hdu2['SCI'].data

# Use a classic ZScale normalization
norm = ImageNormalize(flux1, interval=ZScaleInterval(),stretch=LinearStretch())

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[1,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 3000')

ax2.imshow(flux2[1,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 3000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[126,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 6000')

ax2.imshow(flux2[126,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 6000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[249,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 9000')

ax2.imshow(flux2[249,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 9000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[369,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 12000')

ax2.imshow(flux2[369,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 12000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[486,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 15000')

ax2.imshow(flux2[486,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 15000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[600,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 18000')

ax2.imshow(flux2[600,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 18000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[712,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 21000')

ax2.imshow(flux2[712,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 21000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[821,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 24000')

ax2.imshow(flux2[821,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 24000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[927,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 27000')

ax2.imshow(flux2[927,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 27000')

In [None]:
rc('axes', linewidth=2)            
fig, (ax1,ax2) = plt.subplots(1, 2, figsize=(7,7),dpi=100)

# Plot data
ax1.imshow(flux1[1030,:,:], cmap='gray',norm=norm,origin='lower')
ax1.set_title('No Outlier Det., Input 30000')

ax2.imshow(flux2[1030,:,:], cmap='gray',norm=norm,origin='lower')
ax2.set_title('With Outlier Det., Input 30000')


In [None]:
# Plot spectrum of point source before outlier detection
image1,header1=fits.getdata('od_before_ch1-long_s3d.fits',header=True)
num_x1=header1["NAXIS2"]
num_y1=header1["NAXIS1"]
num_chan1=header1["NAXIS3"]
start_wavelength1=header1["CRVAL3"]
step_wavelength1=header1["CDELT3"]
pix_size1=header1["PIXAR_SR"]

# Get wavelength
d1=[0.0]*num_chan1
d1[0]=start_wavelength1
for i in range(1,num_chan1):
    d1[i]=d1[i-1]+step_wavelength1

# Get flux
a1=[0.0]*num_chan1
for i in range(0,num_chan1):
    for m in range(10,30):
        for n in range(10,30):
            a1[i]=image1[i,m,n]+a1[i] 
for i in range(0,num_chan1):
    a1[i]=a1[i]*(10**6)*(pix_size1)

# Plot spectrum of point source before outlier detection
plt.plot(d1,a1,'-', color='blue', lw=1, label='without outlier detection')

# Plot spectrum of point source after outlier detection
image2,header2=fits.getdata('od_after_ch1-long_s3d.fits',header=True)
num_x2=header2["NAXIS2"]
num_y2=header2["NAXIS1"]
num_chan2=header2["NAXIS3"]
start_wavelength2=header2["CRVAL3"]
step_wavelength2=header2["CDELT3"]
pix_size2=header2["PIXAR_SR"]

#Get wavelength
d2=[0.0]*num_chan1
d2[0]=start_wavelength2
for i in range(1,num_chan2):
    d2[i]=d2[i-1]+step_wavelength2

#Get flux
a2=[0.0]*num_chan2
for i in range(0,num_chan2):
    for m in range(10,30):
        for n in range(10,30):
            a2[i]=image2[i,m,n]+a2[i] 
for i in range(0,num_chan2):
    a2[i]=a2[i]*(10**6)*(pix_size2)

#Plot spectrum
plt.plot(d2,a2,'-', color='black', lw=1, label='with outlier detection')

#Edit plot settings
plt.ylim(0.007, 0.020)
plt.legend()

In [None]:
# Pass/Fail criteria that determines flux of central source before and after outlier detection
before_flux=[0.0]*10
after_flux=[0.0]*10

sum_before1=0.0
sum_after1=0.0
sum_before2=0.0
sum_after2=0.0
sum_before3=0.0
sum_after3=0.0
sum_before4=0.0
sum_after4=0.0
sum_before5=0.0
sum_after5=0.0
sum_before6=0.0
sum_after6=0.0
sum_before7=0.0
sum_after7=0.0
sum_before8=0.0
sum_after8=0.0
sum_before9=0.0
sum_after9=0.0
sum_before10=0.0
sum_after10=0.0

for i in range(15,30):
    for j in range(15, 30):
        sum_before1=sum_before1+flux1[1,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after1=sum_after1+flux2[1,j,i]

for i in range(15,30):
    for j in range(15, 30):
        sum_before2=sum_before2+flux1[126,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after2=sum_after2+flux2[126,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_before3=sum_before3+flux1[249,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after3=sum_after3+flux2[249,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_before4=sum_before4+flux1[369,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after4=sum_after4+flux2[369,j,i]
        
for i in range(15,30):
    for j in range(15,30):
        sum_before5=sum_before5+flux1[486,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after5=sum_after5+flux2[486,j,i]
        
for i in range(15,30):
    for j in range(15,30):
        sum_before6=sum_before6+flux1[600,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after6=sum_after6+flux2[600,j,i]
        
for i in range(15,30):
    for j in range(15,30):
        sum_before7=sum_before7+flux1[712,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after7=sum_after7+flux2[712,j,i]
        
for i in range(15,30):
    for j in range(15,30):
        sum_before8=sum_before8+flux1[821,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after8=sum_after8+flux2[821,j,i]
        
for i in range(15,30):
    for j in range(15,30):
        sum_before9=sum_before9+flux1[927,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after9=sum_after9+flux2[927,j,i]
        
for i in range(15,30):
    for j in range(15,30):
        sum_before10=sum_before10+flux1[1030,j,i]

for i in range(15,30):
    for j in range(15,30):
        sum_after10=sum_after10+flux2[1030,j,i]

for a in range(0,10):
    before_flux[0]=sum_before1
    before_flux[1]=sum_before2
    before_flux[2]=sum_before3
    before_flux[3]=sum_before4
    before_flux[4]=sum_before5
    before_flux[5]=sum_before6
    before_flux[6]=sum_before7
    before_flux[7]=sum_before8
    before_flux[8]=sum_before9
    before_flux[9]=sum_before10

for a in range(0,10):
    after_flux[0]=sum_after1
    after_flux[1]=sum_after2
    after_flux[2]=sum_after3
    after_flux[3]=sum_after4
    after_flux[4]=sum_after5
    after_flux[5]=sum_after6
    after_flux[6]=sum_after7
    after_flux[7]=sum_after8
    after_flux[8]=sum_after9
    after_flux[9]=sum_after10

print(before_flux[0],',',after_flux[0],',',100*(before_flux[0]-after_flux[0])/(before_flux[0]))
print(before_flux[1],',',after_flux[1],',',100*(before_flux[1]-after_flux[1])/(before_flux[1]))
print(before_flux[2],',',after_flux[2],',',100*(before_flux[2]-after_flux[2])/(before_flux[2]))
print(before_flux[3],',',after_flux[3],',',100*(before_flux[3]-after_flux[3])/(before_flux[3]))
print(before_flux[4],',',after_flux[4],',',100*(before_flux[4]-after_flux[4])/(before_flux[4]))
print(before_flux[5],',',after_flux[5],',',100*(before_flux[5]-after_flux[5])/(before_flux[5]))
print(before_flux[6],',',after_flux[6],',',100*(before_flux[6]-after_flux[6])/(before_flux[6]))
print(before_flux[7],',',after_flux[7],',',100*(before_flux[7]-after_flux[7])/(before_flux[7]))
print(before_flux[8],',',after_flux[8],',',100*(before_flux[8]-after_flux[8])/(before_flux[8]))
print(before_flux[9],',',after_flux[9],',',100*(before_flux[9]-after_flux[9])/(before_flux[9]))

In [None]:
# Determine if the notebook passes or fails
count=0
for i in range(0,10):
    if 100*(before_flux[i]-after_flux[i])/(before_flux[i]) > 35:
        count=count+1

if count>1:
    print('This notebook does not pass.')
else:
    print('This notebook passes.')

In [None]:
# Close files
hdu1.close()
hdu2.close()