## Scattered-light subtraction for KCWI data

This can be used as an alternative to KCWI pipeline stage 2 to obtain a better subtraction of scattered light.  
https://github.com/Keck-DataReductionPipelines/KcwiDRP

##### Preperation

There are two input files:
1. The output intensity image from KCWI stage 1, e.g. kb190101_00011_int.fits
2. A data file indicating the starting and ending column number of the gaps, e.g. kb190101_00011_gaps.dat

Note for the gap file:
- The bound for the gaps can be taken loosely.
- One gap file, in most cases, could work for multiple exposures. 
- Format: start end  

##### Functions:
1. Main scattered-light subtraction function: `scat_sub(int_file,gap_file,model,**kwargs)`
2. Slice refit function: `refit_slice(int_file,gap_file,slice_num,model,**kargs)`

##### Return:
1. Display and save plots of the fit to the working directory
2. Reduced file '*intd.fits' in the same directory as the input file

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import astropy.io.fits as fits
from astropy.modeling import models, fitting
import shutil

In [2]:
# Scattered-light subtraction with Moffat model or Ricker wavelet
# ----------------------------------------------------------------------------------------
# This function performs scattered-light subtraction for KCWI data and can be used 
# in stead of Stage 2 of the KCWI data reduction pipeline. It uses a 1D 2nd order
# polynomial and Moffat 1D / Ricker Wavelet 1D to fit the scattered-light.
# The returned plots can help recognize whether a good fit is obtained, and serve as
# a reference for twekaing the parameters.
# ----------------------------------------------------------------------------------------
# Call:
#    scat_sub(int_file,gap_file,model,*dy=257,*amp=1,*cent=1000,*width=150,*level=5)
# ----------------------------------------------------------------------------------------
# Inputs:
#   Required:
#      int_file:   (str) path to file, e.g. '/data/kb190101_00011_int.fits'
#      gap_file:   (str) path to gap file, e.g. '/data/kb190101_00011_gaps.dat'
#         model:   (str) model to fit, 'moffat' or 'ricker'
#   Optional:
#            dy:   (int) number of rows for each slice, the default value is 257 (8 slices)
#           amp: (float) amplitude for Moffat/Ricker model, default amp = 1
#          cent: (float) center of the model, x position of the maximum, default cent = 1000
#         width: (float) width of the model/wavelet, default width = 150
#         level: (float) C0 in 2nd order polynomial, default level = 5
# ----------------------------------------------------------------------------------------
# Outputs:
#          plots:  original data points and fitted model, saved to the working directory
#   reduced_data:  saved in the same directory of the input data as *intd.fits
# ----------------------------------------------------------------------------------------
# Notes:
#  1. This code does not use data points from the central gap.
#  2. Moffat model is preferred for data in which the peak of scattered light falls at
#     one side of the central gap. Otherwise, try ricker model to see if a good fit can 
#     be obtained.
#     It is always a good idea to decide which to use after examing the data.
#  3. Both functions might be sensitive to initial conditions. See [5] in Reference.
#  4. Use the refit function below if the fitting in some slices is not satisfactory.
# ----------------------------------------------------------------------------------------
# Reference:
#  1. Rupke D., Scattered-light subtraction code
#      https://github.com/drupke/ifsred/blob/master/kcwi/ifsr_kcwiscatsub.pro
#  2. KCWI Data Reduction Pipeline, Stage 2
#      https://github.com/Keck-DataReductionPipelines/KcwiDRP
#  3. Bernstein B., Fernandez-Granda C., Deconvolution of Point Sources: A Sampling
#     Theorem and Robustness Guarantees, Comm. Pure Appl. Math., vol 72, May 2018,
#     doi:10.1002/cpa.21805
#      https://arxiv.org/pdf/1707.00808.pdf
#  3. Grundahl F., Sørensen A.N., Detection of scattered light in telescopes, 
#     Astron. Astrophys. Suppl. Ser. 116, 367-371 (1996), doi:10.1051/aas:1996119
#      https://aas.aanda.org/articles/aas/pdf/1996/05/ds1089.pdf
#  4. Morrissey P., Matuszewski M., Martin D.C., et al., THE KECK COSMIC WEB
#     IMAGER INTEGRAL FIELD SPECTROGRAPH, doi:10.3847/1538-4357/aad597
#      https://arxiv.org/pdf/1807.10356.pdf
#  5. More, J.j., The Levenberg-Marquardt algorithm: Implementation and theory, 1977,
#     https://www.osti.gov/servlets/purl/7256021/
# ----------------------------------------------------------------------------------------
# Updates: 
#    Aug 21, 2020  W.Ning  Modified to save plots. Updated documentation.
#    Aug 18, 2020  W.Ning  Updated documentation.
# ----------------------------------------------------------------------------------------


def slice_fit(data,gap,model,yi,yf,amp,cent,width,level):
    x = []; y = []; yi = np.int(yi); yf = np.int(yf);
    for i in gap:
        if i[1]-i[0]==0:
            x.append(i[0])
            y.append(np.sum(data[yi:yf,int(i[0])]))
        else:
            x.append(int((i[0]+i[1])/2))
            y.append(np.min(np.sum(data[yi:yf,int(i[0]):int(i[1])],axis=0)))
    if model == 'moffat':
        m_init = models.Moffat1D(amplitude=amp, x_0=cent, gamma=width,alpha=2) + models.Polynomial1D(degree=2,c0=level)
    if model == 'ricker':
        m_init = models.RickerWavelet1D(amplitude=amp, x_0=cent, sigma=width) + models.Polynomial1D(degree=2,c0=level)
    fit_m = fitting.LevMarLSQFitter()
    m = fit_m(m_init, x, np.divide(y,yf-yi))
    return x,y,m

def fit_plot(nslice,dy,result,name):
    fig, axes = plt.subplots(nrows=int(np.ceil(nslice/2)),ncols=2,figsize=(16,6*int(np.ceil(nslice/2))))
    fig.suptitle(name,fontsize=14,y=0.92)
    axs = axes.flat
    for ax in axs[nslice:]:
        ax.remove()
    axes = axs[:nslice]
    x = np.arange(0,2048)+1
    for i in np.arange(nslice,dtype=int):
        axes[i].scatter(result.x[i],np.divide(result.y[i],dy))
        axes[i].plot(x,result.model[i](x))
    fig.show()
    fig.savefig(name+'.png')

def scat_sub(int_file,gap_file,model,dy=257,amp=1,cent=1000,width=150,level=5,**kargs):
    file_r = int_file.replace("int","intd")
    name = int_file.split("_int")[0]+'--'+model
    shutil.copy2(int_file,file_r)
    
    #Read the gap file
    gaps = open(gap_file, "r")
    gaps_data = gaps.readlines()
    gap = np.zeros((np.shape(gaps_data)[0],2))
    for i in np.arange(np.shape(gaps_data)[0]):
        gap[i,0] = int(gaps_data[i].split()[0])
        gap[i,1] = int(gaps_data[i].split()[1])
    
    with fits.open(file_r, mode='update') as hdu:
        data = hdu[0].data
        
        # Generate Slices
        yn,xn = np.shape(data)
        nslice = int(np.ceil(yn/dy))
        row = np.empty((nslice,2),dtype=int)
        for i in np.arange(nslice-1):
            row[i] = [i*dy,(i+1)*dy]
        row[nslice-1] = [(nslice-1)*dy,yn]
        
        # Fit for each slice and Plot the result
        result = np.recarray((nslice,),dtype=[('x','O'),('y','O'),('model','O')])
        for i in np.arange(nslice):
            result.x[i], result.y[i], result.model[i] = slice_fit(data,gap,model,row[i,0],row[i,1],amp,cent,width,level)
        fit_plot(nslice,dy,result,name)
        
        # Subtract the scattered light
        for i in np.arange(nslice,dtype=int):
            scat = result.model[i](np.arange(xn,dtype=int))
            data[row[i,0]:row[i,1],:] -= scat
        
    hdu.close()

In [3]:
# Redo subtractraction for the slice with unsuccessful fit
# ----------------------------------------------------------------------------------------
# This function performs a refit for the specified slice and updates the reduced file
# without touching other slices.
# The returned plot can help recognize whether a good fit is obtained, and serve as
# a reference for twekaing the parameters.
# ----------------------------------------------------------------------------------------
# Call:
#    refit_slice(int_file,gap_file,slice_num,model,*dy=257,*amp=1,*cent=1000,*width=150,*level=5)
# ----------------------------------------------------------------------------------------
# Inputs:
#   Required:
#      int_file:   (str) path to file, e.g. '/data/kb190101_00011_int.fits'
#      gap_file:   (str) path to gap file, e.g. '/data/kb190101_00011_gaps.dat'
#     slice_num:   (int) which slice to refit
#         model:   (str) model to fit, 'moffat' or 'ricker'
#   Optional:
#            dy:   (int) number of rows for each slice, the default value is 257 (8 slices)
#           amp: (float) amplitude for Moffat/Ricker model, default amp = 1
#          cent: (float) center of the model, x position of the maximum, default cent = 1000
#         width: (float) width of the model/wavelet, default width = 150
#         level: (float) C0 in 2nd order polynomial, default level = 5
# ----------------------------------------------------------------------------------------
# Outputs:
#          plots:  original data points and fitted model, saved to the working directory
#   updated_data:  update the data from scat_sub (*intd.fits)
# ----------------------------------------------------------------------------------------
# Notes:
#  1. dy should be kept consistent to make sure you are working on the right slice.
#     That is, if the default dy=257 was not used, you need to input dy used in scat_sub.
# ----------------------------------------------------------------------------------------
# Updates: 
#    Aug 20, 2020   W.Ning  Created


def refit_slice(int_file,gap_file,slice_num,model,dy=257,amp=1,cent=1000,width=150,level=5,**kargs):
    file_r = int_file.replace("int","intd")
    name = int_file.split("_int")[0]+'--'+model
    
    #Read the gap file
    gaps = open(gap_file, "r")
    gaps_data = gaps.readlines()
    gap = np.zeros((np.shape(gaps_data)[0],2))
    for i in np.arange(np.shape(gaps_data)[0]):
        gap[i,0] = int(gaps_data[i].split()[0])
        gap[i,1] = int(gaps_data[i].split()[1])
    
    # Slices to fit
    refit = fits.getdata(int_file)
    yn,xn = np.shape(refit)
    row = np.empty((1,2),dtype=int)
    if slice_num<np.floor(yn/dy):
        row = [(slice_num-1)*dy,slice_num*dy]
    else:
        row = [(slice_num-1)*dy,yn]
    
    # Refit the specified slice
    result = np.recarray((1,),dtype=[('x','O'),('y','O'),('model','O')])
    result.x[0], result.y[0], result.model[0] = slice_fit(refit,gap,model,row[0],row[1],amp,cent,width,level)
    
    # Plot the fit
    p0 = plt.figure(figsize=(8,5))
    x = np.arange(xn)+1
    plt.scatter(result.x[0],np.divide(result.y[0],dy))
    plt.plot(x,result.model[0](x))
    plt.title(name,fontsize=12);
    p0.show()
    p0.savefig(name+'_'+str(slice_num)+'.png')
    
    with fits.open(file_r, mode='update') as hdu:
        data = hdu[0].data
        scat = result.model[0](np.arange(xn,dtype=int)+1)
        data[row[0]:row[1],:] -= scat
    hdu.close()