In [None]:
import scipy.io
from chromatic import *
from chromatic import plt, u, np
import pandas as pd
import numpy as np

import lightkurve as lk
%matplotlib inline

# Read in the data (extracted from .ima files)

In [None]:
G_102_back_dict = scipy.io.readsav('WFC3_data/G102/Backward_spectra.sav', verbose=False)
G_102_for_dict = scipy.io.readsav('WFC3_data/G102/Forward_spectra.sav', verbose=False)
G_141_back_dict = scipy.io.readsav('WFC3_data/G141/Backward_spectra.sav', verbose=False)
G_141_for_dict = scipy.io.readsav('WFC3_data/G141/Forward_spectra.sav', verbose=False)

F21_BJD_times = pd.read_csv('data/F21_bjdtimes.csv')['BJD'][:]
S22_BJD_times = pd.read_csv('data/S22_bjdtimes.csv')['BJD'][:]

# Initialize dictionaries to organize the data

In [None]:
visits = {
    'F21': {
        'Grism': 'G141',
        'Forward': G_141_for_dict,
        'Backward': G_141_back_dict,
        'wave_lower': 1.13, # determined by hand
        'wave_upper': 1.66,
        'time_lower': 2459455.708,
        'time_upper': 2459455.737,
        'T0 (BJD_TDB)': 2459455.98,
        'exp (s)': 4.970405,
        'time cutoff': 2459455.77, # for removing the first orbit
        'duration': 3.5 * u.hour,
        'period': 8.463 * u.day
            },
    'S22': {
        'Grism': 'G102',
        'Forward': G_102_for_dict,
        'Backward': G_102_back_dict,
        'wave_lower': 0.80,
        'wave_upper': 1.13,
        'time_lower': 2459684.215,
        'time_upper': 2459684.242,
        'T0 (BJD_TDB)': 2459684.48, # This is 27 planetary orbits after the first transit
        'exp (s)': 9.662994, 
        'time cutoff': 2459684.27, # for removing the first orbit
        'duration': 3.5 * u.hour,
        'period': 8.463 * u.day
    }
}

In [None]:
search_result = lk.search_lightcurve('AU Mic', mission='TESS')

In [None]:
for visit in ['F21','S22']:
    
    for direction in ['Backward']:

        print(visit,direction)
        print('')

        if visit == 'S22': 
            n = 133
        if visit == 'F21': 
            n = 86

        # Load data tables
        visit_data = visits[f'{visit}']

        t0 = visit_data['T0 (BJD_TDB)'] * u.day
        trimmed = read_rainbow(f'data/{visit}_{direction}_trimmed.rainbow.npy')

        lc = search_result[1].download().bin(time_bin_size=0.005)
        lc.time = lc.time + 2457000.0 * u.day
        lc.time = lc.time + (n * 4.86 * u.day) #New array which should overlap heavily with the WFC3 observations
        normlc = lc.normalize()

        normlc.plot()
        plt.xlim(t0.value-1,t0.value+3.86)
        # plt.xlim(trimmed.time.value.min(),trimmed.time.value.max() )
        # plt.ylim(0.99,1.01)
        plt.axvline(t0.value,color='darkred',zorder=100)
        plt.axvspan(trimmed.time.value.min(),trimmed.time.value.max(),zorder=-500,alpha=0.6)

In [None]:
for visit in ['F21','S22']:
    
    for direction in ['Backward']:

        print(visit,direction)
        print('')

        if visit == 'S22': 
            n = 133
        if visit == 'F21': 
            n = 86

        # Load data tables
        visit_data = visits[f'{visit}']

        t0 = visit_data['T0 (BJD_TDB)'] * u.day
        trimmed = read_rainbow(f'data/{visit}_{direction}_trimmed.rainbow.npy')

        lc = search_result[1].download().bin(time_bin_size=0.005)
        lc.time = lc.time + 2457000.0 * u.day # Convert from BTJD to BJD
        lc.time = lc.time + (n * 4.86 * u.day) # New array which should overlap heavily with the WFC3 observations
        normlc = lc#.normalize()
        if visit=='F21':
            ini = 305
            fini = 394
        if visit == 'S22':
            ini = 321
            fini = 410
        flux = normlc.flux.value[ini:fini]
        err = normlc.flux_err.value[ini:fini]
        time = normlc.time.value[ini:fini]
        
        # If we want to sum flux along the wavelength-axis to generate an integrated 'white light' curve,
        # we can try the following:
        dw = np.nanmedian(trimmed.wavelength.value[1:]-trimmed.wavelength.value[:-1])
        white_light_curve = np.nansum(trimmed.flux*dw, axis=0)
        white_light_curve_err = np.sqrt(white_light_curve)

        rotation_model = bintogrid(x=time,
                                   y=(flux/np.nanmedian(flux)),
                                   newx=trimmed.time.value)
 
        ''' Plot '''
        plt.figure(figsize=(6,4))
        plt.errorbar(trimmed.time.value, white_light_curve/np.nanmedian(white_light_curve),
                     yerr=white_light_curve_err/np.nanmedian(white_light_curve), fmt='',label='WFC3 Data')
        plt.plot(rotation_model['x'],rotation_model['y'],label='TESS Sector 27')
        plt.errorbar(trimmed.time.value, white_light_curve/np.nanmedian(white_light_curve)/rotation_model['y'],
                     yerr=(white_light_curve_err/np.nanmedian(white_light_curve))/rotation_model['y'], fmt='',label='WFC3 Data Normalized by stellar rotation')
        plt.axvline(t0.value,color='darkred',zorder=100,label='T0')
        # plt.axvspan(trimmed.time.value.min(),trimmed.time.value.max(),zorder=-500,alpha=0.6)
        plt.legend()
        plt.show()
        plt.clf()

# Fit a quadratic model to the data, masking the transit

In [None]:
from scipy.optimize import curvefit

In [None]:
from lmfit import Model

In [None]:
# pre_transit = trimmed.time.value < (t0.value-0.065)
# post_transit = trimmed.time.value > (t0.value+0.085)
# ok_times = pre_transit+post_transit
# x = trimmed.time.value[ok_times]
# y = white_light_curve[ok_times]

# Define a quadratic function
def quadratic_baseline_transit(x, a, b, c):

    baseline = a * x**2 + b * x + c
    
    return transit_model

# Create a model based on the quadratic function
model = Model(quadratic)

# Initialize parameters with guesses
params = model.make_params(a=-1, b=1, c=1)

# Fit the model to the data
result = model.fit(y, params, x=x)

In [None]:
# Print the fitting result
print(result.fit_report())

# Plot the data and the fit
plt.scatter(x, y, label='Data')
plt.plot(x, result.best_fit, label='Best Fit', color='red')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.title('Quadratic Fit using lmfit')
plt.show()

In [None]:
plt.scatter(x,y/result.best_fit)