# make_cube

This notebook contains the code necessary for creating a data cube containing all line parameters, RVs, and measured activity indices given a list of lines and a list of NEID files. This notebook does not include the line list and files themselves.

In [None]:
import numpy as np
from astropy.io import fits
from numpy.polynomial import polynomial as P
from scipy import interpolate
from scipy.constants import c
from scipy.optimize import leastsq
from scipy import stats
from scipy.optimize import curve_fit
from tqdm import tqdm
import xarray as xr

#model fitting
from scipy.signal import savgol_filter
from scipy.interpolate import CubicSpline
from astropy.modeling import models, fitting
from astropy import modeling

In [None]:
#chunked_continuum_fit function written by Arpita Roy

#this function divides the spectrum into "chunks", take the maximum value of each chunk, and fits a polynomial 
#over these chunks. the spectrum is then normalized by dividing out this estimate of the continuum
def chunked_continuum_fit(x, y, percentile_cut=95.0, nchunks=15, deg=5):
    chunks_x = np.zeros(nchunks) #chunk up spectrum and store mini-spectrum chunks in lists
    chunks_y = np.zeros(nchunks)
    npixels = len(x)
    pixels_per_chunk = npixels / nchunks
    for i in range(nchunks): #loop through each chunk
        chunk_i1 = int(i*pixels_per_chunk)
        chunk_i2 = int((i+1)*pixels_per_chunk)
        if chunk_i2 > npixels: #chop end of last chunk
            chunk_i2 = npixels
        chunks_y[i] = np.nanpercentile(y[chunk_i1:chunk_i2], percentile_cut) #chunk y is xxth percentile of chunk
        chunks_x[i] = 0.5*(x[chunk_i1:chunk_i2][0]+x[chunk_i1:chunk_i2][-1]) #chunk x is just center of chunk
    smooth_y = savgol_filter(chunks_y, 11, 1) #add some smoothin to avoid "falling" into lines
    pfit = np.polyfit(chunks_x, smooth_y, deg=deg) #fit polynomial to chunks
    cs = CubicSpline(chunks_x, chunks_y)
    return np.polyval(pfit, x), cs, chunks_x, chunks_y #return continuum array generated from polynomial fit

In [None]:
#simple gaussian function
def gauss(x, a, x0, sigma):
    return a * np.exp(-(x - x0)**2 / (2 * sigma**2))

In [1]:
#cube code written by Sarah Jiang, adapted from code written by Justin Otor
def create_cube(file_list, line_list, cube_name='',
                    lstart=0, lend=None):
    """
    Create a data cube from a list of NEID FITS files.
    
    Arguments
    ---------
    file_list : list, required
        A list of file paths leading to FITS files observational data from NEID
        
    line_list : numpy array or `~xarray.core.dataset.Dataset`, required
        A numpy array or xarray Dataset with information on spectral lines of interest.
         
    cube_name : str, optional
        The name of the DataArray that will hold the data cube.
        
    lstart : int, optional
        The first index in `line_list` to be considered for the data cube.
        [default: 0]
        
    lend : int, optional
        The last index in `line_list` to be considered for the data cube. Lines
        from `lstart` to `lend` will be included in the final product. An `lend`
        of None equates including all lines from `lstart` onward.
        [default: None]

    """
    #set up parameter arrays
    nlines = len(line_list)
    nfiles = len(file_list)
    
    #initialize all parameter arrays with zeros in the dimensions of nfiles x nlines
    #these arrays will store the value of each parameter (e.g. centroid) for each file in file_list 
    #and each line in line_list, and errors
    #LINE PARAMETERS
    centroids = np.zeros((nfiles, nlines))
    centroid_errs = np.zeros(centroids.shape)
    depths = np.zeros(centroids.shape)
    depth_errs = np.zeros(centroids.shape)
    fwhms = np.zeros(centroids.shape)
    fwhm_errs = np.zeros(centroids.shape)
    intfluxs = np.zeros(centroids.shape)
    intflux_errs = np.zeros(centroids.shape)
    
    #RV AND ACTIVITY INDICATORS
    rvs = np.zeros(centroids.shape)
    rv_errs = np.zeros(centroids.shape)
    cahks = np.zeros(centroids.shape)
    cahk_errs = np.zeros(centroids.shape)
    hei_1s = np.zeros(centroids.shape)
    hei_1_errs = np.zeros(centroids.shape)
    hei_2s = np.zeros(centroids.shape)
    hei_2_errs = np.zeros(centroids.shape)
    nais = np.zeros(centroids.shape)
    nai_errs = np.zeros(centroids.shape)
    ha06_1s = np.zeros(centroids.shape)
    ha06_1_errs = np.zeros(centroids.shape)
    ha06_2s = np.zeros(centroids.shape)
    ha06_2_errs = np.zeros(centroids.shape)
    ha16_1s = np.zeros(centroids.shape)
    ha16_1_errs = np.zeros(centroids.shape)
    ha16_2s = np.zeros(centroids.shape)
    ha16_2_errs = np.zeros(centroids.shape)
    cai_1s = np.zeros(centroids.shape)
    cai_1_errs = np.zeros(centroids.shape)
    cai_2s = np.zeros(centroids.shape)
    cai_2_errs = np.zeros(centroids.shape)
    cairt1s = np.zeros(centroids.shape)
    cairt1_errs = np.zeros(centroids.shape)
    cairt2s = np.zeros(centroids.shape)
    cairt2_errs = np.zeros(centroids.shape)
    cairt3s = np.zeros(centroids.shape)
    cairt3_errs = np.zeros(centroids.shape)
    nainirs = np.zeros(centroids.shape)
    nainir_errs = np.zeros(centroids.shape)
    padeltas = np.zeros(centroids.shape)
    padelta_errs = np.zeros(centroids.shape)
    
    #initialize empty arrays to store timestamps of each file in file_list and wavelengths of each line in line_list
    times = []
    lines = []
    
    #iterate through line_list
    for l, line in tqdm(enumerate(line_list)):
        #collect line wavelength
        lines.append(line)
        
        #initialize arrays to hold data for this line over all files in file_list
        line_waves = []
        line_specs = []
        line_errs = []
        
        #if any observations have NaNs in the window of the line, the gaussian fit will fail and 
        #flag corr_err = True
        #initialize corr_err = False
        corr_err = False
        
        #iterate through file_list
        for o, file in enumerate(file_list):
            #fetch data and header info from file
            hdul = fits.open(file)
            if l == 0: #collect file timestamp; only do this once
                times.append(hdul['PRIMARY'].header['OBSJD'])

            flux = hdul['SCIFLUX'].data
            wave = hdul['SCIWAVE'].data
            blaze = hdul['SCIBLAZE'].data
            err = hdul['SCIVAR'].data
                  
            #collect rv and act idxs and errors from NEID file and place into the correct file x line position in 
            #the parameter arrays
            rvs[o, lstart + l] = hdul['CCFS'].header['CCFRVMOD']
            rv_errs[o, lstart + l] = hdul['CCFS'].header['DVRMSMOD']
            cahks[o, lstart + l] = hdul['ACTIVITY'].data[0][1]
            cahk_errs[o, lstart + l] = hdul['ACTIVITY'].data[0][2]
            hei_1s[o, lstart + l] = hdul['ACTIVITY'].data[1][1]
            hei_1_errs[o, lstart + l] = hdul['ACTIVITY'].data[1][2]
            hei_2s[o, lstart + l] = hdul['ACTIVITY'].data[2][1]
            hei_2_errs[o, lstart + l] = hdul['ACTIVITY'].data[2][2]
            nais[o, lstart + l] = hdul['ACTIVITY'].data[3][1]
            nai_errs[o, lstart + l] = hdul['ACTIVITY'].data[3][2]
            ha06_1s[o, lstart + l] = hdul['ACTIVITY'].data[4][1]
            ha06_1_errs[o, lstart + l] = hdul['ACTIVITY'].data[4][2]
            ha06_2s[o, lstart + l] = hdul['ACTIVITY'].data[5][1]
            ha06_2_errs[o, lstart + l] = hdul['ACTIVITY'].data[5][2]
            ha16_1s[o, lstart + l] = hdul['ACTIVITY'].data[6][1]
            ha16_1_errs[o, lstart + l] = hdul['ACTIVITY'].data[6][2]
            ha16_2s[o, lstart + l] = hdul['ACTIVITY'].data[7][1]
            ha16_2_errs[o, lstart + l] = hdul['ACTIVITY'].data[7][2]
            cai_1s[o, lstart + l] = hdul['ACTIVITY'].data[8][1]
            cai_1_errs[o, lstart + l] = hdul['ACTIVITY'].data[8][2]
            cai_2s[o, lstart + l] = hdul['ACTIVITY'].data[9][1]
            cai_2_errs[o, lstart + l] = hdul['ACTIVITY'].data[9][2]
            cairt1s[o, lstart + l] = hdul['ACTIVITY'].data[10][1]
            cairt1_errs[o, lstart + l] = hdul['ACTIVITY'].data[10][2]
            cairt2s[o, lstart + l] = hdul['ACTIVITY'].data[11][1]
            cairt2_errs[o, lstart + l] = hdul['ACTIVITY'].data[11][2]
            cairt3s[o, lstart + l] = hdul['ACTIVITY'].data[12][1]
            cairt3_errs[o, lstart + l] = hdul['ACTIVITY'].data[12][2]
            nainirs[o, lstart + l] = hdul['ACTIVITY'].data[13][1]
            nainir_errs[o, lstart + l] = hdul['ACTIVITY'].data[13][2]
            padeltas[o, lstart + l] = hdul['ACTIVITY'].data[14][1]
            padelta_errs[o, lstart + l] = hdul['ACTIVITY'].data[14][2]

            #handle NaNs
            flux[np.isnan(flux)] = 0
            wave[np.isnan(wave)] = 0
            err[np.isnan(err)] = 0

            #get NEID order for line
            #lines can appear in multiple orders since the wavelength range of each order overlap; get all orders
            #that exhibit line and choose the one where the line is closest to the middle of the order
            locs = np.apply_along_axis(np.searchsorted, 1, wave, line)
            matches = (locs > 0) * (locs < wave.shape[1])

            orders, = np.where(matches==True)
            if len(orders) == 1:
                line_order = orders[0] #only one order
            elif len(orders) == 3:
                line_order = orders[1] #choose middle order
            else: #find order where line is closest to the middle of the order
                order1 = orders[0]
                order2 = orders[1]

                order1_mean = (wave[order1,0]+wave[order1,9215])/2
                order2_mean = (wave[order2,0]+wave[order2,9215])/2

                diff1 = abs(line-order1_mean)
                diff2 = abs(line-order2_mean)

                if diff1 < diff2:
                    line_order = orders[0]
                else:
                    line_order = orders[1]

            #shift wavelength axis to the rest frame of the star
            barycorr = hdul['PRIMARY'].header['SSBRV'+str(line_order+52).zfill(3)] #add 52 to get header order
            rv = hdul['CCFS'].header['CCFRVMOD'] #measured rv #km/s
            delta_rv = rv-barycorr #km/s

            delta_rv = delta_rv*1000 #convert to m/s
            avg_wvl = (wave[line_order,0]+wave[line_order,9215])/2
            f = c/avg_wvl #divide by average wavelength of order
            delta_lambda = delta_rv/f

            #collect shifted wavelength axes
            line_wave = wave[line_order,:]-delta_lambda
            line_waves.append(line_wave)

            #collect blaze-corrected spectra
            line_spec = flux[line_order,:]/blaze[line_order,:]
            line_specs.append(line_spec)
            
            #collect errors on blaze-corrected spectra
            line_err = err[line_order,:]/blaze[line_order,:]
            line_errs.append(line_err)
        
        line_waves = np.vstack(line_waves)
        line_specs = np.vstack(line_specs)
        line_errs = np.vstack(line_errs)
                        
        #normalize spectra and errors using chunkfit function
        line_specs_norm = []
        line_errs_norm = []

        for i, wave in enumerate(line_waves):
            chunkfit, cs,chunks_x, chunks_y = chunked_continuum_fit(wave,line_specs[i],percentile_cut=90.0, nchunks=100, deg=1)
            line_specs_norm.append(line_specs[i]/chunkfit)
            line_errs_norm.append(line_errs[i]/chunkfit)

        line_specs_norm = np.vstack(line_specs_norm)
        line_errs_norm = np.vstack(line_errs_norm)
        
        #calculate integrated flux using uniform velocity window of 3 km/s and interpolating within window
        line_fluxs = []
        flux_errs = []
        
        for i, wave in enumerate(line_waves):
            #convert velocity window to corresponding wavelength window
            rv_wing = 3000 #m/s
            f = c/line
            lambda_wing = rv_wing/f

            #generate uniform wavelength window
            flux_start = line-lambda_wing
            flux_stop = line+lambda_wing

            flux_range = np.linspace(flux_start, flux_stop, 50)

            #interpolate flux over uniform window
            x = wave
            y = line_specs_norm[i]
            f = interpolate.interp1d(x,y,kind='cubic')
            
            interp_flux = f(flux_range)
            #sum flux to get integrated flux
            int_flux = np.nansum(interp_flux)
            
            #calculate error on integrated flux
            flux_plus_error = y + line_errs_norm[i]
            fplus = interpolate.interp1d(x, flux_plus_error, kind='cubic')
            interp_plus = fplus(flux_range)
            errs = interp_plus-interp_flux
            
            sqrd_errs = [err**2 for err in errs]
            flux_err = np.sqrt(np.nansum(sqrd_errs))
            
            #collect integrated fluxes and errors
            line_fluxs.append(int_flux)
            flux_errs.append(flux_err)

        #now clip spectra to line, gaussian fit line, and get other line parameters (centroid, depth, FWHM)
        clipped_waves = []
        clipped_specs = []
        clipped_errs = []

        #assume line width of 0.4A (0.2A on either side) and use these as starting bounds for line
        wing = 0.2

        line_start = line-wing
        line_stop = line+wing

        for i, wave in enumerate(line_waves):
            #find closest wavelength value to 0.2A wings
            start_diff_list = wave - line_start
            stop_diff_list = wave - line_stop

            start_diff_list = [abs(i) for i in start_diff_list]
            stop_diff_list = [abs(i) for i in stop_diff_list]

            closest_start = np.nanmin(start_diff_list)
            closest_stop = np.nanmin(stop_diff_list)

            start_idx, = np.where(start_diff_list==closest_start)
            stop_idx, = np.where(stop_diff_list==closest_stop)

            #clip wavelength axes, spectra, and spectra error arrays to 0.2A wings
            clipped_wave = line_waves[i][int(start_idx):int(stop_idx)]
            clipped_spec = line_specs_norm[i][int(start_idx):int(stop_idx)]
            clipped_err = line_errs_norm[i][int(start_idx):int(stop_idx)]

            clipped_waves.append(clipped_wave)
            clipped_specs.append(clipped_spec)
            clipped_errs.append(clipped_err)

        line_peaks = []
        line_centroids = []
        line_fwhms = []
        
        line_peak_errs = []
        line_centroid_errs = []
        line_fwhm_errs = []
        
        #---------------------------------------START OF GAUSSIAN FITTING

        #iterate through clipped data to fit for parameters in each observation
        for i in range(len(clipped_waves)):
            x = clipped_waves[i]
            y = -clipped_specs[i]+1 #flip spectra to enable gaussian fitting
            yerr = clipped_errs[i]

            if 0 not in clipped_waves[i] and 0 not in clipped_specs[i]: #throw out obs that have NaNs/0s
                #fit gaussian; if this function fails, flag corr_err = True
                try:
                    popt, cov = curve_fit(gauss, x, y, p0=[0.5, line, 0.2], sigma=yerr,
                                          bounds=[[0,line-1,0],[1,line+1,1]])
                except:
                    corr_err = True
                    continue
                    
                #get line depth/peak
                peak = popt[0]

                #get centroid
                centroid = popt[1]

                #get fwhm
                fwhm = 2*popt[2]*np.sqrt(2*np.log(2))
                
                #get errors of gaussian fit; if this fails, flag corr_err = True
                try:
                    perr = np.sqrt(np.diag(cov))
                except:
                    perr = np.zeros(3)

                #get parameter errors
                depth_err = perr[0]
                centroid_err = perr[1]
                fwhm_err = 2*perr[2]*np.sqrt(2*np.log(2))
                
                #collect parameters and errors
                line_peak_errs.append(depth_err)
                line_centroid_errs.append(centroid_err)
                line_fwhm_errs.append(fwhm_err)
                
                line_peaks.append(peak)
                line_centroids.append(centroid)
                line_fwhms.append(fwhm)
            else:
                corr_err = True

        #if any of the steps that flag corr_err occur, skip/discard line
        if corr_err == True:
            print('ERROR: At least one file has NANs in window')
            continue
        
        #now check that the fit was accurate by checking if the mean centroid measured for this line is 
        #within 0.01A of the known line center
        line_centroids = np.array(line_centroids)
        mean_centroid = np.nanmean(line_centroids)

        std = 0.2
        counter = 0

        try:
            #iterate through 19 steps of narrowing the line window by 0.01 and re-fitting the gaussian until the
            #mean centroid is within 0.01A of the known line center
            while abs(mean_centroid-line) > 0.01 and counter < 18:
                wing -= 0.01
                std -= 0.01

                clipped_waves = []
                clipped_specs = []
                clipped_errs = []

                line_start = line-wing
                line_stop = line+wing

                for i, wave in enumerate(line_waves):
                    start_diff_list = wave - line_start
                    stop_diff_list = wave - line_stop

                    start_diff_list = [abs(i) for i in start_diff_list]
                    stop_diff_list = [abs(i) for i in stop_diff_list]

                    closest_start = np.nanmin(start_diff_list)
                    closest_stop = np.nanmin(stop_diff_list)

                    start_idx, = np.where(start_diff_list==closest_start)
                    stop_idx, = np.where(stop_diff_list==closest_stop)

                    clipped_wave = line_waves[i][int(start_idx):int(stop_idx)]
                    clipped_spec = line_specs_norm[i][int(start_idx):int(stop_idx)]
                    clipped_err = line_errs_norm[i][int(start_idx):int(stop_idx)]

                    clipped_waves.append(clipped_wave)
                    clipped_specs.append(clipped_spec)
                    clipped_errs.append(clipped_err)

                line_peaks = []
                line_centroids = []
                line_fwhms = []
                
                line_peak_errs = []
                line_centroid_errs = []
                line_fwhm_errs = []

                for i in range(len(clipped_waves)):
                    x = clipped_waves[i]
                    y = -clipped_specs[i]+1
                    yerr = clipped_errs[i]

                    if 0 not in clipped_waves[i] and 0 not in clipped_specs[i]: #throw out obs that have nans/0s
                        try:
                            popt, cov = curve_fit(gauss, x, y, p0=[0.5, line, std], sigma=yerr,
                                                  bounds=[[0,line-1,0],[1,line+1,1]])
                        except:
                            corr_err == True
                            continue

                        #get line depth/peak
                        peak = popt[0]

                        #get centroid
                        centroid = popt[1]

                        #get fwhm
                        fwhm = 2*popt[2]*np.sqrt(2*np.log(2))
                        
                        try:
                            perr = np.sqrt(np.diag(cov))
                        except:
                            perr = np.zeros(3)
                            
                        depth_err = perr[0]
                        centroid_err = perr[1]
                        fwhm_err = 2*perr[2]*np.sqrt(2*np.log(2))

                        line_peak_errs.append(depth_err)
                        line_centroid_errs.append(centroid_err)
                        line_fwhm_errs.append(fwhm_err)

                        line_peaks.append(peak)
                        line_centroids.append(centroid)   
                        line_fwhms.append(fwhm)

                if corr_err == True:
                    continue
                        
                line_centroids = np.array(line_centroids)
                mean_centroid = np.nanmean(line_centroids)
                counter += 1
        except: #if iterative narrowing and re-fitting fails during the process, skip/discard line
            continue

        #if iterative narrowing and re-fitting finishes but the mean centroid is still not within 0.01A of the
        #known line center, skip/discard line
        if counter == 18:
            continue
        
        #at this point, lines have been filtered for generally "clean" lines without NaNs that can be accurately 
        #fitted with a standard gaussian within +/- 0.2A
        
        #collect centroids, fwhms, depths, and integrated fluxes and errors for ALL observations for line
        #and place into the correct files x line positions in the parameter arrays
        centroids[:len(line_centroids), lstart + l] = line_centroids
        centroid_errs[:len(line_centroid_errs), lstart + l] = line_centroid_errs
        fwhms[:len(line_fwhms), lstart + l] = line_fwhms
        fwhm_errs[:len(line_fwhm_errs), lstart + l] = line_fwhm_errs
        depths[:len(line_peaks), lstart + l] = line_peaks
        depth_errs[:len(line_peak_errs), lstart + l] = line_peak_errs
        intfluxs[:len(line_fluxs), lstart + l] = line_fluxs
        intflux_errs[:len(flux_errs), lstart + l] = flux_errs
                    
    #create the DataArray after joining parameter arrays, swapping line and time axes so time comes last
    join_params = np.array([centroids, centroid_errs, depths, depth_errs, fwhms, fwhm_errs, intfluxs, intflux_errs,
                            rvs, rv_errs, cahks, cahk_errs, hei_1s, hei_1_errs, hei_2s, hei_2_errs, nais, nai_errs,
                            ha06_1s, ha06_1_errs, ha06_2s, ha06_2_errs, ha16_1s, ha16_1_errs, ha16_2s, ha16_2_errs,
                            cai_1s, cai_1_errs, cai_2s, cai_2_errs, cairt1s, cairt1_errs, cairt2s, cairt2_errs,
                            cairt3s, cairt3_errs, nainirs, nainir_errs, padeltas, padelta_errs]).swapaxes(1,2)
    
    cube = xr.DataArray(join_params,
                             dims=('param', 'line', 'time'),
                             coords={'param': ['centroid', 'centroid_err', 'depth', 'depth_err', 'fwhm', 'fwhm_err',
                                               'int_flux', 'int_flux_err', 'rv', 'rv_err', 'cahk', 'cahk_err',
                                               'hei_1', 'hei_1_err', 'hei_2', 'hei_2_err', 'nai', 'nai_err', 'ha06_1',
                                               'ha06_1_err', 'ha06_2', 'ha06_2_err', 'ha16_1', 'ha16_1_err', 'ha16_2',
                                               'ha16_2_err', 'cai_1', 'cai_1_err', 'cai_2', 'cai_2_err', 'cairt1',
                                               'cairt1_err', 'cairt2', 'cairt2_err', 'cairt3', 'cairt3_err', 'nainir',
                                               'nainir_err', 'padelta', 'padelta_err'],
                                     'line': lines,
                                     'time': times
                                    },
                             name=cube_name,
                            )
    
    #add parameter-specific labels (for time)
    cube.coords['time'].attrs['unit'] = 'JD'
    
    #return cube containing centroid, depth, fwhm, integrated flux, rvs, act idxs, and all errors, wavelength
    #values of each line, and timestamps of each observation
    return cube