In [1]:
import math as mt
import pandas as pd
import seaborn as sb
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
#from IPython.display import display, HTML
from numba import njit, prange
from astropy.io import fits as pf
from emcee import EnsembleSampler
#from .autonotebook import tqdm as notebook_tqdm
from tqdm.auto import tqdm
from corner import corner
from pytransit import QuadraticModel
from pytransit.utils.de import DiffEvol
from pytransit.orbits.orbits_py import as_from_rhop,  i_from_ba
from pytransit.param.parameter import (ParameterSet, GParameter, PParameter, LParameter,NormalPrior as NP, UniformPrior as UP)
import batman
#from test_lc_batman import make_lc

np.seterr('ignore')
#np.random.seed(0)

{'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'ignore'}

In [3]:
@njit(parallel = True, cache = False, fastmath = True)

def lnlike_normal_v(o, m , e, model = None):
    m = np.atleast_2d(m)
    npv = m.shape[0]
    npt = o.size
    lnl = np.zeros(npv)
    for i in prange(npv):
        for j in range(npt):
            #if isinstance(model, (np.ndarray, np.generic)):
     
            lnl[i] += (-np.log(e[i]) - 0.5*np.log(2*np.pi) - 0.5*((o[j] - m[i,j])/e[i])**2)

                    
    return lnl


In [4]:


class LPFunction:

    def __init__(self, name: str, times: np.ndarray = None, fluxes: np.ndarray = None, h_model: np.ndarray = np.array([])):

        self.tm = QuadraticModel(klims = (np.sqrt(.1), np.sqrt(.15)), nk = 512, nz = 512)


        # LPF name 
        # ________

        self.name = name

        # Declare high-level objects
        # __________________________

        self.ps = None          # Parameterization
        self.de = None          # Differential evolution optimizer
        self.sampler = None     # MCMC sampler


        # Initialize data
        # _______________

        self.times = np.asarray(times)
        self.fluxes = np.asarray(fluxes)
        #print(self.times)
        #print(times)
        #print(fluxes[0:10])
        #self.time_array = np.linspace(-0.5, 0.5, 1000)
        self.time_array = np.linspace(-0.5, 0.5, 540)
        self.tm.set_data(self.time_array)
        self.params = batman.TransitParams()
        self.params.t0 = 0
        self.params.per = 4.7361 
        self.params.rp = 0 
        self.params.a = 4.98
        self.params.ecc = 0
        self.params.w = 90
        self.params.inc = 90
        self.params.limb_dark = "linear"
        self.params.u = [0.5]
        self.m = batman.TransitModel(self.params, self.time_array, max_err = .02)
        # Define the parameterization
        # ___________________________

        self.ps = ParameterSet([GParameter('rp_rs', 'area_ratio',       'A_s',    UP(0.03, .2), bounds = (0.03, .2)),
                             GParameter('loge', 'log10_error',    '',       UP(-10,-2), bounds = (-10,-2))])
                                #sampling in log space, can do so uniformly

        
        self.ps.freeze()
        #plt.plot(self.time_array, fluxes)
        #plt.show()

        
    def create_pv_population(self, npop = 50):
        return self.ps.sample_from_prior(npop)

    def baseline(self, pv):
        return 1

    def transit_model(self, pv, copy = True):
        pv = np.atleast_2d(pv)
        model = []

        
        for parameter in pv:
            self.params.rp = parameter[0]
            
            model.append(self.m.light_curve(self.params))

        return np.array(model)
        


    def flux_model(self, pv):
        return self.transit_model(pv) * self.baseline(pv)

    def residuals(self, pv):
        return self.fluxes - self.flux_model(pv)

    
    def set_prior(self, pid: int, prior) -> None:
        self.ps[pid].prior = prior

    def lnprior(self, pv):
        return self.ps.lnprior(pv)

    def lnlikelihood(self, pv):
        flux_m = self.flux_model(pv)
        errors = 10**(np.atleast_2d(pv)[:,1])
        
        return lnlike_normal_v(self.fluxes, flux_m, errors)

    def lnposterior(self, pv):
        lnp = self.lnlikelihood(pv) + self.lnprior(pv)
        return np.where(np.isfinite(lnp), lnp, -np.inf)

    def __call__(self, pv):
        return self.lnposterior(pv)

    def optimize(self, niter=200, npop = 500, population = None, label = 'Global optimization', leave = False):
        if self.de is None:
            self.de = DiffEvol(self.lnposterior, np.clip(self.ps.bounds, -1, 1), npop, c = 0.1, maximize = True, vectorize = True)
            if population is None:
                self.de._population[:,:] = self.create_pv_population(npop)
            else:
                self.de._population[:,:] = population

        for _ in tqdm(self.de(niter), total = niter, desc = label, leave = leave):
            pass

    def sample(self, niter = 500, thin = 5, label = 'MCMC sampling', reset = True, leave = True):
        if self.sampler == None:
            self.sampler = EnsembleSampler(self.de.n_pop, self.de.n_par, self.lnposterior, vectorize = True)
            pop0 = self.de.population
            
            #print(self.de.minimum_location)

        else:
            pop0 = self.sampler.chain[:,-1,:].copy()
        if reset:
            self.sampler.reset()
        for _ in tqdm(self.sampler.sample(pop0, iterations = niter, thin = thin), total = niter, desc = label, leave = False):
            pass

    def posterior_samples(self, burn: int = 0, thin: int = 1):
        fc = self.sampler.chain[:, burn::thin, :].reshape([-1, self.de.n_par])
        return pd.DataFrame(fc, columns=self.ps.names)

    def plot_mcmc_chains(self, pid: int=0, alpha: float=0.1, thin: int=1, ax=None):
        fig, ax = (None, ax) if ax is not None else subplots()
        ax.plot(self.sampler.chain[:, ::thin, pid].T, 'k', alpha=alpha)
        fig.tight_layout()
        return fig

    def plot_light_curve(self, model: str = 'de', figsize: tuple = (13, 4)):
        fig, ax = plt.subplots(figsize = figsize, constrained_layout = True)
        cp = sb.color_palette()

        if model == 'de':
            pv = self.de.minimum_location
            
            err = 10**pv[1]
            
        elif model == 'mc':
            fc = np.array(self.posterior_samples())
            pv = np.random.permutation(fc)[:300]
            
            err = 10**np.median(pv[:,1], 0)
            
        ax.errorbar(self.time_array, self.fluxes, err, fmt = '.', c = cp[4], alpha = 0.75)
        #ax.scatter(self.time_array, self.fluxes)
        if model == 'de':
            ax.plot(self.time_array, self.flux_model(pv)[0], c = cp[0])
            pass
        if model == 'mc':
            flux_pr = self.flux_model(fc[np.random.permutation(fc.shape[0])[:1000]])
            flux_pc = np.array(np.percentile(flux_pr, [50, 0.15,99.85, 2.5,97.5, 16,84], 0))
            #[ax.fill_between(self.time_array, *flux_pc[i:i+2,:], alpha=0.2,facecolor=cp[0]) for i in range(1,6,2)]
            ax.plot(self.time_array, flux_pc[0], c=cp[0])
        plt.setp(ax, xlim=self.time_array[[0,-1]], xlabel='Time', ylabel='Normalised flux')
        #plt.ylim((.85, 1.15))
        plt.show()
        return fig, ax







In [5]:
def fit_lightcurve(time_array, flux_array):

    lpf = LPFunction('K', time_array, flux_array)
    npop, de_iter, mc_reps, mc_iter, thin = 50, 200, 3, 1000, 10

    lpf.optimize(de_iter, npop)
    #lpf.plot_light_curve() #plot fit if you want to!
    lpf.sample(1000, thin = thin)

    for i in range(mc_reps):

        lpf.sample(mc_iter, thin = thin, reset = True, label = 'MCMC sampling')

    fc = lpf.sampler.chain.reshape([-1, lpf.sampler.chain.shape[-1]])
    mp = np.median(fc, 0)

    df = pd.DataFrame(data=fc.copy(), columns = lpf.ps.names)

    df['rp_rs2'] = 1e6*df.rp_rs**2 #transit depths in ppm
    df['e (ppm)'] = 1e6*(10**df.loge)
    table = df.describe()
    #print(table) 
    
    transit_depth = np.mean(df['rp_rs2']) #mean transit depth
    std_dev_transit_depth = np.std(df['rp_rs2'])

    return transit_depth, std_dev_transit_depth, df

In [7]:


#if __name__ == "__main__":

    #time, flux, _, _, _ = make_lc(1000, .25e-3)
    
    #transit_depth, std_dev_transit_depth, df = fit_lightcurve(time, flux)
    
    #corner(df[['rp_rs2', 'loge']], labels = ['Rp_Rs2', 'Log_error']) # show corner plot
    #plt.show()

