In [None]:
from itertools import product
import pickle

import ipywidgets as ipw
import numpy as np

from matplotlib import pyplot as plt
from astropy.table import Table

from stellarphot.transit_fitting import TransitModelFit
from stellarphot.transit_fitting.gui import *
from stellarphot.io import TOI
from stellarphot.settings.fits_opener import FitsOpener
from stellarphot.plotting import plot_many_factors
from stellarphot import PhotometryData
from astropy.timeseries import BinnedTimeSeries, TimeSeries, aggregate_downsample
from astropy.time import Time
from astropy.table import Table, Column
from astropy import units as u

### 0. Get some data

In [None]:
fits_openr = FitsOpener(
    title="Select your photometry/flux file",
    filter_pattern=["*.csv", "*.fits", "*.ecsv"],
)
fits_openr2 = FitsOpener(title="Select your TESS info file", filter_pattern=["*.json"])
passband = ipw.Dropdown(description="Filter", options=["gp", "ip"], disabled=True)
box = ipw.VBox()

def update_filter_list(change):
    tab = Table.read(fits_openr.path)["passband"]
    passband.options = sorted(set(tab))
    passband.disabled = False
    passband.value = passband.options[0]


fits_openr.file_chooser.observe(update_filter_list, names="_value")
box.children = [fits_openr.file_chooser, fits_openr2.file_chooser, passband]
box

In [None]:
# 👉 File with photometry, including flux
photometry_file = fits_openr.path

# 👉 File with exoplanet info in
tess_info_output_file = fits_openr2.path

### 👇👇👇 use this to exclude some data (only if needed!) 👇👇👇

In [None]:
use_no_data_before = Time(2400000, format="jd", scale="tdb")

use_no_data_between = [
    [Time(2400000, format="jd", scale="tdb"), Time(2400000, format="jd", scale="tdb")]
]

use_no_data_after = Time(2499999, format="jd", scale="tdb")

In [None]:
photometry = PhotometryData.read(photometry_file)

tess_info = TOI.model_validate_json(tess_info_output_file.read_text())
# with open(tess_info_output_file, "rb") as f:
#     tess_info = pickle.load(f)

In [None]:
phot_times = Time(photometry["bjd"], format="jd", scale="tdb")

n_dropped = 0

bad_data = phot_times < use_no_data_before

n_dropped = bad_data.sum()

if n_dropped > 0:
    print(f"👉👉👉👉 Dropping {n_dropped} data points before BJD {use_no_data_before}")

bad_data = bad_data | (
    (use_no_data_between[0][0] < phot_times) & (phot_times < use_no_data_between[0][1])
)

new_dropped = bad_data.sum() - n_dropped

if new_dropped:
    print(
        f"👉👉👉👉 Dropping {new_dropped} data points between BJD {use_no_data_between[0][0]} and {use_no_data_between[0][1]}"
    )

n_dropped += new_dropped

bad_data = bad_data | (phot_times > use_no_data_after)

new_dropped = bad_data.sum() - n_dropped

if new_dropped:
    print(f"👉👉👉👉 Dropping {new_dropped} data points after BJD {use_no_data_after}")

n_dropped += new_dropped

photometry = photometry[~bad_data]

### You may need to alter some of the settings here

In [None]:
# These affect the fitting that is done

# bin size in minutes
bin_size = 5 * u.min

# Keep the time of transit fixed?
keep_fixed_transit_time = True
transit_time_range = 60 * u.min

# Keep radius of planet fixed?

keep_fixed_radius_planet = False

# Keep radius of orbit fixed?

keep_fixed_radius_orbit = False

# Remove effects of airmas?
fit_airmass = False

# Remove effects of sky background?
fit_spp = False

# Remove effects of change in focus?
fit_width = False

In [None]:
# Enter your object's period here
period = tess_info.period

# Enter the epoch here
epoch = tess_info.epoch  # Time(2458761.602894, scale='tdb', format='jd')

# Enter the duration below
duration = tess_info.duration

# Enter the transit depth here -- get the "ppm" value from ExoFOP-TESS
depth = tess_info.depth_ppt

# Enter object name
obj = f"TIC {tess_info.tic_id}"

# Enter filter
phot_filter = "rp"

In [None]:
# These affect spacing of lines on final plot
high = 1.06
low = 0.82
scale = 0.15 * (high - low)
shift = -0.72 * (high - low)

In [None]:
target_star = photometry["star_id"] == 1

# No changes to the line below, it is grabbing the first time in the data series
then = Time(photometry["bjd"][target_star][0], scale="tdb", format="jd")

date_obs = photometry["date-obs"][target_star][0]
exposure_time = photometry["exposure"][target_star][0]

Looks like we need to normalize the data first.....

In [None]:
band_filter = photometry["passband"] == phot_filter

target_and_filter = target_star & band_filter

In [None]:
photometry = photometry[target_and_filter]

In [None]:
cycle_number = int((then - epoch) / period + 1)
that_transit = cycle_number * period + epoch
that_transit

In [None]:
start = that_transit - duration / 2
mid = that_transit
end = that_transit + duration / 2

after_transit = (photometry["bjd"] - 2400000 * u.day) > end

outside_transit = (photometry["bjd"] < start) | (photometry["bjd"] > end)

normalization_factor = np.nanmean(1 / photometry["relative_flux"][outside_transit])
normalized_flux = Column(photometry["relative_flux"] * normalization_factor, name="relative_flux")

### Bin Data

Need
* data table
* start
* end
* bin_size

In [None]:
len(normalization_factor * photometry["relative_flux_error"].value)

In [None]:
t_ob = Time(photometry["bjd"], scale="tdb", format="jd")
ts = TimeSeries(
    [
        normalized_flux,
        photometry["airmass"],
        photometry["xcenter"],
        photometry["sky_per_pix_avg"],
        photometry["width"],
    ],
    time=t_ob,
)
ts2 = TimeSeries(
    [Column(
        data=normalization_factor * photometry["relative_flux_error"].value,
        name="relative_flux_error"
    )],
    time=t_ob
)

first_time = photometry["bjd"][0] - 2400000
last_time = photometry["bjd"][-1] - 2400000


def add_quad(x):
    try:
        n = len(x)
    except TypeError:
        n = 1
    return np.sqrt(np.nansum(x**2)) / n


binned = aggregate_downsample(ts, time_bin_size=bin_size)
binned2 = aggregate_downsample(ts2, time_bin_size=bin_size, aggregate_func=add_quad)

# binned_time = BinnedTimeSeries(photometry['bjd'], time_bin_start=first_time, time_bin_end=last_time, time_bin_size=bin_size)

### 1. Create the model object

In [None]:
mod = TransitModelFit()

### 2. Load some data

Here we will just load times and normalized flux. You can also set width, spp (sky per pixel) and airmass. The only two that must be set are times and flux.

If you have set `mod.spp`, `mod.width` or `mod.airmass` then those things will be included in the fit. Otherwise, they are ignored.

THE WEIGHTS ARE IMPORTANT TO INCLUDE

In [None]:
not_empty = ~np.isnan(binned["relative_flux"])

mod.times = (np.array(binned["time_bin_start"].value) - 2400000)[not_empty]
mod.data = binned["relative_flux"].value[not_empty]
mod.weights = 1 / (binned2["relative_flux_error"].value)[not_empty]

### 3. Set up the model

You should be able to get the parameters for this from TTF. There are more parameters you can set; `shift-Tab` in the arguments to pull up the docstring, which lists and explains them all.

In [None]:
mod.setup_model(
    t0=mid.jd - 2400000,  # midpoint, BJD
    depth=depth,  # parts per thousand
    duration=duration.to("day").value,  # days
    period=period.to("day").value,  # days
)

### 3.25 Set up airmass, etc

In [None]:
mod.airmass = np.array(binned["airmass"])[not_empty]
mod.width = np.array(binned["width"])[not_empty]
mod.spp = np.array(binned["sky_per_pix_avg"])[not_empty]

### 3.5 Constrain the fits if you want

#### Exoplanet parameters

In [None]:
mod.model.t0.bounds = [
    mid.jd - 2400000 - transit_time_range.to("day").value / 2,
    mid.jd - 2400000 + transit_time_range.to("day").value / 2,
]
mod.model.t0.fixed = keep_fixed_transit_time
mod.model.a.fixed = keep_fixed_radius_orbit
mod.model.rp.fixed = keep_fixed_radius_planet

#### Detrending parameters

In [None]:
mod.model.spp_trend.fixed = not fit_spp
mod.model.airmass_trend.fixed = not fit_airmass
mod.model.width_trend.fixed = not fit_width

In [None]:
detrended_by = []
if fit_airmass:
    detrended_by.append("Airmass")

if fit_spp:
    detrended_by.append("SPP")

if fit_width:
    detrended_by.append("Wdith")

detrended_by = (
    ("Detrended by: " + ",".join(detrended_by)) if detrended_by else "No detrending"
)

### 4. Run the fit

In [None]:
mod.fit()

### 5. Let's try a plot....

In [None]:
plt.plot(mod.times, mod.data, ".")
plt.plot(mod.times, mod.model_light_curve())
plt.vlines(start.jd - 2400000, 0.98, 1.02, colors="r", linestyle="--", alpha=0.5)
plt.vlines(end.jd - 2400000, 0.98, 1.02, colors="r", linestyle="--", alpha=0.5)
plt.title("Data and fit")
plt.grid()

In [None]:
mod.model

In [None]:
# mod._fitter.fit_info

In [None]:
flux_full_detrend = mod.data_light_curve(detrend_by="all")
flux_full_detrend_model = mod.model_light_curve(detrend_by="all")

In [None]:
rel_detrended_flux = flux_full_detrend / np.mean(flux_full_detrend)

rel_detrended_flux_rms = np.std(rel_detrended_flux)
rel_model_rms = np.std(flux_full_detrend_model - rel_detrended_flux)

rel_flux_rms = np.std(mod.data)

In [None]:
grid_y_ticks = np.arange(low, high, 0.02)

In [None]:
# (RMS={rel_flux_rms:.5f})

plt.figure(figsize=(8, 11))
fig, ax = plt.subplots(1, 1, figsize=(8, 11))

plt.plot(
    (photometry["bjd"] - 2400000 * u.day).jd,
    normalized_flux,
    "b.",
    label=f"rel_flux_T1 (RMS={rel_flux_rms:.5f})",
    ms=4,
)

plt.plot(
    mod.times,
    flux_full_detrend - 0.04,
    ".",
    c="r",
    ms=4,
    label=f"rel_flux_T1 ({detrended_by})(RMS={rel_detrended_flux_rms:.5f}), (bin size={bin_size} min)",
)

plt.plot(
    mod.times,
    flux_full_detrend - 0.08,
    ".",
    c="g",
    ms=4,
    label=f"rel_flux_T1 ({detrended_by} with transit fit)(RMS={rel_model_rms:.5f}), (bin size={bin_size})",
)
plt.plot(
    mod.times,
    flux_full_detrend_model - 0.08,
    c="g",
    ms=4,
    label=f"rel_flux_T1 Transit Model ([P={mod.model.period.value:.4f}], "
    f"(Rp/R*)^2={(mod.model.rp.value)**2:.4f}, \na/R*={mod.model.a.value:.4f}, "
    f"[Tc={mod.model.t0.value + 2400000:.4f}], "
    f"[u1={mod.model.limb_u1.value:.1f}, u2={mod.model.limb_u2.value:.1f})",
)

plot_many_factors(photometry, shift, scale)

plt.vlines(start.jd - 2400000, low, 1.025, colors="r", linestyle="--", alpha=0.5)
plt.vlines(end.jd - 2400000, low, 1.025, colors="r", linestyle="--", alpha=0.5)
plt.text(
    start.jd - 2400000,
    low + 0.0005,
    f"Predicted\nIngress\n{start.jd-2400000-int(start.jd - 2400000):.3f}",
    horizontalalignment="center",
    c="r",
)
plt.text(
    end.jd - 2400000,
    low + 0.0005,
    f"Predicted\nEgress\n{end.jd-2400000-int(end.jd - 2400000):.3f}",
    horizontalalignment="center",
    c="r",
)

# plt.vlines(start + 0.005, low, 1, colors='darkgray', linestyle='--', alpha=0.5)
# plt.text(start + 0.005, low+0.001, f'Left\n{start-int(start)+0.005:.3f}', horizontalalignment='center',c='darkgray')
# plt.vlines(end - 0.005, low, 1, colors='darkgray', linestyle='--', alpha=0.5)
# plt.text(end - 0.005, low+0.001, f'Rght\n{end-int(end)-0.005:.3f}', horizontalalignment='center',c='darkgray')


plt.ylim(low, high)
plt.xlabel("Barycentric Julian Date (TDB)", fontname="Arial")
plt.ylabel("Relative Flux (normalized)", fontname="Arial")
plt.title(
    f"{obj}.01   UT{date_obs}\nPaul P. Feder Observatory 0.4m ({phot_filter} filter, {exposure_time} exp, fap 10-25-40)",
    fontsize=14,
    fontname="Arial",
)
plt.legend(loc="upper center", frameon=False, fontsize=8, bbox_to_anchor=(0.6, 1.0))
ax.set_yticks(grid_y_ticks)
plt.grid()

plt.savefig(
    f"TIC{tess_info.tic_id}-01_20200701_Paul-P-Feder-0.4m_gp_lightcurve.png",
    facecolor="w",
)

In [None]:
mod.n_fit_parameters

In [None]:
mod._all_detrend_params

In [None]:
def evaluate_fits(mod):
    BICs = []
    settings = []
    all_trendable = mod._all_detrend_params
    tf_sequence = product([True, False], repeat=len(all_trendable))
    for fixed in tf_sequence:
        this_summary = []
        for param, fix in zip(all_trendable, fixed):
            trend_mod = getattr(mod.model, f"{param}_trend")
            if fix:
                setattr(mod.model, f"{param}_trend", 0.0)
            trend_mod.fixed = fix
            this_summary.append(f"{param}: {not fix}")

        settings.append(", ".join(this_summary))
        mod.fit()
        BICs.append(mod.BIC)
    return Table(data=[settings, BICs], names=["Fit this param?", "BIC"])

In [None]:
bic_table = evaluate_fits(mod)
bic_table.sort("BIC")
bic_table