# STIS data reduction pipeline
Sai Krishanth Pulikesi Mannan

The functions createmask and combproframes are adapted from Robert M Thompson (https://github.com/douglase/reduce50coron/blob/epseri/stishubble.py)

This has been tested with the following versions of Python(v3.10) packages:

<ol>
<li>Pyklip: 2.6</li>
<li>Scipy: 1.7.3</li>
<li>Pandas: 1.4.2</li>
<li>Astropy: 5.2.1</li>
<li>Numpy: 1.21.5</li>
<li>Matplotlib: 3.5.1</li>
<li>Multiprocess: 0.70.14</li>
</ol>

Packages that are not pip installed are available from here:

1. [pyklip](https://pyklip.readthedocs.io/en/latest/install.html)
2. [nmf_imaging](https://github.com/seawander/nmf_imaging)
3. [radonCenter](https://github.com/seawander/centerRadon)

In [1]:
import os
import math
import warnings
import radonCenter
import nmf_imaging
import numpy as np
import pandas as pd
import scipy.ndimage
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap,SymLogNorm
from multiprocess import Pool
from astropy import wcs
from astropy import stats
import astropy.units as u
from astropy.io import fits
from astropy.nddata import Cutout2D
from astropy import convolution as conv
from astropy.visualization import astropy_mpl_style
from astropy.utils.exceptions import AstropyWarning
from pyklip.klip import rotate as pyklipRotate

In [2]:
#These are Wedge A1.0 files where the shapes are smaller and centers are off by a lot
badrefs = ['o64w12040_flt.fits',
           'o64w11010_flt.fits',
           'o64w12020_flt.fits',
           'o64wa2030_flt.fits',
           'o64w12010_flt.fits',
           'o64w12030_flt.fits',
           'o64wa2020_flt.fits',
           'o64wa2040_flt.fits',
           'o64w11030_flt.fits',
           'o64wa2010_flt.fits',
           'o64wa1030_flt.fits',
           'o64w11040_flt.fits',
           'o64wa1020_flt.fits',
           'o64wa1040_flt.fits',
           'o64wa1010_flt.fits',
           'o64w11020_flt.fits']

In [3]:
#Set target and reference directories, diskmodel used (to inject disk, optional), number of cores,
#number of NMF components, and maximum iterations
trgdir = 'eps_eri_trg'
refdir = 'del_eri_ref'
trgfiles = [f for f in os.listdir(trgdir)]
reffiles = [f for f in os.listdir(refdir) if f not in trgfiles]
reffiles = [f for f in reffiles if f not in badrefs]

diskmodel = 'disk_models/sst_backman_model_HST_inc_34.fits'

ncores = 8
n_components = 10
maxiters = 1e3

In [4]:
#Function to create algorithmic mask applied to frames. To chage properties of the mask, it is suggested to start
#with wedgeWarcsF and legWarcsF (they adjust the thickness of the central wedge and 
#the diffraction spikes respectively). The "shift" variables shift the mask left or right.
def createmask(maskshape, voff, hoff, apertureshape):
    """
    Creates the maskFR
    This is presently done by default.
    
    Args:
        maskL       : A list of parameters to make the maskFR
                      Presently, the first is width in arcseconds of the wedge
                      Presently, the second is width in arcseconds of the legs
        maskshape   : The shape to make the maskFR
        
    Attributes:
        returns the maskFR frame
     
    """
    slope = (((3.0-0.5)/2)/50)
    H = maskshape[0]
    W = maskshape[1]
    mvc = int(H/2) # mask vertical center
    mhc = int(W/2) # mask horizontal center
    
    maskFR = np.zeros(maskshape)
    avc = int ( apertureshape[0] / 2 ) # aperture Vertical center
    ahc = int ( apertureshape[1] / 2 ) # aperture Horizontal center
    maskFR[mvc-avc+1+1:mvc+avc-1,mhc-ahc+1:mhc+ahc-1] = 1
        
    wedgeWarcsF   = 1.2 # Float   ; default: 1.0 arcseconds
    legWarcsF     = 0.8 # Float   ; default: 0.2 arcseconds
    wedgeshiftI   = 0 # Integer 
    leghorzshiftI = 0 # Integer 
    legvertshiftI = 0 # Integer 
    wedgeWpix     = round(wedgeWarcsF/0.05) # default: 20 pixels
    legWpix       = round(legWarcsF/0.05) # default: 4 pixels
    halfwedgeWpix = round(wedgeWpix/2)
    mhoff = hoff - 202 - 1
    leghorzshiftI = leghorzshiftI + round(mhoff)
    legvertshiftI = legvertshiftI + round(voff)
    wedgeshiftI = wedgeshiftI + round(mhoff)
            
    # Wedge A
    wedgestartX = mhc + int(wedgeshiftI)
    wedgestartY = mvc
    wedgeX      = wedgestartX
    wedgeY      = wedgestartY
    for row in range(mvc):
        halfwidthtop = int(halfwedgeWpix+((wedgeY+row)*slope))
        for col in range(halfwidthtop):
            maskFR[wedgeY+row][wedgeX+col] = 0
            maskFR[wedgeY+row][wedgeX-col] = 0
        
    for row in range(mvc):        
        halfwidthbottom = int(halfwedgeWpix + ((wedgeY-row)*slope))
        for col in range(halfwidthbottom):
            maskFR[wedgeY-row][wedgeX+col] = 0
            maskFR[wedgeY-row][wedgeX-col] = 0
    
    # Spider legs / Diffraction spikes
    S                = int ( leghorzshiftI )
    legstartX        = mhc   
    legstartY        = mvc
    legX             = legstartX
    legY             = legstartY
    lvs              = int(legvertshiftI)
    
    for row in range ( mvc + abs ( lvs ) ) :
        legY = legY + 1
        legX = legX + 1
        for col in range ( legWpix ) :
            if W - 1 - legX + S - col + int(legWpix/2) < 239 and legY + lvs < H :
                maskFR [   legY-1 + lvs ] [ W - 1 - legX + S - col + int(legWpix/2)] = 0 # UL
            if         legX + S + col - int(legWpix/2) < 239 and legY + lvs < H : 
                maskFR [   legY-1 + lvs ] [         legX + S + col - int(legWpix/2)] = 0 # UR
            if         legX + S + col - int(legWpix/2) < 239 and - legY + lvs > - H - 1 :
                maskFR [ - legY+1 + lvs ] [         legX + S + col - int(legWpix/2)] = 0 # LR
            if W - 1 - legX + S - col + int(legWpix/2) < 239 and - legY + lvs > - H - 1 :
                maskFR [ - legY+1 + lvs ] [ W - 1 - legX + S - col + int(legWpix/2)] = 0 # LL        

    return maskFR

In [5]:
#Function that combines the processed frames. Has 3 output fits files which contain residuals, medians, and 
#standard deviations respectively. 
def combproframes(DC, filenames, PAs) : 
    print("DC.shape       : ", DC.shape)
    print("len(filenames) : ", len(filenames)) 
    print("len(PAs)       : ", len(PAs))
    DCii   = 0  
    FN     = filenames [ DCii ] 
    lastFN = FN
    IHDUii = 1  
    TLbo     = [] 
    for suusaFR in DC : 

        FN = filenames [ DCii ]
        if lastFN != FN : IHDUii = 1
        PA = PAs [ DCii ]
        susaFR = stats.sigma_clip ( suusaFR, sigma = 2, maxiters = 5 )

        frmmsT = stats.sigma_clipped_stats ( susaFR , sigma = 2, maxiters = 5 ) 
        rotationcenter = ( susaFR.shape[1] / 2, susaFR.shape[0] / 2 )
        nusaFR = pyklipRotate ( susaFR, PA, rotationcenter ) 
        TLbo.append (
            (
                DCii,
                nusaFR,
                susaFR,
                FN,
                IHDUii,
                frmmsT[2],
                PA,
            )     
        )
        DCii   = DCii + 1
        IHDUii = IHDUii + 3
        lastFN = FN

    nuFRRAbo = np.array ( [ col[1] for col in TLbo ] ) 
    print("nuFRRAbo.shape : ", nuFRRAbo.shape)
    
    HDU      = fits.PrimaryHDU ( nuFRRAbo )
    HDUL     = fits.HDUList ( [ HDU ] )

    oFN  = "nmf_nuAllResidFRbo.fits"
    HDUL.writeto ( oFN, overwrite = True )
    
    from operator import itemgetter 
    TLbs     = sorted ( TLbo, key = itemgetter ( 5 ) )
    nuFRRAbs = np.array ( [ col[1] for col in TLbs ] )
    print("nuFRRAbs.shape      : ", nuFRRAbs.shape)
    
    nuNanMedianFR = np.nanmedian ( nuFRRAbs, axis = 0 )
    print("nuNanMedianFR       : ", nuNanMedianFR)
    print("nuNanMedianFR.shape : ", nuNanMedianFR.shape)
    HDU      = fits.PrimaryHDU ( nuNanMedianFR )
    HDUL     = fits.HDUList ( [ HDU ] )

    oFN  = "nmf_nuNanMedianFRbs.fits"
    HDUL.writeto ( oFN, overwrite = True )        

    nuNanStddevFR = np.nanstd    ( nuFRRAbs, axis = 0 )
    print("nuNanStddevFR       : ", nuNanStddevFR)
    print("nuNanStddevFR.shape : ", nuNanStddevFR.shape)
    HDU      = fits.PrimaryHDU ( nuNanStddevFR )
    HDUL     = fits.HDUList ( [ HDU ] )

    oFN  = "nmf_nuNanStddevFRbs.fits"
    HDUL.writeto ( oFN, overwrite = True )    

In [6]:
#Processes and target and reference frames. It is written to be parallelizable. Follows 
#Post-processing of the HST STIS coronagraphic observations" by Ren et al. closely. Changes not advised. 
def prep_frames_parallel(filename):
    
    try:
        trgs = []
        trgs_err = []
        refs = []
        refs_err = []
        PAs = []
        fnames = []
        mask = []
        
        image_data_sc = []
        image_data_er = []
        image_data_dq = []
        
        if filename in trgfiles:
            path = os.path.join(trgdir, filename)
        elif filename in reffiles:
            path = os.path.join(refdir, filename)
            
        with fits.open(path) as hdul:
            exptimes = []
            CRPIX1 = hdul['SCI'].header['CRPIX1']
            CRPIX2 = hdul['SCI'].header['CRPIX2']
            NAXIS1 = hdul['SCI'].header['NAXIS1']
            NAXIS2 = hdul['SCI'].header['NAXIS2']
            CCDGAIN = hdul[0].header['CCDGAIN']
            wcsO = wcs.WCS(hdul[1].header)
            rot_angle = np.rad2deg(math.atan2(wcsO.wcs.cd[1][0], wcsO.wcs.cd[0][0]))
            wcspa = 180*np.sign(rot_angle) - rot_angle
            
            for i in range(1,len(hdul),3):
                exptimes.append(hdul[i].header['EXPTIME'])
                
            exptimes = np.array(exptimes)
            medexp = np.median(exptimes)
            
            for i in range(1,len(hdul),3):
                EXPTIME = hdul[i].header['EXPTIME']
                if EXPTIME > medexp:
                    continue
                image_data_sc.append(hdul[i].data)
                image_data_er.append(hdul[i+1].data)
                image_data_dq.append(hdul[i+2].data)
                
            CRPIX1a = np.zeros(len(image_data_dq))
            CRPIX2a = np.zeros(len(image_data_dq))
            
        for x in range(len(image_data_dq)): 
            
            #if image_data_sc[x].shape[0] != 110:
            cutframesc = Cutout2D(image_data_sc[x],(CRPIX1,CRPIX2),(110,2000),wcs=None)
            cutframeer = Cutout2D(image_data_er[x],(CRPIX1,CRPIX2),(110,2000),wcs=None)
            cutframedq = Cutout2D(image_data_dq[x],(CRPIX1,CRPIX2),(110,2000),wcs=None)
            image_data_sc[x] = cutframesc.data
            image_data_er[x] = cutframeer.data
            image_data_dq[x] = cutframedq.data
            CRPIX1a[x],CRPIX2a[x] = cutframesc.position_cutout
            
            dq01 = np.where(image_data_dq[x]==(16 or 256 or 8192),0,1)
            dq10 = np.where(image_data_dq[x]==(16 or 256 or 8192),1,0)
            medsci = scipy.ndimage.median_filter(image_data_sc[x],size=3)
            multsciadd = np.multiply(dq10,medsci)
            multscisub = np.multiply(dq01,image_data_sc[x])
            image_data_sc[x] = np.add(multscisub,multsciadd)
            mederr = scipy.ndimage.median_filter(image_data_er[x],size=3)
            multerradd = np.multiply(dq10,mederr)
            multerrsub = np.multiply(dq01,image_data_er[x])
            image_data_er[x] = np.add(multerrsub,multerradd)
                
            for i in range(len(image_data_dq[x])):
                for j in range(len(image_data_dq[x][i])):
                    vdist  = i - CRPIX2a[x]
                    hdist  = j - CRPIX1a[x]
                    radius = math.sqrt(pow(vdist, 2) + pow(hdist, 2))
                    image_data_sc[x][i][j] = image_data_sc[x][i][j]*pow(radius, 0.5)
                
            (x_cen, y_cen) = radonCenter.searchCenter(image_data_sc[x], CRPIX1a[x], CRPIX2a[x], size_window = math.floor(NAXIS2/2),size_cost=7,theta=[45, 135]) 
        
            image_data_sc[x] = image_data_sc[x]/EXPTIME
            image_data_er[x] = image_data_er[x]/EXPTIME
            
            voff = image_data_sc[x].shape[0]/2 - y_cen 
            hoff = image_data_sc[x].shape[1]/2 - x_cen 
            
            shiftimage_data_sc = scipy.ndimage.shift(image_data_sc[x], np.array([voff, hoff]))
            shiftimage_data_er = scipy.ndimage.shift(image_data_er[x], np.array([voff, hoff]))     
            
            lh = 110 #for A1.0
            lw = 213 #for A1.0
            
            hypotenuse = math.floor(math.sqrt(pow(lh, 2) + pow(lw, 2)))
        
            sc_data_2d = Cutout2D(shiftimage_data_sc, position=(shiftimage_data_sc.shape[1]/2,shiftimage_data_sc.shape[0]/2), size=(lh,lw), wcs=None)
            er_data_2d = Cutout2D(shiftimage_data_er, position=(shiftimage_data_er.shape[1]/2,shiftimage_data_er.shape[0]/2), size=(lh,lw), wcs=None)
            
            aFR = int(hypotenuse/2 - sc_data_2d.shape[0]/2) 
            aFC = int(hypotenuse/2 - sc_data_2d.shape[1]/2) 
            bb  = 1
            aBR = aFR + bb 
            aTR = aFR - bb + sc_data_2d.shape[0] 
            aLC = aFC + bb 
            aRC = aFC - bb + sc_data_2d.shape[1]  
            
            SCIcanvas = np.zeros((hypotenuse, hypotenuse)) 
            SCIcanvas[aBR:aTR, aLC:aRC] = sc_data_2d.data[bb:-bb,bb:-bb]
            ERRcanvas = np.zeros((hypotenuse, hypotenuse)) 
            ERRcanvas[aBR:aTR,aLC:aRC] = er_data_2d.data[bb:- bb,bb:-bb]
        
            if filename in trgfiles:
                trgs.append(SCIcanvas)
                trgs_err.append(ERRcanvas)
                PAs.append(wcspa)
                fnames.append(filename)
            elif filename in reffiles:
                refs.append(SCIcanvas)
                refs_err.append(ERRcanvas)
                PAs.append(wcspa)
                fnames.append(filename)
                
    except Exception as e:
        print(e,filename)
        
    return trgs, trgs_err, refs, refs_err, PAs, fnames

In [7]:
#Variables that are inputs to the createmask function. Change stisPos for slightly different results.
#Changing scishape not advised. 
scishape = (110, 1024)
stisPos  = (309, 54)
voffstis = scishape[0]/2 - stisPos[1] + 0 
hoffstis = scishape[1]/2 - stisPos[0] + 2 

In [8]:
%%time
#Run prep_frames_parallel in parallel with the number of cores specified. Ignores WCS warnings from astropy
warnings.simplefilter('ignore', AstropyWarning)
with Pool(ncores) as p:
    results1 = list(p.map(prep_frames_parallel, trgfiles))
    results2 = list(p.map(prep_frames_parallel, reffiles))

CPU times: user 236 ms, sys: 1.13 s, total: 1.37 s
Wall time: 1min 48s


In [9]:
#Extracts data we need from results in the previous cell and stores them as arrays
trgs = []
trgs_err = []
refs = []
refs_err = []
PAs = []
fnames = []
mask = createmask((239,239),voffstis,hoffstis,(110,213))

for i in range(len(results1)):
    for j in range(len(results1[i][0])):
        tr = results1[i][0][j]
        trgs.append(tr)
        
for i in range(len(results1)):
    for j in range(len(results1[i][1])):
        tr = results1[i][1][j]
        trgs_err.append(tr)
        
for i in range(len(results1)):
    for j in range(len(results1[i][4])):
        tr = results1[i][4][j]
        PAs.append(tr)
        
for i in range(len(results1)):
    for j in range(len(results1[i][5])):
        tr = results1[i][5][j]
        fnames.append(tr)
        
for i in range(len(results2)):
    for j in range(len(results2[i][2])):
        tr = results2[i][2][j]
        refs.append(tr)
        
for i in range(len(results2)):
    for j in range(len(results2[i][3])):
        tr = results2[i][3][j]
        refs_err.append(tr)
        
for i in range(len(results2)):
    for j in range(len(results2[i][4])):
        tr = results2[i][4][j]
        PAs.append(tr)
        
for i in range(len(results2)):
    for j in range(len(results2[i][5])):
        tr = results2[i][5][j]
        fnames.append(tr)
        
trgs = np.array(trgs)
trgs_err = np.array(trgs_err)
refs = np.array(refs)
refs_err = np.array(refs_err)

In [11]:
#Measure Euclidean distances between reference frames and target frames and selects 10 percent closest reference
#frames to use in NMF subtraction.
best = int(0.1*len(refs))
medians = []
for i in refs:
    norms = []
    for j in range(len(trgs)):
        norm = np.linalg.norm(i - trgs[j])
        norms.append(norm)
    medians.append(np.median(norms))

add = [i for i in range(len(refs))]
best10 = pd.DataFrame(data = medians)
best10['Index'] = add
best10 = best10.sort_values(by=[0])
best10 = best10[0:best]

new_refs = []
new_refs_err = []

for i in best10['Index']:
    new_refs.append(refs[i])
    new_refs_err.append(refs_err[i])

new_refs = np.array(new_refs)
new_refs_err = np.array(new_refs_err)

In [12]:
#Disk injection (optional, comment out this cell if not injecting disk)

trgs = np.array(trgs)*0.1765
trgs_err = np.array(trgs_err)*0.1765
new_refs = np.array(new_refs)*0.1765
new_refs_err = np.array(new_refs_err)*0.1765

image_data = fits.getdata(diskmodel)
image_data_1 = Cutout2D(image_data, position=(128,128), size=(239,239))
disk = (image_data_1.data/0.1765)*100

trgs = np.add(trgs,disk)

In [None]:
%%time
#Build NMF components
components = nmf_imaging.NMFcomponents(new_refs, ref_err = new_refs_err, mask = mask, n_components = n_components, maxiters = maxiters, oneByOne=True)

In [None]:
%time
#Use NMF components to model targets and store subtracted results
results = np.zeros(trgs.shape) 

for i in range(trgs.shape[0]):
    print(f"Number {i+1} of {trgs.shape[0]}")
    trg = trgs[i]
    trg_err = trgs_err[i]
    model = nmf_imaging.NMFmodelling(trg = trg, trg_err = trg_err, components = components, n_components = n_components, mask_components = mask, maxiters = maxiters, trgThresh = 0.0)
    best_frac =  nmf_imaging.NMFbff(trg, model, mask) 
    result = nmf_imaging.NMFsubtraction(trg, model, mask, frac = best_frac) 
    results[i] = result

In [None]:
#Combine individual result frames using wcs data
combproframes(results, fnames, PAs)

In [None]:
nmf_nuNanMedianFRbs = fits.getdata ( "nmf_nuNanMedianFRbs.fits")
plt.figure(figsize=[20,10])
plt.title("nmf_nuNanMedianFRbs")
plt.imshow      (
                        nmf_nuNanMedianFRbs    ,
                        cmap = "plasma" ,
                        vmax = 20      ,
                        vmin =0
                    )
plt.xlim        ( [ 0, nmf_nuNanMedianFRbs.shape[1] ] ) 
plt.ylim        ( [ 0, nmf_nuNanMedianFRbs.shape[0] ] )
plt.colorbar()

In [None]:
nmf_nuNanStddevFRbs = fits.getdata ( "nmf_nuNanStddevFRbs.fits")
plt.figure(figsize=[20,10])
plt.title("nmf_nuNanStddevFRbs")
plt.imshow      (
                        nmf_nuNanStddevFRbs    ,
                        cmap = "plasma" ,
                        vmax = 2       ,
                        vmin = 0
                    )
plt.imshow      ( nmf_nuNanStddevFRbs )
plt.xlim        ( [ 0, nmf_nuNanStddevFRbs.shape[1] ] ) 
plt.ylim        ( [ 0, nmf_nuNanStddevFRbs.shape[0] ] )
plt.colorbar()

In [17]:
zero_color = 'white'
norm=SymLogNorm(6e-9)
mpl.rcParams['hatch.linewidth'] = 0.2  
plt.rcParams["font.serif"]='serif'
mpl.rcParams["image.origin"]="lower"
semi_sym_cm = LinearSegmentedColormap.from_list('my cmap', ['black',zero_color,'purple'])
PIXELSCL=0.05071*u.arcsecond

In [None]:
nmf_fn="nmf_nuAllResidFRbo.fits"

nmf_cube_2019=fits.open(nmf_fn)

print(nmf_cube_2019[0].data.shape)

In [None]:
breaks=[0,95,95*2,95*3,95*4-2,95*5-4,95*6]
import matplotlib
plt.figure(figsize=[15,18])
vmin=np.nanmean(nmf_cube_2019[0].data)-np.nanstd(nmf_cube_2019[0].data)*5
vmax=np.nanmean(nmf_cube_2019[0].data)+np.nanstd(nmf_cube_2019[0].data)*5
visit=np.empty([6,nmf_cube_2019[0].data.shape[1],nmf_cube_2019[0].data.shape[2]])
for i in range(1,len(breaks)):
    plt.subplot(int("3"+"2"+str(i)))
    plt.title(str([breaks[i-1],breaks[i]]))
    current_cmap=matplotlib.cm.get_cmap().copy()
    current_cmap.set_bad(color="k")
    visit[i-1,:,:]=np.nanmedian(nmf_cube_2019[0].data[breaks[i-1]:breaks[i],:,:],axis=0,
                           )
    
    plt.imshow(visit[i-1,:,:],    #norm=matplotlib.colors.LogNorm(0.1,100)
               cmap=current_cmap,
               vmin=0.1,vmax=100
              )
    plt.colorbar()

In [None]:
plt.figure(figsize=[7,3],dpi=200)
plt.subplot(122)
plt.title("2nd visits")
plt.imshow(np.nanmean(visit[:3],axis=0),vmax=20,vmin=0,interpolation="nearest",cmap=current_cmap)
plt.colorbar()
plt.subplot(121)
plt.imshow(np.nanmean(visit[3:],axis=0),vmax=20,vmin=0,interpolation="nearest",cmap=current_cmap)
plt.title("1st visits")
plt.colorbar()

In [None]:
pairset=set()
for j in range(1,len(breaks)):
    plt.figure(figsize=[20,7])

    for i in range(1,len(breaks)):

        plt.subplot(int("2"+"3"+str(i)))
        plt.title(str([breaks[i-1],breaks[i]]))
        current_cmap=matplotlib.cm.gnuplot.copy()#matplotlib.cm.get_cmap()
        current_cmap.set_bad(color="k")
        #print(i)
        delta=visit[i-1,:,:]-visit[j-1,:,:]
        plt.imshow(delta,    norm=matplotlib.colors.LogNorm(1,50),
               cmap=current_cmap,interpolation="nearest",
                  )
        plt.colorbar()
    plt.suptitle("$orbit_n-orbit_{}$".format(j)+"\n"+nmf_fn)

In [None]:
if nmf_fn.find("no_delta") != -1:
    subtracted=nmf_cube_2019[0].data[:,164:-164:,164:-164]#-nmf_cube_2018[0].data[0,:,:]
else:
    subtracted=nmf_cube_2019[0].data[:,:,:]#164:-164:,164:-164]#-nmf_cube_2018[0].data[0,:,:]

sub_std_frm=np.nanstd(np.nanstd(subtracted,axis=2),axis=1)

sub_ordered=sub_std_frm.copy()
sub_ordered.sort()

cutoff=sub_ordered[int(round(len(sub_ordered)*.9))] #counts, %

print(cutoff)

good_frames=np.where(sub_std_frm<cutoff)[0]
print("n  frames:"+str(len(good_frames)))

good_std=np.nanstd(subtracted[good_frames,:,:],axis=0)

bad_frames=np.where(sub_std_frm>cutoff)[0]
bad_std=np.nanstd(subtracted[bad_frames,:,:],axis=0)

bfraction=np.size(bad_frames)/np.size(sub_std_frm)
bad_Median=np.nanmedian(subtracted[bad_frames,:,:],axis=0)


gfraction=np.size(good_frames)/np.size(sub_std_frm)
gfrac_str="{:.2g}".format((np.round(gfraction*100)))

plt.figure(dpi=350)
plt.subplot(121)


#plt.imshow(good_std,vmin=0,vmax=np.nanmax(bad_std),cmap=plt.cm.plasma_r,)
plt.plot(sub_ordered)


plt.plot([0,400],[cutoff,cutoff],label="cutoff")

plt.ylabel("$stdev$(residual)")
#plt.xlabel("frame")
plt.xlabel("frames ordered by std(residual)")
plt.legend()
plt.yscale("log")



plt.subplot(122)


plt.hist(sub_ordered,bins="auto")
plt.ylabel("frames")
plt.plot([cutoff,cutoff],[0,200],label="cutoff")
plt.yscale("log")

plt.xscale("log")

plt.xlabel("$std$(residual)")
plt.legend()
#plt.title("histogram of $stdev$(residual)")
plt.suptitle(gfrac_str+"% above threshold\n"+nmf_fn)
plt.tight_layout()

In [None]:
plt.figure(figsize=[8,10],dpi=200)
plt.subplot(321)

plt.suptitle(gfrac_str+"% above threshold\n"+nmf_fn)
plt.title("median used")
good_median=np.nanmedian(subtracted[good_frames,:,:],axis=0)
plt.imshow(good_median,#norm=matplotlib.colors.LogNorm(0.001,good_median.max()),
               #cmap=plt.cm.magma,
           interpolation='nearest',
              vmin=0.01,
           vmax=10)#np.nanmax(good_median))
plt.colorbar()
plt.subplot(322)
plt.title("median discarded")


plt.imshow(bad_Median,#norm=matplotlib.colors.LogNorm(0.001,good_median.max()),
           vmin=0.01,
           vmax=10
               #cmap=plt.cm.magma,
          )#np.nanmax(good_median))
fits.writeto(nmf_fn.replace("/","_")
             +"good_frames{}percent.fits".format(int(np.round(gfraction*100))),
             good_median,overwrite=True)
fits.writeto(nmf_fn.replace("/","_")
             +"bad_frames{}percent.fits".format(int(np.round(bfraction*100))),
             bad_Median,overwrite=True)
fits.writeto(nmf_fn.replace("/","_")
             +"goodSTD_frames{}percent.fits".format(int(np.round(gfraction*100))),
             good_std,
             overwrite=True)

plt.colorbar()
plt.subplot(323)

plt.imshow(good_std,vmin=0,cmap=plt.cm.gnuplot,vmax=np.nanmax(good_median),)

plt.title("$stdev$(used frames)")

plt.colorbar()

plt.subplot(324)
plt.imshow(bad_std,vmin=0,
           vmax=np.nanmax(good_median),
           cmap=plt.cm.gnuplot)
plt.title("$stdev$(discarded frames)")
plt.colorbar()

plt.subplot(325)

plt.imshow(good_median/good_std,vmin=0,vmax=2,cmap=plt.cm.plasma,)
plt.title("$SNR$(used frames)")


plt.colorbar()
plt.subplot(326)
plt.imshow(bad_Median/bad_std,vmin=0,vmax=2,cmap=plt.cm.plasma,)
plt.title("$SNR$(discarded frames)")
plt.colorbar()

In [24]:
def displ_scale(array,ps=.05*u.arcsec,
                d=None,
                ax=None,
                cmap=plt.cm.viridis,
                grid=False,
                cbar=True,
                cbar_label=None,
                xticks=None,
                **kwargs):
    if ax==None:
        ax=plt.subplot(111)
    nx = int(array.shape[1])
    ny= int(array.shape[0])
    halfpix = ps.to(u.arcsec).value*0.5
    extent = ([-(nx/2*ps).to(u.arcsec).value-halfpix,
                (nx/2*ps).to(u.arcsec).value+halfpix,
                -(ny/2*ps).to(u.arcsec).value-halfpix,
                +(ny/2*ps).to(u.arcsec).value+halfpix
              ])
    print(nx,ny,ps,halfpix)
    print(ax.get_xticks())
    print(extent)
        
    try:
        plt.imshow(array.decompose().value,interpolation='nearest',
                   extent=extent,cmap=cmap,**kwargs)
        print(ax.get_xticks())
    except Exception as err:
        print(err)
        plt.imshow(array,interpolation='nearest',
                   cmap=cmap,
                   extent=extent,**kwargs)
    cb=plt.colorbar(format='%.3g')
    
    try:
        if cbar_label is None:
            cb.ax.set_title(array.decompose().unit,fontsize=9)
        else:
            cb.ax.set_title(cbar_label,fontsize=9)
    except Exception as err:
        print(err)
    if grid:
        ax.grid(linewidth=0.1)
    if d is not None:
        axAU = ax.twiny()
        arcsec = ax.get_xticks()
        axAU.set_xticks((arcsec))   
        print(arcsec)
        AU=np.round(arcsec*d,1)
        print(AU)
        axAU.set_xticklabels(AU)
        #ax2.set_xticklabels(AU)
        axAU.set_xlabel("AU")
        axAU.set_xlim(np.array(ax.get_xlim()))

In [None]:
plt.figure(figsize=[8.5,3.8],dpi=220)
ax1=plt.subplot(122)
displ_scale(good_median/good_std,vmin=0,#vmax=2,
            grid=False,
            ax=ax1,
            #d=3.2,
            cbar_label="SNR",
            cmap=plt.cm.gnuplot,
            vmax=5,
           )#np.nanmax(good_median))
#plt.title("$SNR$(used frames)")
#plt.colorbar(label="SNR")
ax1.set_xlabel('arcsec')
ax1.set_yticklabels("")
ax=plt.subplot(121)
if nmf_fn.find("klip") != -1:
    vmin=-1
else:
    vmin=0
    
    #plt.grid()
displ_scale(good_median,#norm=matplotlib.colors.LogNorm(0.001,good_median.max()),
              vmin=0,
            grid=True,
            ax=ax,
            #d=3.2,
            cbar_label="c/s",
            
            cmap="viridis",
           vmax=20)#np.nanmax(good_median))
ax.set_xlabel('arcsec')
#plt.suptitle(nmf_fn)
plt.tight_layout()

In [33]:
hdu = fits.PrimaryHDU(good_median)
hdu.writeto('wbackman.fits',overwrite=True)