### Test whether the goodness of fit criterion developed in ramp_utils/fitter.py can succesfully identify bad fits in a couple of problematic wfc3 situation

This notebook uses several features from the `nonlinear_bkg` notebook by J. Mack, which (is/ will soon be) included in wfc3tools notebook and/or stsci notebooks

In [None]:
import numpy as np
import matplotlib.pyplot as plt

from astropy.io import fits
%matplotlib notebook

### The g.o.f. computing function

In [None]:
'''
The function that computes the likelihood of a ramp given the best fit value,
in the assumption that the flux is constant.
Can be used to "reject" the hypotesis that the flux is constant, and thus to provide a flag for reprocessing
This is an adaptation of the goodness_of_fit method in the fitter.py class.

Rationale: obtain the probability of observing the given differences in electron (from the IMA files),
and compare such probability with the same probability for a certain number of random ramps,
generated starting from the fitted flux value and given the read noise.
It basically measures whether the observed electrons are consistent with poisson + read noise


    :electron_rate: the mean flux (e/s) from flt files, a scalar
    
    :dt: time intervals, numpy array of length NSAMP-1
    
    :diffs: difference of electrons in read_j minus read_{j-1}, lenght NSAMP-1
    
    :good_intervals: intervals that were used in the fit, i.e. those w/o cosmic rays,
                     numpy array of booleans, length NSAMP-1

    :RON: the readout noise in electrons (per single read, not in CDS mode)
    
    :FW: full-well capacity
    
    :ncompare: the number of simulations from which to obtain probabilitites to be compared with the probability of the actual observation
    
    :nsig: number of sigmas to establish the range of integration for the Poisson*Gaussian integral
'''


from scipy.stats import poisson, norm

def goodness_of_fit(electron_rate,dt,diffs,good_intervals,RON,FW=80000,ncompare=10000,nsig=3):
    
    
    lprob = 0.

    for i in range(len(dt)):
        if good_intervals[i] == True:
            gauss_distr = norm(loc=diffs[i]*dt[i],scale=np.sqrt(2)*RON)
            mu = electron_rate*dt[i]
            poiss_distr = poisson(mu=mu)

            low  = np.floor(np.maximum(0,mu-nsig*np.sqrt(mu))).astype(np.int_)
            high = np.ceil(np.minimum(FW,mu+nsig*np.sqrt(mu))).astype(np.int_)
            integral_x = np.arange(low,high,1)

            poiss_pmfs = poiss_distr.pmf(integral_x)
            gauss_pdfs = gauss_distr.pdf(integral_x)            
            prob=np.sum(poiss_pmfs*gauss_pdfs)
            lprob += np.log(prob)

    lprob_v = np.zeros(ncompare)
    nh = norm(loc=0.,scale=np.sqrt(2)*RON)

    for i in range(len(dt)):
        if good_intervals[i] == True:
            mu = electron_rate*dt[i]
            poiss_distr = poisson(mu=mu)
            low  = np.floor(np.maximum(0,mu-nsig*np.sqrt(mu))).astype(np.int_)
            high = np.ceil(np.minimum(FW,mu+nsig*np.sqrt(mu))).astype(np.int_)

            integral_x = np.arange(low,high,1)
            poiss_pmfs = poiss_distr.pmf(integral_x)

            prob_v = np.zeros(ncompare)
            rv = poiss_distr.rvs(size=ncompare) + nh.rvs(size=ncompare)
            for k in range(ncompare):
                gauss_distr = norm(loc=rv[k],scale=np.sqrt(2)*RON)
                gauss_pdfs = gauss_distr.pdf(integral_x)
                prob_v[k] = np.sum(poiss_pmfs*gauss_pdfs)

            lprob_v += np.log(prob_v)                                                      

    BM = lprob > lprob_v
    pval = np.sum(BM).astype(np.float_)/ncompare
    
    return lprob,pval

def goodness_of_fit_gauss(electron_rate,sigma_electron_rate,dt,diffs,good_intervals,RON=20,ncompare=10000):

    lprob = 0.
    lprob_v = np.zeros(ncompare)
    for i in range(len(dt)):
        if good_intervals[i] == True:
            err =  np.sqrt( (sigma_electron_rate*dt[i])**2 + 2*RON**2)
            gauss_distr = norm(loc=electron_rate*dt[i],scale=err)
            
            lprob += np.log(gauss_distr.pdf(diffs[i]*dt[i]))

            rv = gauss_distr.rvs(size=ncompare)
            lprob_v += np.log(gauss_distr.pdf(rv)) 


    BM = lprob > lprob_v
    pval = np.sum(BM).astype(np.float_)/ncompare

    return lprob,pval

## Tests using simulated data

### Function to add variable background to a ramp

In [None]:
from scipy.interpolate import interp1d

def get_vbg_electrons(times,vbg_cr,meas,mean_bg_cr=None):
    '''
    Given a tabulated form for the variable background time dependency,
    generate a number of electrons per each read interval in the ramp
    
    :times:
        times at which the countrate is tabulated
    
    :vbg_cr:
        values of the time variable background at those times   
    
    :meas:
        a RampMeasurement object
    
    :mean_bg_cr:
        the mean value for normalizing the countrate within the interval    
    '''

    
    #Create an interpolator from the tabulated values and interpolate at the ramp read times
    bg_int = interp1d(times,vbg_cr,'quadratic')
    varbg = bg_int(meas.RTS.read_times)

    #Normalize if requested
    if mean_bg_cr is not None:
        dt = meas.RTS.read_times[-1]-meas.RTS.read_times[0]
        t_avg = np.trapz(varbg,meas.RTS.read_times) / dt
        varbg = varbg/t_avg * mean_bg_cr

    #Get the total accumulated electrons
    bg_e=[0]
    bg_e.extend([np.random.poisson(lam=vb*dt) for vb,dt in zip(0.5*(varbg[1:]+varbg[:-1]),meas.RTS.read_times[1:]-meas.RTS.read_times[:-1]) ])
    bg_e = np.asarray(bg_e)
    
    return np.cumsum(bg_e)


### Simulate a ramp

In [None]:
from ramp_utils.ramp import RampTimeSeq,RampMeasurement
from ramp_utils.fitter import IterativeFitter


myramp = RampTimeSeq('HST/WFC3/IR',15,samp_seq='SPARS100') # a WFC3 ramp with 15 samples and a SPARS100 sequence 
myflux = 1
CRdict = {'times':[220., 700.,980],'counts':[300,20,290]}

mymeas = RampMeasurement(myramp,myflux,CRdict=CRdict)

mymeas.test_plot()

### If needed, add variable background and/or cosmic rays

In [None]:
tbg  = np.linspace(0,1500,10)
cbg  = np.array([1.0,1.2,1.5,1.3,2.0,3.5,4.7,3.4,2.0,1.5])
extra_bg= {'times':tbg,'vbg_er':np.power(cbg,1.0),'mean_bg_er':1.0}
extra_bg=None

if extra_bg is not None:
    ebh = get_vbg_electrons(extra_bg['times'],extra_bg['vbg_er'],mymeas,mean_bg_cr=extra_bg['mean_bg_er'])   
    mymeas.add_background(ebh)
 

In [None]:
### Perform the fit

In [None]:
myfitter = IterativeFitter(mymeas,fitpars={'one_iteration_method':'Nelder-Mead'})
err, count, gi, crlcount = myfitter.perform_fit(CRthr=4)
myfitter.test_plot()

### G.o.f. according to my original method

In [None]:
myfitter.goodness_of_fit(mode='full-likelihood')
print(myfitter.mean_electron_rate, myfitter.gof_stat,myfitter.gof_pval)


### G.o.f. according to the modified version that can be used with ***calwf3***

In [None]:
dt = myfitter.RM.RTS.group_times[1:]-myfitter.RM.RTS.group_times[:-1]
diffs = (myfitter.RM.noisy_electrons_reads[1:]-myfitter.RM.noisy_electrons_reads[:-1])/dt
good_intervals = gi
electron_rate = myfitter.mean_electron_rate
RON = myfitter.RM.RON_e
lprob, pval = goodness_of_fit(electron_rate,dt,diffs,good_intervals,RON,FW=80000,ncompare=1000,nsig=3)
print(lprob, pval)


## Now do the test on real data

### Read and display a good and a "bad" image from program 12242, visit BF

In [None]:
path = './Helium_data/'

flt1=fits.getdata(path+'ibohbfb7q_flt.fits',ext=1)
flt2=fits.getdata(path+'ibohbfb9q_flt.fits',ext=1)

fig=plt.figure(figsize=(12,6))
ax1=fig.add_subplot(1,2,1)
ax2=fig.add_subplot(1,2,2)

ax1.imshow(flt1, vmin=-0.25+np.nanmedian(flt1),vmax=.5+np.nanmedian(flt1),cmap='Greys_r',origin='lower')
ax2.imshow(flt2, vmin=-0.25+np.nanmedian(flt2),vmax=.5+np.nanmedian(flt2),cmap='Greys_r',origin='lower')

ax1.set_title('ibohbfb7q (Linear Bkg)',fontsize=20)
ax2.set_title('ibohbfb9q (Non-linear Bkg)',fontsize=20)

plt.tight_layout()

### Rearrange IMAs and FLTs to be fed to the g.o.f. routine

In [None]:
'''
Function that takes a filename and read and rearranges the flt and ima file
to be fed to the goodness of fit function
'''


def rearrange(path, rootname):

    flt=fits.open(path+rootname+'_flt.fits')
    ima=fits.open(path+rootname+'_ima.fits')
    
    wfc3ron = 0.25*(flt[0].header['READNSEA']+flt[0].header['READNSEB']+flt[0].header['READNSEC']+flt[0].header['READNSED'])
   
    ns = ima[0].header['NSAMP']
    
    diffs = np.empty([1014,1014,ns-1],dtype=np.float_)
    gintv = np.ones([1014,1014,ns-1],dtype=np.bool_)
    dt    = np.empty(ns-1)
    mt    = np.empty(ns-1)
    
    for j in range(ns-1):
        
        te = ima['TIME',ns-j-1].header['PIXVALUE']
        ts = ima['TIME',ns-j].header['PIXVALUE']
        ee = ima['SCI',ns-j-1].data[5:-5,5:-5]
        es = ima['SCI',ns-j].data[5:-5,5:-5]
        ge = ima['DQ',ns-j-1].data[5:-5,5:-5]
        gs = ima['DQ',ns-j].data[5:-5,5:-5]

        diffs[:,:,j] = (te*ee-ts*es)/(te-ts)
        
        BM = (gs == 0) & (ge == 8192) 
        gintv[BM,j] = False
        
        dt[j] = te-ts
        mt[j] = 0.5*(te+ts)
     
    rates = flt['SCI',1].data
    fltdq = flt['DQ',1].data
    error = flt['ERR',1].data
    
    diffs = diffs.reshape(-1, diffs.shape[-1])
    gintv = gintv.reshape(-1, gintv.shape[-1])
    rates = rates.flatten()
    fltdq = fltdq.flatten()
    error = error.flatten()
    
    return diffs,gintv,rates,fltdq,error,dt,mt,wfc3ron

### Test that the rearrangement works

In [None]:
d,g,r,fdq,err,dt,mt,ron = rearrange(path,'ibohbfb7q')

j,k = 5,9
i = np.nonzero(~g[:,k]) # indexes of ramps that have a cosmic ray in read k
           
print(g[i[0][j],:]) # take the j-th of those pixel having cosmic rays in read k
f=plt.figure()
plt.scatter(mt,d[i[0][j],:],c=~g[i[0][j],:],cmap='winter');
plt.axhline(r[i[0][j]],c='#bb3311');

### Run the g.o.f. test on a list of wfc3 images

In [None]:
import time
from multiprocessing import Pool

n_jobs = 30
mypool = Pool(n_jobs)

nmax = 10000000
method = 'Gauss'
results = []

imlist = ['ibohbfb7q','ibohbfb9q']
for rootname in imlist:    
    
    print('Starting',rootname)
    diffs_l,good_intervals_l,electron_rate_l,fltdq_l,error_l,dt,mt,RON = rearrange(path,rootname)

    if method == 'full':
        inputs = [ [electron_rate,dt,diffs,good_intervals,RON,80000,100,3] 
                  for j,(electron_rate,diffs,good_intervals) in enumerate(zip(electron_rate_l,diffs_l,good_intervals_l))
                  if j < nmax
                 ]
        ts = time.time()
        lprob_l,pval_l = map(list, zip(*mypool.starmap(goodness_of_fit,inputs,5)))
        te = time.time()
        
    elif method =='Gauss':
        inputs = [ [electron_rate,error,dt,diffs,good_intervals,RON,1000] 
              for j,(electron_rate,error,diffs,good_intervals) in enumerate(zip(electron_rate_l,error_l,diffs_l,good_intervals_l))
              if j < nmax
             ]

        ts = time.time()
        lprob_l,pval_l = map(list, zip(*mypool.starmap(goodness_of_fit_gauss,inputs,5)))
        te = time.time()

    else:
        assert False
    
    
    results.append([lprob_l,pval_l])
    
    print('Elapsed time [minutes]: {}'.format((te-ts)/60.))
    
mypool.close()
mypool.join()


In [None]:
f = plt.figure()
n,b,h = plt.hist(results[0][1],bins=200,histtype='step',label=imlist[0])
for i,r in enumerate(imlist):
    if i > 0:
        plt.hist(results[i][1],bins=b,histtype='step',label=imlist[i]);
plt.legend()

In [None]:
import pickle,bz2

with bz2.BZ2File('./Simulations_results/GOF_Helium.pbz2', 'w') as f:
    dictosave = {'imlist':imlist,'results':results}
    pickle.dump(dictosave,f)

