In [None]:
from itertools import product

import numpy as np

from astropy.timeseries import TimeSeries, aggregate_downsample
from astropy.time import Time
from astropy.table import Table, Column
from astropy import units as u
from matplotlib import pyplot as plt

from stellarphot.transit_fitting import TransitModelFit, TransitModelOptions
from stellarphot.io import TOI
from stellarphot.plotting import plot_transit_lightcurve
from stellarphot.gui_tools.photometry_widget_functions import TessAnalysisInputControls, filter_by_dates


### 0. Get some data

+ Select photometry file with relative flux
+ Select passband
+ Select TESS info file

In [None]:
taic = TessAnalysisInputControls()
taic

In [None]:
# 👉 File with photometry, including flux
photometry_file = taic.photometry_data_file
inp_photometry = taic.phot_data

# 👉 File with exoplanet info in
tess_info_output_file = taic.tic_info_file
tess_info = TOI.model_validate_json(tess_info_output_file.read_text())

### Get just the target star and some information about it

In [None]:
if np.isnan(inp_photometry['bjd']).all():
    inp_photometry.add_bjd_col()


In [None]:
photometry = inp_photometry.lightcurve_for(1, flux_column="relative_flux", passband=taic.passband).remove_nans()

### You may need to alter some of the settings here

### Fit settings

+ Do any detrending by a covariate?
+ Which parameters are fixed?

In [None]:
# These affect the fitting that is done

model_options = TransitModelOptions()

### Find the OOT region and use it to get normalization factor

In [None]:
that_transit = tess_info.transit_time_for_observation(photometry.time)
start = that_transit - tess_info.duration / 2
mid = that_transit
end = that_transit + tess_info.duration / 2

after_transit = (photometry["bjd"] - 2400000 * u.day) > end

outside_transit = (photometry["bjd"] < start) | (photometry["bjd"] > end)

normalization_factor = np.nanmean(1 / photometry["relative_flux"][outside_transit])
normalized_flux = Column(photometry["relative_flux"] * normalization_factor, name="normalized_flux")
norm_flux_error = Column(normalization_factor * photometry["relative_flux_error"].value, name="normalized_flux_error")
photometry.add_columns([normalized_flux, norm_flux_error])


### Bin Data

Need
* data table
* start
* end
* bin_size

Data is binned twice because one finds means and the other errors

**Also make times smaller**

In [None]:
t_ob = Time(photometry["bjd"], scale="tdb", format="jd")
ts = TimeSeries(
    [
        photometry["normalized_flux"],
        photometry["airmass"],
        photometry["xcenter"],
        photometry["sky_per_pix_avg"],
        photometry["width"],
    ],
    time=t_ob,
)
ts2 = TimeSeries(
    [Column(
        data=photometry["normalized_flux_error"],
        name="normalized_flux_error"
    )],
    time=t_ob
)

first_time = photometry["bjd"][0] - 2400000
last_time = photometry["bjd"][-1] - 2400000


def add_quad(x):
    try:
        n = len(x)
    except TypeError:
        n = 1
    return np.sqrt(np.nansum(x**2)) / n


binned = aggregate_downsample(ts, time_bin_size=model_options.bin_size * u.min)
binned2 = aggregate_downsample(ts2, time_bin_size=model_options.bin_size * u.min, aggregate_func=add_quad)

binned["normalized_flux_error"] = binned2["normalized_flux_error"]
binned = binned[~np.isnan(binned["normalized_flux"])]

## Model, fit, plot

### Create the model 

In [None]:
# Make the model
mod = TransitModelFit()

# Setup the model
mod.setup_model(
    binned_data=binned,
    t0=mid.jd - 2400000,  # midpoint, BJD
    depth=tess_info.depth_ppt,  # parts per thousand
    duration=tess_info.duration.to("day").value,  # days
    period=tess_info.period.to("day").value,  # days
    model_options=model_options,
)


### Run the fit

In [None]:
mod.fit()

### Look at the results

In [None]:
plt.plot(mod.times, mod.data, ".")
plt.plot(mod.times, mod.model_light_curve())
plt.vlines(start.jd - 2400000, 0.98, 1.02, colors="r", linestyle="--", alpha=0.5)
plt.vlines(end.jd - 2400000, 0.98, 1.02, colors="r", linestyle="--", alpha=0.5)
plt.title("Data and fit")
plt.grid()

### Exclude data by date *if needed*


In [None]:
bad_time = filter_by_dates(
    phot_times=photometry["bjd"],
    use_no_data_before=Time(2400000, format="jd", scale="tdb"),
    use_no_data_between=[
        [
            Time(2400000, format="jd", scale="tdb"),
            Time(2400000, format="jd", scale="tdb"),
        ]
    ],
    use_no_data_after=Time(2499999, format="jd", scale="tdb"),
)

photometry = photometry[~bad_time]

In [None]:
mod.model

### Attempt to calculate BIC, but...this seems to have side effects on the rest of notebook 

In [None]:
def evaluate_fits(mod):
    BICs = []
    settings = []
    all_trendable = mod._all_detrend_params
    tf_sequence = product([True, False], repeat=len(all_trendable))
    for fixed in tf_sequence:
        this_summary = []
        for param, fix in zip(all_trendable, fixed):
            trend_mod = getattr(mod.model, f"{param}_trend")
            if fix:
                setattr(mod.model, f"{param}_trend", 0.0)
            trend_mod.fixed = fix
            this_summary.append(f"{param}: {not fix}")

        settings.append(", ".join(this_summary))
        mod.fit()
        BICs.append(mod.BIC)
    return Table(data=[settings, BICs], names=["Fit this param?", "BIC"])

In [None]:
bic_table = evaluate_fits(mod)
bic_table.sort("BIC")
bic_table