In [None]:
import itertools
import os.path as path

import aesara_theano_fallback.tensor as tt
import astropy.io.fits as fits
import astropy.units as u
import arviz as az
import celerite2
import celerite2.terms as terms
import celerite2.theano.terms as theano_terms
import corner
import exoplanet as xo
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc3_ext as pmx

import adhocfitter.astro as astro
import adhocfitter.plotting as plotting
import adhocfitter.priors as priors
import adhocfitter.timeseries as aftimeseries

pd.set_option('display.max_rows', 80)

# Disable OpenMP parallelism
import os
os.environ["OMP_NUM_THREADS"] = "1"

DATA_DIR = '../data'
ISOCHRONES_DIR = path.join(DATA_DIR, 'isochrones')

In [None]:
test_epoch = np.array([2110.0659, 1901.719])  # TJD
test_epoch_unc = np.array([0.005, 0.02])
test_period = np.array([9.127055, 3.09833])  # day
test_period_unc = np.array([0.0001, 0.0002])  #np.array([0.00001, 0.000023])
test_last_seen = np.array([2111.0, 1855.3])
# 2073.55739092,  2073.55767828,  2073.55796009]),
#  array([ 1907.91526705,  1907.91617858,  1907.91709297])
test_epoch = (test_last_seen - test_epoch) // test_period * test_period + test_epoch
print(test_epoch)
test_duration = np.array([4.9, 2.0])  # hour
test_duration_unc = np.array([6., 4.])
test_omega_vec = np.array([0., 1.])

test_radius_planet = np.array([0.066, 0.022])  # Stellar radius
test_radius_planet_unc = np.array([0.5, 0.5])
test_impact_param = np.array([0.615, 0.756])
test_limb_dark = np.array([0.37, 0.21])

# Estimated values from previous fits
test_mass_star = 1.05  # HARPS value
test_mass_star_unc = 0.2
test_feh = 0.415
test_eep = 380
# A_V extinction
test_av = 0.15  # From Keivan.
test_av_upper = 0.70  # Upper limit from Schlegel 1998 dust map.
# Gaia EDR3 parallax
PARALLAX_GAIA_EDR3 = 5.76044602669829
PARALLAX_GAIA_EDR3_UNC= 0.01048182
# Gaia EDR3 parallax zero-point correction, computed by:
# zpt.get_zpt(
#     phot_g_mean_mag=10.8542,
#     nu_eff_used_in_astrometry=1.5263987,
#     pseudocolour=float('nan'),
#     ecl_lat=-68.38768994372302,
#     astrometric_params_solved=31,  # 5-param solution.
# )
PARALLAX_GAIA_EDR3 -= -0.012446

test_mass_planet = (np.array([82., 11.]) * u.M_earth).to(u.M_sun).value / test_mass_star
test_mass_planet_unc = (np.array([300., 100.]) * u.M_earth).to(u.M_sun).value / test_mass_star # Upper limit.

print(test_mass_planet, test_mass_planet_unc)

# TIC 8.1 values
# test_density_star = 0.839021  # g cm^-3
# test_density_star_unc = 0.190382
# test_radius_star = 1.18093  # R_sun
# test_radius_star_unc = 0.0597739
# test_mass_star = 0.98  # M_sun
# test_mass_star_unc = 0.131459

# FEROS values
# test_density_star = 1.052  # g cm^-3
# test_density_star_unc = 0.052
# test_radius_star = 1.104  # R_sun
# test_radius_star_unc = 0.011
# test_mass_star = 1.001  # M_sun
# test_mass_star_unc = 0.026

# print(test_mass_star/test_radius_star**3 * astro.SOLAR_DENSITY*1e-3)

test_params = {
    "t0": test_epoch,
    "period": test_period,
    "rp": test_radius_planet,
    "b": test_impact_param,
    "ecc_fix": np.array([False, True]),
    "ecc": np.array([0., 0.]),
    "omega": np.array([np.pi/2, np.pi/2]),
    "u": {"tess": test_limb_dark, "Rc": test_limb_dark, "tess_c": test_limb_dark},
#     "rho_star": test_density_star,
    "m_star": test_mass_star,
#     "r_star": test_radius_star,
    "tdur": test_duration,
    "m_planet": test_mass_planet,
    "av": test_av,
    "parallax": PARALLAX_GAIA_EDR3,
    "feh": test_feh,
    "eep": test_eep,
}

prior_unc = {
    "t0": test_epoch_unc,
    "period": test_period_unc,
    "rp": test_radius_planet_unc,
#     "rho_star": test_density_star_unc,
    "m_star": test_mass_star_unc,
#     "r_star": test_radius_star_unc,
    "m_planet": test_mass_planet_unc,
    "tdur": test_duration_unc,
    "av": test_av_upper,
    "parallax": PARALLAX_GAIA_EDR3_UNC,
    "albedo_bond": 0.7,
}

In [None]:
# HARPS values
TEFF_PRIOR = 5568.
# TEFF_PRIOR_UNC = 66.
TEFF_PRIOR_UNC = 100.
LOGG_PRIOR = 4.38
LOGG_PRIOR_UNC = 0.03
FEH_PRIOR = 0.438
FEH_PRIOR_UNC = 0.044

MAG_GAIA_EDR3_PRIOR = np.array([10.8542, 11.24296, 10.30405])  # G, BP, RP.
FLUX_GAIA_EDR3_PRIOR = np.array([857564.470952316, 434744.994697565, 599156.243355725])
FLUX_GAIA_EDR3_PRIOR_UNC = np.array([126.9724, 209.5703, 142.4705])
MAG_GAIA_EDR3_PRIOR_UNC = np.maximum(np.array([0.02, 0.02, 0.02]), priors.calculate_magnitude_err(FLUX_GAIA_EDR3_PRIOR, FLUX_GAIA_EDR3_PRIOR_UNC))
print(MAG_GAIA_EDR3_PRIOR_UNC)

MAG_2MASS_PRIOR = np.array([9.715, 9.417, 9.303])  # J, H, K.
MAG_2MASS_PRIOR_UNC = np.array([0.028, 0.026, 0.021])

MAG_WISE_PRIOR = np.array([9.249, 9.301, 9.241, 9.])  # 3.4, 4.6, 12, 22
MAG_WISE_PRIOR_UNC = np.array([0.023, 0.02, 0.026, 0.254])

test_params['teff'] = TEFF_PRIOR

prior_unc['age'] = 10.

stellar_params = {
    'teff': (TEFF_PRIOR, TEFF_PRIOR_UNC),
    'feh': (FEH_PRIOR, FEH_PRIOR_UNC),
}

magnitude_groups = {
    '2mass': (MAG_2MASS_PRIOR, MAG_2MASS_PRIOR_UNC, (True, 0)),
    'gaia': (MAG_GAIA_EDR3_PRIOR, MAG_GAIA_EDR3_PRIOR_UNC, (True, 0.02)),
    'wise': (MAG_WISE_PRIOR, MAG_WISE_PRIOR_UNC, (True, 0.01)),
}

In [None]:
tess_time, tess_dflux, tess_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_02.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)

tess_20s_time, tess_20s_dflux, tess_20s_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_03.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)
tess_20s_noise = tess_20s_noise[0]
tess_20s_times = [tess_20s_time]
tess_20s_dfluxes = [tess_20s_dflux]
tess_20s_noises = [tess_20s_noise]

tess_2m_time, tess_2m_dflux, tess_2m_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_03_binned_to_2_min.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)
tess_2m_times = [tess_2m_time]
tess_2m_dfluxes = [tess_2m_dflux]
tess_2m_noises = [tess_2m_noise]

In [None]:
# Exposure time 70 s
astep_time, astep_dflux, astep_unc, astep_design = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_06.csv'),
    'BJD', 'FLUX', 'ERRFLUX', ['SKY'], time_epoch=2450000.)

# Exposure time 30 s (except last cadence of B)
lco_zs_time, lco_zs_dflux, lco_zs_unc, lco_zs_design = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_04_zs.tsv'),
    '#BJD_TDB', 'rel_flux_T1_n', 'rel_flux_err_T1_n', ['Sky/Pixel_T1'], delim_whitespace=True)
lco_B_time, lco_B_dflux, lco_B_unc, lco_B_design = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_04_B.tsv'),
    '#BJD_TDB', 'rel_flux_T1_n', 'rel_flux_err_T1_n', ['Width_T1'], delim_whitespace=True)

pest_Ic_time, pest_Ic_dflux, pest_Ic_unc = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_05_Ic.txt'),
    '#BJD_TDB', 'flux', 'flux_err', delim_whitespace=True)
pest_B_time, pest_B_dflux, pest_B_unc = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_05_B.txt'),
    '#BJD_TDB', 'flux', 'flux_err', delim_whitespace=True)

ground_times = [astep_time, lco_zs_time] #, lco_B_time, pest_Ic_time, pest_B_time]
ground_dfluxes = [astep_dflux, lco_zs_dflux] #, lco_B_dflux, pest_Ic_dflux, pest_B_dflux]
ground_uncs = [astep_unc, lco_zs_unc] #, lco_B_unc, pest_Ic_unc, pest_B_unc]
ground_exp_times = list(np.array([70., 30.]) / 3600 / 24) #  30., 60., 60.]) / 3600 / 24)
ground_supersampling_factors = [1] * 2
ground_filters = ['Rc', 'zs'] #'B', 'Ic', 'B']
ground_detrends = [astep_design, lco_zs_design] #, lco_B_design, None, None]

In [None]:
lc_times = [tess_time] + tess_20s_times + ground_times
lc_dfluxes = [tess_dflux] + tess_20s_dfluxes + ground_dfluxes
lc_uncs = [tess_noise] + tess_20s_noises + ground_uncs
exposure_times = [0.5/24] + [20./60/60/24]*len(tess_20s_dfluxes) + ground_exp_times
supersampling_factors = [15] + [1]*len(tess_20s_dfluxes) + ground_supersampling_factors
# lc_times = [tess_time] + tess_2m_times + ground_times
# lc_dfluxes = [tess_dflux] + tess_2m_dfluxes + ground_dfluxes
# lc_uncs = [tess_noise] + tess_2m_noises + ground_uncs
# exposure_times = [0.5/24] + [2./60/24]*len(tess_2m_dfluxes) + ground_exp_times
# supersampling_factors = [15] + [1]*len(tess_2m_dfluxes) + ground_supersampling_factors
num_lc = 2 + len(ground_dfluxes)
filters = ['tess'] * 2 + ground_filters

# lc_detrend_series = [None] + [None]*len(tess_20s_dfluxes) + ground_detrends
lc_detrend_series = [None] + [None]*len(tess_2m_dfluxes) + ground_detrends

for filter_name in set(filters):
    test_params['u'][filter_name] = test_limb_dark
    test_params[f'u_{filter_name}'] = test_limb_dark

test_params['mean_flux'] = np.array([0.]*num_lc)
prior_unc['mean_flux'] = np.array([0.5]*num_lc)
test_params['lc_jitter'] = np.array([1e-9]*num_lc)
prior_unc['lc_jitter'] = np.array([1.]*num_lc)

np.array(exposure_times)*24*3600, supersampling_factors, filters

In [None]:
rv_table = aftimeseries.read_generic_rv(path.join(DATA_DIR, 'toi_2000_table_08.csv'))
chiron_time, chiron_rv, chiron_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'CHIRON')
feros_time, feros_rv, feros_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'FEROS')
harps_time, harps_rv, harps_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'HARPS')

rv_names = ['chiron', 'feros', 'harps']
rv_times = [chiron_time, feros_time, harps_time]
rv_data = [chiron_rv, feros_rv, harps_rv]
rv_uncs = [chiron_rv_unc, feros_rv_unc, harps_rv_unc]
num_rv_outside = len(rv_data)

test_gamma = np.array([np.average(i) for i in rv_data])
test_rv_unc = np.array([np.average(i) for i in rv_uncs])
test_jitter = np.array([1e-3]*len(rv_data))

test_params['K'] = np.array([23., 9.])
prior_unc['K'] = np.array([50., 30.])
test_params['rv_gamma'] = test_gamma
prior_unc['rv_gamma'] = np.array([1000.]*num_rv_outside)
test_params['rv_jitter'] = test_jitter
prior_unc['rv_jitter'] = np.array([15., 30., 15.])

test_params['rv_fit_planet'] = np.array([True, True])

test_gamma, test_rv_unc, test_jitter, prior_unc['rv_gamma'], prior_unc['rv_jitter']

In [None]:
plot_orbit = xo.orbits.KeplerianOrbit(rho_star=1.065,
    **{k: test_params[k] for k in ['period', 't0', 'b', 'ecc', 'omega']})
fig, axs = plotting.plot_multi_planet_folded_light_curve(
    2, plot_orbit, test_params['rp'],
    lc_times[:2], lc_dfluxes[:2], filters[:2],
    {k:v for k, v in test_params.items() if k[:2]=='u_'},
    [0.]*len(lc_times),
    exposure_times, supersampling_factors,
    [35]*num_lc, [0.07, 0.08], [0.3, 0.2, 0.2, 0.2])
axs[1][0].set_ylim(-0.0055, 0.0015)
# axs[2][0].set_ylim(-0.0055, 0.0015)

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(2)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')
rv_folded_axs[0].set_title('c')
rv_folded_axs[1].set_title('b')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')

plot_orbit = xo.orbits.KeplerianOrbit(rho_star=1.065,
    **{k: test_params[k] for k in ['period', 't0', 'b', 'ecc', 'omega', 'm_planet']})

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': 'gray',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^'},
    {'color': '#a12568', 'fmt': 'd'},
    {'color': '#fec260', 'fmt': 'o'},
]

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs, rv_unfolded_ax, rv_residual_ax,
    2, plot_orbit, None, #test_params['K'],
    test_params['rv_gamma'],
    test_params['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(map_soln, len(rv_data)),
    # model_trend_func=lambda t: polynomial_design_matrix(t, map_soln['rv_time_offset'], 2) @ map_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')

for ax, p in zip(rv_folded_axs, test_params["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.5f}\,\mathrm{{d}}$', transform=ax.transAxes)

In [None]:
def make_mist_interpolator(isochrones_dir):
    mist_track_array_grid = np.load(
        path.join(isochrones_dir, "mist/tracks/full_grid_v1.2_vvcrit0.0.npz"))
    MIST_TRACK_COL_SELECT = ['Mbol', 'logTeff', 'mass', 'logg', 'feh', 'age', 'dt_deep']
    MIST_TRACK_COL_MAP = {n: i for i, n in enumerate(mist_track_array_grid['columns'])}
    MIST_TRACK_COLS = [MIST_TRACK_COL_MAP[n] for n in MIST_TRACK_COL_SELECT]
    MIST_TRACK_COL_FINAL = mist_track_array_grid['columns'][MIST_TRACK_COLS]

    mist_track_slice = mist_track_array_grid['grid'][-5:, :21*6, :707, MIST_TRACK_COLS]

    MIST_METALLICITY = np.arange(-0.5, 0.75, 0.25)
    MIST_MASS = mist_track_array_grid['grid'][0, :21*6, 0, MIST_TRACK_COL_MAP['initial_mass']]
    MIST_EEP = np.arange(1., 708.)
    
    mist_interpolate = xo.interp.RegularGridInterpolator(
        (MIST_METALLICITY, MIST_MASS, MIST_EEP),
        mist_track_slice,
    )
    return mist_interpolate, MIST_TRACK_COL_FINAL

mist_interpolate, MIST_TRACK_COL_MAP = make_mist_interpolator(ISOCHRONES_DIR)
MIST_TRACK_COL_MAP

In [None]:
def make_mist_bc_interpolator(isochrones_dir):
    mist_bc_dir = path.join(isochrones_dir, 'BC', 'mist')
    mist_bc_ubvri = pd.read_hdf(path.join(mist_bc_dir, "UBVRIplus.h5"))
    mist_bc_wise = pd.read_hdf(path.join(mist_bc_dir, "WISE.h5"))

    MIST_BC_SLICE = (
        slice(2500., 15000.),
        slice(None), 
        slice(0., 0.75),
        slice(None),
        3.1,
    )
    mist_bc_grid_df = pd.merge(
        mist_bc_ubvri[['2MASS_J', '2MASS_H', '2MASS_Ks', 'Gaia_G_EDR3', 'Gaia_BP_EDR3', 'Gaia_RP_EDR3']].loc[MIST_BC_SLICE],
        mist_bc_wise.loc[MIST_BC_SLICE],
        left_index=True, right_index=True,
    )
    mist_bc_grid = mist_bc_grid_df.to_numpy().reshape((36, 26, 4, 13, 10))

    MIST_BC_TEFF = np.array(mist_bc_grid_df.index.get_level_values(0).unique())
    MIST_BC_LOGG = np.array(mist_bc_grid_df.index.get_level_values(1).unique())
    MIST_BC_FEH = np.array(mist_bc_grid_df.index.get_level_values(2).unique())
    MIST_BC_AV = np.array(mist_bc_grid_df.index.get_level_values(3).unique())

    mist_bc_interpolate = xo.interp.RegularGridInterpolator(
        [MIST_BC_TEFF, MIST_BC_LOGG, MIST_BC_FEH, MIST_BC_AV],
        mist_bc_grid,
    )
    return mist_bc_interpolate
mist_bc_interpolate = make_mist_bc_interpolator(ISOCHRONES_DIR)

In [None]:
def add_normal_prior(name, test_params, prior_unc, shape=None):
    if shape is None:
        return pm.Normal(
            name, mu=test_params[name], sd=prior_unc[name])
    else:
        return pm.Normal(
            name, mu=test_params[name], sd=prior_unc[name], shape=shape)

def add_uniform_prior(name, test_params, prior_unc, shape=None):
    lower = test_params[name] - prior_unc[name]
    upper = test_params[name] + prior_unc[name]
    if shape is None:
        return pm.Uniform(
            name, lower=lower, upper=upper, testval=test_params[name])
    else:
        return pm.Uniform(
            name, lower=lower, upper=upper, testval=test_params[name], shape=shape)

def _add_fixed_eccentricity(fix_flag):
    if all(fix_flag):
        return None, [None, None]
    if not any(fix_flag):
        ecc_vec = pmx.UnitDisk(f'sqrt_ecc_vec', testval=np.array([1e-6, 1e-6]), shape=(2, len(fix_flag)))
        ecc = pm.Deterministic('ecc', tt.sum(ecc_vec*ecc_vec, axis=0))
        omega_vec = ecc_vec / tt.sqrt(ecc)
        omega = pm.Deterministic('omega', tt.arctan2(omega_vec[1], omega_vec[0]))
        return ecc, omega_vec
    ecc_stack = []
    omega_vec_stack = []
    for i, ecc_fix in enumerate(fix_flag):
        if ecc_fix:
            ecc_stack.append(tt.as_tensor_variable(0.))
            omega_vec_stack.append(tt.as_tensor_variable(np.array([0., 1.])))
        else:
            ecc_vec = pmx.UnitDisk(f'sqrt_ecc_vec_{i}', testval=np.array([1e-6, 1e-6]))
            ecc = tt.sum(ecc_vec*ecc_vec, axis=0)
            ecc_stack.append(ecc)
            omega_vec_stack.append(ecc_vec / tt.sqrt(ecc))
    ecc = pm.Deterministic('ecc', tt.stack(ecc_stack, axis=0))
    omega_vec = tt.stack(omega_vec_stack, axis=1)
    omega = pm.Deterministic('omega', tt.arctan2(omega_vec[1], omega_vec[0]))
    return ecc, omega_vec

def _add_mass_planet(test_params, prior_unc):
    if 'rv_fit_planet' not in test_params:
        return None
    rv_fit_flag = test_params['rv_fit_planet']
    if all(rv_fit_flag):
        return pm.Uniform(
            "m_planet", testval=test_params['m_planet'],
            lower=0., upper=prior_unc['m_planet'], shape=len(rv_fit_flag))
    if any(rv_fit_flag):
        mass_planet_stack = []
        for i, fit_rv in enumerate(rv_fit_flag):
            if fit_rv:
                mass_planet_stack.append(pm.Uniform(
                    f'm_planet_{i}', testval=test_params['m_planet'][i],
                    lower=0., upper=prior_unc['m_planet'][i],
                ))
            else:
                mass_planet_stack.append(tt.as_tensor_variable(0.))
        return pm.Deterministic('m_planet', tt.stack(mass_planet_stack, axis=0))
    return None

def make_multi_planet_transit_model(model, num_planets, test_params, prior_unc):
    num_light_curve = len(lc_dfluxes)
    with model:
        period = add_uniform_prior(
            'period', test_params, prior_unc, shape=num_planets)
        epoch = add_uniform_prior(
            't0', test_params, prior_unc, shape=num_planets)
        radius_planet = pm.Uniform(
            "rp", testval=test_params["rp"],
            lower=0., upper=prior_unc["rp"], shape=num_planets)
        impact_param = xo.distributions.ImpactParameter("b", ror=radius_planet)
        ecc, omega_vec = _add_fixed_eccentricity(test_params['ecc_fix'])
        
        mass_planet = _add_mass_planet(test_params, prior_unc)

#         rho_star = add_normal_prior('rho_star', test_params, prior_unc)
#         mass_star = add_normal_prior('m_star', test_params, prior_unc)
#         radius_star = add_normal_prior('r_star', test_params, prior_unc)
        mass_star = model.m_star
        radius_star = model.r_star

        orbit = xo.orbits.KeplerianOrbit(
            period=period,
            t0=epoch,
            b=impact_param,
            ecc=ecc,
            cos_omega=omega_vec[0],
            sin_omega=omega_vec[1],
            m_star=mass_star,
            r_star=radius_star,
            m_planet=mass_planet*mass_star,
        )
        pm.Deterministic('incl', orbit.incl*(180/np.pi))
        pm.Deterministic('rho_star', orbit.rho_star)
#         pm.Deterministic('t_periastron', orbit.t_periastron)


        radius_planet_conv = (1*u.R_sun).to(u.R_earth).value
        radius_planet_earth = pm.Deterministic('r_planet_earth', radius_planet*radius_star*radius_planet_conv)
        if mass_planet is not None:
            mass_planet_earth = pm.Deterministic("m_planet_earth", mass_planet * astro.SOLAR_MASS_IN_EARTH_MASS)
            density_planet_conv = (1. * u.M_earth/u.R_earth**3).to(u.g/u.cm**3).value / (4./3 * np.pi)
            pm.Deterministic('rho_planet', (mass_planet_earth/radius_planet_earth**3)*density_planet_conv)
            rv_unit_conv = (1 * u.R_sun / u.day).to(u.m / u.s).value
            pm.Deterministic('K', orbit.K0*orbit.m_planet*rv_unit_conv)

        semimajor_axis_conv = (1*u.R_sun).to(u.AU).value
        pm.Deterministic('a', orbit.a*semimajor_axis_conv)
        scaled_semimajor_axis = pm.Deterministic(
            'aor',
            astro.calculate_aor_stellar_radius(
                mass_star, radius_star, period),
        )
        pm.Deterministic('irradiation', astro.calculate_irradiation(model.teff, scaled_semimajor_axis, ecc))
        # bond_albedo = pm.Uniform('albedo_bond', lower=0., upper=prior_unc['albedo_bond'], testval=0.1)
        # astro.calculate_temperature_eq(model.teff, scaled_semimajor_axis, bond_albedo, ecc)
        transit_duration = pm.Deterministic(
            'tdur',
            astro.calculate_transit_duration_2(
                period,
                impact_param,
                scaled_semimajor_axis,
                0 if ecc is None else ecc,
                1 if ecc is None else omega_vec[1],
                radius_planet,
            ) * 24,  # Convert to hours.
        )
        return orbit

In [None]:
def model_set_up_limb_dark_fixed(model, test_params, filters):
    unique_filters = set(filters)
    ldlc_objs = dict()
    with model:
        for filter_name in unique_filters:
            limb_dark_param = pm.Deterministic(
                f"u_{filter_name}",
                tt.as_tensor_variable(np.array(test_params['u'][filter_name])))
            ldlc_objs[filter_name] = xo.LimbDarkLightCurve(limb_dark_param)
    return ldlc_objs

def model_set_up_limb_dark(model, test_params, filters):
    unique_filters = set(filters)
    ldlc_objs = dict()
    with model:
        for filter_name in unique_filters:
            limb_dark_param = xo.distributions.QuadLimbDark(
                f"u_{filter_name}", testval=test_params["u"][filter_name])
            ldlc_objs[filter_name] = xo.LimbDarkLightCurve(
                limb_dark_param)
    return ldlc_objs

def model_add_transit_light_curves(
    model, orbit, ldlc_objs,
    lc_times, lc_dfluxes, lc_uncs,
    exposure_times, supersampling_factors, filters,
    detrend_series=None,
    keep_light_curve=False):
    num_light_curve = len(lc_times)
    if detrend_series is None:
        detrend_series = [None] * num_light_curve
    with model:
        radius_planet = model.rp*model.r_star
        jitter = pm.Uniform(
            "lc_jitter", testval=test_params['lc_jitter'],
            lower=0., upper=prior_unc['lc_jitter'], shape=num_light_curve)
#         total_obs_likelihood = 0
        for i, (lc_time, lc_dflux, lc_noise, texp, supersample, filter_name, detrending) in enumerate(
            zip(lc_times, lc_dfluxes, lc_uncs, exposure_times, supersampling_factors, filters, detrend_series)):
            raw_light_curve = ldlc_objs[filter_name].get_light_curve(
                orbit=orbit,
                r=radius_planet,
                t=lc_time,
                texp=texp,
                oversample=supersample,
            )
            if keep_light_curve:
                pm.Deterministic(f"lc_pred_{i}", raw_light_curve)
            if detrending is None:
                mean_flux_lower = test_params['mean_flux'][i] - prior_unc['mean_flux'][i]
                mean_flux_upper = test_params['mean_flux'][i] + prior_unc['mean_flux'][i]
                mean_flux = pm.Uniform(
                    f'mean_flux_{i}', lower=mean_flux_lower, upper=mean_flux_upper, testval=test_params['mean_flux'][i])
                light_curve = tt.sum(raw_light_curve, axis=-1) + mean_flux
            else:
                detrend_design = np.hstack((np.ones((len(detrending), 1)), detrending))
                detrend_params = pm.Normal(f'lc_detrend_coeffs_{i}', mu=1e-5, sd=1., shape=detrend_design.shape[1])
                light_curve = tt.sum(raw_light_curve, axis=-1) + detrend_design @ detrend_params
            obs_likelihood = pm.Normal(
                f"lc_obs_{i}",
                mu=light_curve,
                sd=tt.sqrt(lc_noise*lc_noise + jitter[i]*jitter[i]),
                observed=lc_dflux,
            )
#             total_obs_likelihood += pm.math.sum(obs_likelihood)
#         pm.Deterministic("log_likelihood_transit", total_obs_likelihood)

In [None]:
def model_set_up_rv(model, num_rvs, test_params, prior_unc):
    with model:
        add_uniform_prior('rv_gamma', test_params, prior_unc, shape=num_rvs)
        pm.HalfNormal(
            'rv_jitter', sigma=prior_unc["rv_jitter"], shape=num_rvs)

def polynomial_design_matrix(x, offset, order):
    return np.vander(x-offset, order+1, increasing=True)[:, 1:]

def model_set_up_polynomial_detrend(model, rv_times, order):
    if order < 1:
        raise ValueError('Polynomial must be at least of linear order')
    min_time = min(np.min(t) for t in rv_times)
    max_time = max(np.max(t) for t in rv_times)
    with model:
        rv_offset = (min_time + max_time) / 2.
        pm.Deterministic('rv_time_offset', tt.as_tensor_variable(rv_offset))
        pm.Normal("rv_trend_coeff", mu=0, sd=10.**-np.arange(1, order+1), shape=order)
        for i, t in enumerate(rv_times):
            design_matrix = tt.as_tensor_variable(polynomial_design_matrix(t, rv_offset, order))
            pm.Deterministic(f'rv_design_matrix_{i}', design_matrix)

def trend_generator(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_design_matrix_{i}'] @ model['rv_trend_coeff']

def trend_generator_trace(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_design_matrix_{i}'].data @ model['rv_trend_coeff'].data

def gp_generator(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_gp_pred_{i}']
        
def gp_matern(map_soln):
    sigma = map_soln['gp_sigma']
    rho = map_soln['gp_rho']
    kernel = terms.Matern32Term(sigma=sigma, rho=rho)
    gp = celerite2.GaussianProcess(kernel)
    return gp

def gp_sho(map_soln):
    sigma = map_soln['gp_sigma']
    rho = map_soln['gp_rho']
    tau = map_soln['gp_tau']
    kernel = terms.SHOTerm(rho=rho, tau=tau, sigma=sigma)
    gp = celerite2.GaussianProcess(kernel)
    return gp

def gp_trend_func_generator(kernel_func, map_soln, rv_times, rv_data, rv_uncs, return_unc=False):
    gp = kernel_func(map_soln)
    
    gp_time = np.concatenate(rv_times)
    gp_sort_args = np.argsort(gp_time)
    gp_time = np.ascontiguousarray(gp_time[gp_sort_args])
    rv_jitters = map_soln['rv_jitter']
    gp_diag = np.ascontiguousarray(
        np.concatenate(
            [unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)]
        )[gp_sort_args])
    gp.compute(gp_time, diag=gp_diag, quiet=True)

    gp_res = np.concatenate([rv - map_soln[f'rv_pred_{i}'] for i, rv in enumerate(rv_data)])[gp_sort_args]
    def gp_trend(t):
        return gp.predict(gp_res, t=t)
    def gp_trend_unc(t):
        mean, variance = gp.predict(gp_res, t=t, return_var=True)
        stdev = np.sqrt(variance)
        return np.vstack((-stdev, stdev))
    if return_unc:
        return gp_trend, gp_trend_unc
    else:
        return gp_trend

def model_set_up_gp_matern32(model, rv_times, rv_residuals, rv_uncs, rv_jitters):
    sigma = pm.Uniform('gp_sigma',
        lower=0., upper=np.sqrt(1000), testval=1.)
    rho = pm.Uniform('gp_rho',
        lower=15./np.sqrt(3), upper=100./np.sqrt(3), testval=16.)
    kernel = theano_terms.Matern32Term(sigma=sigma, rho=rho)
    gp = celerite2.theano.GaussianProcess(kernel)
    
    gp_time = np.concatenate(rv_times)
    gp_sort_args = np.argsort(gp_time)
    gp_time = gp_time[gp_sort_args]
    # print(rv_uncs)
    # print(rv_jitters.eval())
    # print([unc for i, unc in enumerate(rv_uncs)])
    gp_diag = tt.concatenate([unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)])[gp_sort_args]
    gp_res = tt.concatenate(rv_residuals)[gp_sort_args]
    # pm.Deterministic('gp_res', gp_res)
    # mean = pm.Normal('gp_mean', mu=0., sigma=100.)

    # for i, (rv_time, rv_res, rv_unc) in enumerate(zip(rv_times, rv_residuals, rv_uncs)):
        # jitter = rv_jitters[i]
    gp.compute(gp_time, diag=gp_diag, quiet=True)
    gp.marginal('rv_gp_obs', observed=gp_res)
    for i, rv_time in enumerate(rv_times):
        pm.Deterministic(f'rv_gp_pred_{i}', gp.predict(gp_res, t=rv_time))

gp_model_time = np.linspace(1900, 2400, 256)
def model_set_up_gp_sho(model, test_params, prior_unc, rv_times, rv_data, rv_uncs, keep_model=False):
    with model:
        sigma = pm.Uniform('gp_sigma',
            lower=prior_unc['gp_sigma'][0], upper=prior_unc['gp_sigma'][1], testval=test_params['gp_sigma'])
        rho = pm.Uniform('gp_rho',
            lower=prior_unc['gp_rho'][0], upper=prior_unc['gp_rho'][1], testval=test_params['gp_rho'])
        tau = pm.Uniform('gp_tau',
            lower=prior_unc['gp_tau'][0], upper=prior_unc['gp_tau'][1], testval=test_params['gp_tau'])
        kernel = theano_terms.SHOTerm(rho=rho, tau=tau, sigma=sigma)

        gp_time = np.concatenate(rv_times)
        gp_sort_args = np.argsort(gp_time)
        gp_time = np.ascontiguousarray(gp_time[gp_sort_args])
        rv_jitters = model['rv_jitter']
        gp_diag = tt.concatenate(
                [unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)])[gp_sort_args]
        rv_residuals = [rv - model[f'rv_pred_{i}'] for i, rv in enumerate(rv_data)]
        gp_res = tt.concatenate(rv_residuals)[gp_sort_args]

        gp = celerite2.theano.GaussianProcess(kernel)
        gp.compute(gp_time, diag=gp_diag, quiet=True)
        gp.marginal('rv_gp_obs', observed=gp_res)

        gp_reverse_args = np.argsort(gp_sort_args)
        gp_pred = gp.predict(gp_res)[gp_reverse_args]
        rv_lengths = list(map(len, rv_times))
        rv_slice_indices = np.concatenate([np.array([0]), np.cumsum(rv_lengths, dtype=int)])
        for i, (start, end) in enumerate(zip(rv_slice_indices[:-1], rv_slice_indices[1:])):
            pm.Deterministic(f'rv_gp_pred_{i}', gp_pred[start:end])

        if keep_model:
            pm.Deterministic(f'rv_gp_model', gp.predict(gp_res, t=gp_model_time))

def model_add_rv_instruments(
    model, orbit,
    rv_times, rv_data, rv_uncs,
    trend_func=None,
    keep_rv=False,
    compute_likelihood=True):
    if trend_func is None:
        def trend_func(model, num_rv):
            for i in range(num_rv):
                yield 0.
    with model:
        for i, (rv_time, rv, rv_unc, trend) in enumerate(zip(
            rv_times, rv_data, rv_uncs, trend_func(model, len(rv_times)))):
            rv_predict = orbit.get_radial_velocity(rv_time)#, K=model.K)
            rv_predict_sum = tt.sum(rv_predict, axis=1) + model.rv_gamma[i] + trend
            if keep_rv:
                pm.Deterministic(f'rv_pred_{i}', rv_predict_sum)

            if compute_likelihood:
                total_unc = tt.sqrt(rv_unc * rv_unc + model.rv_jitter[i] * model.rv_jitter[i])
                obs_likelihood = pm.Normal(
                    f"rv_obs_{i}",
                    mu=rv_predict_sum,
                    sd=total_unc,
                    observed=rv,
                )

def optimize_model(model, passes=None):
    with model:
        if passes is None:
            return pmx.optimize(start=model.test_point)
        map_soln = pmx.optimize(start=model.test_point, vars=[model[v] for v in passes[0]])
        for p in passes[1:]:
            map_soln = pmx.optimize(start=map_soln, vars=[model[v] for v in p])
        map_soln = pmx.optimize(start=map_soln)
    return map_soln

In [None]:
# def my_hack_log1pexp(x):
#     return tt.nnet.relu(x) + tt.log1p(tt.exp(-abs(x)))

# def my_hack_expit(x):
#     return tt.exp(-my_hack_log1pexp(-x))

def make_stellar_params(
    model, test_params, prior_unc, mist_interpolate,
    fit_parallax=True, fit_extinction=True,
    mist_unc=0.05, teff_sed_unc=0.025, ap_mag_bol_sed_unc=0.02*2.5/np.log(10)):
    with model:
        MassBoundNormal = pm.Bound(pm.Normal, lower=0.1, upper=4.)
        initial_mass = MassBoundNormal(
            'm_star_0', mu=test_params['m_star'], sd=prior_unc['m_star'])
        initial_feh = pm.Uniform('feh_0', lower=0., upper=0.5, testval=test_params['feh'])
        # 202: ZAMS, 353: IAMS, 454: TAMS
        eep = pm.Uniform('eep', lower=202., upper=454., testval=test_params['eep'])

        # 'Mbol', 'logTeff', 'mass', 'logg', 'feh', 'age', 'dt_deep'
        interp_input = tt.stack([initial_feh, initial_mass, eep], axis=0)
        interp_output = mist_interpolate.evaluate(interp_input)
        # mag_bol = interp_output[0]
        log10_teff = interp_output[1]
        m_star = interp_output[2]
        logg_star = interp_output[3]
        feh = interp_output[4]
        log10_age = interp_output[5]
        dlog10t_deep = interp_output[6]

        pm.Deterministic('m_star', m_star)
        age = pm.Deterministic('age', 10**(log10_age-9))
        pm.Potential('age_limit', -1e100*tt.nnet.sigmoid(3e3*(age-prior_unc['age'])))
        # pm.Uniform('age_obs', lower=0., upper=13.8, observed=age)
        pm.Potential('log_dt_deep', (tt.log(dlog10t_deep) + log10_age*np.log(10)))

        teff_mist = pm.Deterministic('teff_mist', 10**log10_teff)
        teff = pm.Normal('teff', mu=teff_mist, sd=teff_mist*mist_unc)
        
        pm.Deterministic('feh_mist', feh)
        pm.Normal('feh', mu=feh, sd=feh*mist_unc)

        radius_star_mist = pm.Deterministic(
            'r_star_mist', astro.calculate_stellar_radius(m_star, logg_star))
        radius_star = pm.Normal('r_star', mu=radius_star_mist, sd=radius_star_mist*mist_unc)
    
        luminosity_star = pm.Deterministic('L_star', astro.calculate_luminosity(teff, radius_star))
        mag_bol = astro.calculate_mag_bolometric(luminosity_star)

        teff_sed = pm.Normal('teff_sed', mu=teff, sd=teff*teff_sed_unc)
        radius_star_sed = pm.Uniform('r_star_sed', lower=0., upper=2*radius_star, testval=radius_star)
        luminosity_star_sed = pm.Deterministic(
            'L_star_sed', astro.calculate_luminosity(teff_sed, radius_star_sed))
        mag_bol_sed = astro.calculate_mag_bolometric(luminosity_star_sed)

        pm.Deterministic('logg_star_mist', logg_star)
        pm.Deterministic('logg_star', astro.calculate_stellar_logg(m_star, radius_star))
        pm.Deterministic('logg_star_sed', astro.calculate_stellar_logg(m_star, radius_star_sed))

        if fit_parallax:
#             PositiveNormal = pm.Bound(pm.Normal, lower=0.)
            parallax = pm.Normal('parallax', mu=test_params['parallax'], sd=prior_unc['parallax'])
            distance = pm.Deterministic('distance', 1000./parallax)
            distance_modulus = 10. - 5. * tt.log10(parallax)
            apparent_mag_bol = pm.Deterministic('ap_mag_bol', mag_bol + distance_modulus)
            apparent_mag_bol_sed = pm.Deterministic('ap_mag_bol_sed', mag_bol_sed + distance_modulus)
            pm.Normal(
                'ap_mag_bol_obs',
                mu=apparent_mag_bol_sed,
                sd=ap_mag_bol_sed_unc,  # Value adopted from 0.02 fractional uncertainty in bolometric flux
                observed=apparent_mag_bol)
        if fit_extinction:
            extinction = pm.Uniform('av', lower=0., upper=prior_unc['av'], testval=test_params['av'])

def model_add_stellar_param_obs(model, stellar_params):
    with model:
        for k, (val, unc) in stellar_params.items():
            pm.Normal(f'{k}_obs', mu=model[k], sd=unc, observed=val)

def model_add_sed(model, mist_bc_interpolate, magnitude_groups):
    for k, (mags, uncs, add_jitter) in magnitude_groups.items():
        if len(mags) != len(uncs):
            raise ValueError(f'Length of magnitudes and their uncertainties do not match for {k}')
    with model:
        bc_interp_input = tt.stack([model.teff_sed, model.logg_star_sed, model.feh, model.av])
        bc_output = mist_bc_interpolate.evaluate(bc_interp_input)
        apparent_mags = model.ap_mag_bol_sed - bc_output
        pm.Deterministic('ap_mag_phot_sed', apparent_mags)
        sed_unc_scale = pm.Uniform('sed_unc_scale', lower=1., upper=4., testval=1.2)

        i = 0
        for k, (mags, uncs, _) in magnitude_groups.items():
            # jitter_flag, jitter_min = add_jitter
            # if jitter_flag:
            #     jitter = pm.Uniform(f'jitter_{k}', lower=jitter_min, upper=jitter_min+0.02, testval=jitter_min+1e-3)
            #     uncs = tt.sqrt(uncs**2 + jitter*jitter)
            num_bands = len(mags)
            model_mags = apparent_mags[i:i+num_bands]
            i += num_bands
            pm.Normal(f'mag_{k}_obs', mu=model_mags, sd=uncs*sed_unc_scale, observed=mags)

In [None]:
models = []
map_solns = []
traces = []

In [None]:
model = pm.Model()
make_stellar_params(model, test_params, prior_unc, mist_interpolate, mist_unc=0.03)
model_orbit = make_multi_planet_transit_model(model, 2, test_params, prior_unc)
model_ldlc_objs = model_set_up_limb_dark(model, test_params, filters)
model_add_transit_light_curves(
    model, model_orbit, model_ldlc_objs,
    lc_times, lc_dfluxes, lc_uncs,
    exposure_times, supersampling_factors, filters,
    detrend_series=lc_detrend_series,
    keep_light_curve=False)
model_set_up_rv(
    model, len(rv_data), test_params, prior_unc)

# model_set_up_polynomial_detrend(model, rv_times, 2)
prior_unc['gp_sigma'] = np.array([0., 100.])
prior_unc['gp_rho'] = np.array([15., 200.])
prior_unc['gp_tau'] = np.array([0., 200.])
test_params['gp_sigma'] = 30.
test_params['gp_rho'] = 20.
test_params['gp_tau'] = 9.
model_add_rv_instruments(
    model, model_orbit,
    rv_times, rv_data, rv_uncs,
    # trend_func=trend_generator,
    keep_rv=True,
    compute_likelihood=False,
)
model_set_up_gp_sho(model, test_params, prior_unc, rv_times, rv_data, rv_uncs)

model_add_stellar_param_obs(model, stellar_params)
model_add_sed(model, mist_bc_interpolate, magnitude_groups)
with model:
    pm.Deterministic('log_prob', model.logpt)
map_soln = optimize_model(
    model,
    [
        ['mean_flux_0', 'mean_flux_1', 'lc_detrend_coeffs_2', 'lc_detrend_coeffs_3'],
        ['rp'],
        # ['t0', 'period'],
        ['av'],
        ['m_star'],
        ['b', 'm_star', 'r_star', 'feh', 'teff'],
        ['r_star_sed', 'teff_sed', 'parallax'],
        ['lc_jitter', 'mean_flux_0', 'mean_flux_1', 'lc_detrend_coeffs_2', 'lc_detrend_coeffs_3'],
        # ['mean_flux', 'lc_jitter'],
        ['m_planet', 'rv_gamma', 'rv_jitter', 'gp_sigma', 'gp_rho', 'gp_tau'], #'rv_trend_coeff'],
    ]
)

models.append(model)
map_solns.append(map_soln)

In [None]:
plot_orbit = xo.orbits.KeplerianOrbit(#ecc=np.array([0.02, 0]),
    **{k: map_soln[k] for k in ['period', 't0', 'b', 'ecc', 'm_star', 'r_star', 'omega']})
fig, axs = plotting.plot_multi_planet_folded_light_curve(
    2, plot_orbit, map_soln['rp'],
    lc_times[:2], lc_dfluxes[:2], filters[:2],
    {k:v for k, v in map_soln.items() if k[:2]=='u_'},
    [map_soln[f'mean_flux_{i}'] for i in range(2)],
    exposure_times, supersampling_factors,
    [40]*num_lc, [0.06, 0.07], [0.3]*num_lc)

axs[0][1].set_xlim(-0.03, 0.03)
# axs[1][0].set_ylim(-0.001, 0.001)
# axs[0][0].set_ylim(-0.001, 0.001)
axs[0][0].set_ylim(-0.0055, 0.0015)
axs[1][0].set_ylim(-0.0055, 0.0015)

In [None]:
plot_orbit = xo.orbits.KeplerianOrbit(#ecc=np.array([0.02, 0]),
    **{k: map_soln[k] for k in ['period', 't0', 'b', 'ecc', 'm_star', 'r_star', 'omega']})
fig, axs = plotting.plot_multi_planet_folded_light_curve(
    1, plot_orbit, map_soln['rp'],
    lc_times[2:], lc_dfluxes[2:], filters[2:],
    {k:v for k, v in map_soln.items() if k[:2]=='u_'},
    [0., 0.], #, 0., map_soln['mean_flux_5'], map_soln['mean_flux_6']],
    exposure_times[2:], supersampling_factors[2:],
    [40]*num_lc, [0.03, 0.07], [0.3]*num_lc,
    detrend_series=lc_detrend_series[2:],
    detrend_coeffs=[
        map_soln['lc_detrend_coeffs_2'],
        map_soln['lc_detrend_coeffs_3'],
    ])
# axs[0][1].set_xlim(-0.03, 0.03)
# axs[1][0].set_ylim(-0.001, 0.001)
# axs[0][0].set_ylim(-0.001, 0.001)
# axs[0][0].set_ylim(-0.0055, 0.0015)
# axs[1][0].set_ylim(-0.0055, 0.0015)

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(2)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')
rv_folded_axs[0].set_title('c')
rv_folded_axs[1].set_title('b')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')

rv_plot_soln = map_soln

plot_orbit = xo.orbits.KeplerianOrbit(
    **{k: rv_plot_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': 'gray',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^'},
    {'color': '#a12568', 'fmt': 'd'},
    {'color': '#fec260', 'fmt': 'o'},
]

rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
    gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs, rv_unfolded_ax, rv_residual_ax,
    2, plot_orbit, rv_plot_soln['K'],
    rv_plot_soln['rv_gamma'],
    rv_plot_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_plot_soln, len(rv_data)),
    trends=gp_generator(rv_plot_soln, len(rv_data)),
    model_trend_func=rv_trend_func,
    model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_plot_soln['rv_time_offset'], 2) @ rv_plot_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')

for ax, p in zip(rv_folded_axs, rv_plot_soln["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)

In [None]:
map_soln

In [None]:
with model:
    trace = pmx.sample(
        tune=4096,
        draws=4096,
        start=map_soln,
        cores=32,
        chains=32,
        target_accept=0.995,
        return_inferencedata=True,
        idata_kwargs={
            'log_likelihood': [
                'rv_gp_obs',
                'teff_obs',
                'feh_obs',
                'ap_mag_bol_obs',
                'mag_2mass_obs',
                'mag_gaia_obs',
                'mag_wise_obs',
            ],
        },
        parameter_groups=[
            pmx.ParameterGroup([model.period, model.t0]),
            pmx.ParameterGroup([
                model.m_star_0, model.feh_0, model.eep,
                model.rp, model.b, model.sqrt_ecc_vec_0,
                model.r_star, model.r_star_sed, model.teff_sed,
                model.sed_unc_scale, model.av, model.parallax,
            ]),
            pmx.ParameterGroup([
                model.rv_gamma, model.rv_jitter, model.m_planet, model.gp_sigma, model.gp_rho, model.gp_tau]),
            pmx.ParameterGroup([
                model.mean_flux_0, model.mean_flux_1,
                model.lc_detrend_coeffs_2, model.lc_detrend_coeffs_3,
                model.lc_jitter,
                model.u_tess, model.u_Rc, model.u_zs,
            ]),
        ],
    )
traces.append(trace)

trace.to_netcdf(
    '../chains/toi2000_trace_20s_tess_astep_lco_pest_rv_gp_sho_sed_15.nc',
    groups=['posterior', 'log_likelihood', 'sample_stats'],
)

In [None]:
trace.to_netcdf(
    '../chains/toi2000_trace_2m_tess_astep_lco_pest_rv_gp_sho_sed_14.nc',
    groups=['posterior', 'log_likelihood', 'sample_stats'],
)

In [None]:
trace['posterior']

In [None]:
display_vars = [
    # "period", "t0",
    "rp", "b", #"tdur",
    "sqrt_ecc_vec_0", #"ecc", "omega",
    # "r_planet_earth",
    # "m_planet_earth",
    # "rho_planet",
    "m_planet",
    "m_star_0", "feh_0", "age", "eep",
    # "m_star", "r_star", "rho_star", "logg_star", "teff",
    "feh", #"parallax", "av", #"mag_bol",
    "r_star_sed", "teff_sed",
    "sed_unc_scale",
    # "jitter_gaia", "jitter_2mass", "jitter_wise",
    # "u_tess",
    # "mean_flux_0", "mean_flux_1", #"mean_flux_5", "mean_flux_6",
    # "lc_detrend_coeffs_2", "lc_detrend_coeffs_3", #"lc_detrend_coeffs_4",
    "lc_jitter",
    "rv_gamma",
    "rv_jitter",
    # "rv_trend_coeff",
    "gp_sigma", "gp_rho", "gp_tau",
]

axs = az.plot_trace(trace, var_names=display_vars)
plt.tight_layout()
# plt.savefig('plot/toi2000_trace_2m_rv_gp_sho.pdf', dpi=144, bbox_inches='tight')

In [None]:
fig = corner.corner(
    trace, var_names=['period', 't0'],
    truths={
        'period': test_params['period'] + prior_unc['period'],
        't0': test_params['t0'] + prior_unc['t0'],
    },
    quantiles=[0.16, 0.5, 0.84],
)

In [None]:
(test_params['period'] + np.array([-prior_unc['period'], [0]*2, prior_unc['period']])).T

In [None]:
(test_params['t0'] + np.array([-prior_unc['t0'], [0]*2, prior_unc['t0']])).T

In [None]:
prior_unc['period'], prior_unc['t0']

In [None]:
(array([ 0.0001,  0.0002]), array([ 0.005,  0.02 ]))

In [None]:
display_vars = [
    "period", "t0", "rp", "b", "tdur",
    "sqrt_ecc_vec_0", "ecc", "omega",
    "r_planet_earth",
    "m_planet_earth",
    "rho_planet",
    "K",
    "m_star_0", "feh_0", "eep",
    "m_star", "r_star", "rho_star", "logg_star", "teff",
    "feh", "age", "parallax", "av", "L_star",
    "r_star_mist", "teff_mist", "feh_mist",
    "r_star_sed", "teff_sed", "ap_mag_bol", "ap_mag_bol_sed",
    "sed_unc_scale",
    # "jitter_gaia", "jitter_2mass", "jitter_wise",
    "u_tess", "u_Rc", "u_zs", #"u_B", "u_Ic",
    "mean_flux_0", "mean_flux_1", #"mean_flux_5", "mean_flux_6",
    "lc_detrend_coeffs_2", "lc_detrend_coeffs_3", #"lc_detrend_coeffs_4",
    # "mean_flux",
    "lc_jitter",
    "rv_gamma", "rv_jitter",
    # "rv_trend_coeff",
    "gp_sigma", "gp_rho", "gp_tau",
    "a", "aor", #"mag_bol", "L_star",
]

summary = az.summary(
    trace, var_names=display_vars,
    circ_var_names=["omega"],
    round_to=15,
    hdi_prob=0.68,
    # kind='stats',
    stat_funcs={
        "median": np.median,
        "-": lambda x: np.median(x) - np.quantile(x, [0.16]),
        "+": lambda x: np.quantile(x, [0.84]) - np.median(x),
    },
    coords={"ecc_dim_0": [0], "omega_dim_0": [0]},
    skipna=True,
    # extend=False,
    # fmt='xarray',
)
summary

In [None]:
traces

In [None]:
corner.c

In [None]:
np.max([summary['ess_bulk'], summary['ess_tail']], axis=1), np.min([summary['ess_bulk'], summary['ess_tail']], axis=1)

In [None]:
2.082685e3 + 9.127055*19 + 3.877537/24*np.array([-0.5, 0, 0.5])

In [None]:
1.946683e-05 * 24 * 3600, 45 parameters

In [None]:
print(summary)

In [None]:
tmp_rp = np.array([0.064583, 0.021672])
tmp_rp_unc = np.array([0.001105, 0.001115])
tmp_rp**2, 2*tmp_rp*tmp_rp_unc

In [None]:
flat_samples = trace.posterior.stack(sample=("chain", "draw"))
median_soln = {k:v.data for k, v in flat_samples.median(dim='sample').items()}
# median_soln = {k: np.median(v, axis=-1) for k, v in flat_samples.items()}
max_post_index = flat_samples.log_prob.argmax(dim='sample')
max_post_soln = {k:v.data for k, v in flat_samples[{'sample': max_post_index}].items()}

In [None]:
flat_samples.ecc.quantile([0.000031671, 0.0015, 0.95, 0.9985], dim="sample")

In [None]:
max_post_soln

In [None]:
np.percentile(flat_samples['m_planet_earth'], [95.], axis=1)

In [None]:
trace.posterior.quantile(np.array([0.16, 0.50, 0.84]), dim=["chain", "draw"])

In [None]:
flat_samples.median('sample')

In [None]:
median_soln['mean_flux']

In [None]:
for i in median_soln['rv_gamma'].data:
    print(i)

In [None]:
max_post_soln['period'][::-1]

In [None]:
# plot_orbit = xo.orbits.KeplerianOrbit(
#     **{k: max_post_soln[k] for k in ['period', 't0', 'b', 'm_star', 'r_star', 'ecc', 'omega', 'm_planet']})

lc_plot_soln = median_soln

plot_soln = {k: lc_plot_soln[k] for k in ['period', 't0', 'b', 'ecc', 'omega', 'm_planet']}
plot_soln.update({k: lc_plot_soln[k] for k in ['m_star', 'r_star']})
plot_orbit = xo.orbits.KeplerianOrbit(
    **plot_soln)

fig, axs = plotting.plot_multi_planet_folded_light_curve(
    2, plot_orbit, lc_plot_soln['rp'],
    lc_times,
    lc_dfluxes,
    ['tess']*num_lc,
    {k:v.data for k, v in lc_plot_soln.items() if k[:2]=='u_'},
    [lc_plot_soln['mean_flux_0'], lc_plot_soln['mean_flux_1']],
    [.5/24, 2./60/24], [15, 1],
#     exposure_times, supersampling_factors,
    [40]*num_lc, [0.035, 0.05], [0.3]*num_lc, dpi=600)
axs[0][0].set_ylim(-0.0055, 0.0015)
axs[1][0].set_ylim(-0.0055, 0.0015)
# axs[2][0].set_ylim(-0.0055, 0.0015)
axs[0][0].set_title('c')
axs[0][1].set_title('b')
axs[-1][0].set_xlabel('Phase')
axs[-1][1].set_xlabel('Phase')
for row in axs:
    row[0].set_ylabel('$\Delta\mathrm{flux} / \mathrm{flux}$')
# fig.set_size_inches(6, 3)
# fig.set_dpi(600)
# fig.savefig('plot/toi2000_lc_2m_rv_gp_sho_sed_prelim.png', dpi=144, bbox_inches='tight')

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(2)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')

rv_folded_axs[0].set_title('b')
rv_folded_axs[1].set_title('c')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-1.0)

rv_plot_soln = max_post_soln

plot_orbit = xo.orbits.KeplerianOrbit(
    **{k: rv_plot_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': '#555555',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^', "alpha": 0.5},
    {'color': '#a12568', 'fmt': 'd', "alpha": 0.3},
    {'color': '#fec260', 'fmt': 'o', "alpha": 1.},
]

rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
    gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs[::-1], rv_unfolded_ax, rv_residual_ax,
    2, plot_orbit, rv_plot_soln['K'],
    rv_plot_soln['rv_gamma'],
    rv_plot_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_plot_soln, len(rv_data)),
    trends=gp_generator(rv_plot_soln, len(rv_data)),
    model_trend_func=rv_trend_func,
    model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_plot_soln['rv_time_offset'], 2) @ rv_plot_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')
rv_unfolded_ax.tick_params(direction="in", labelsize="small")
rv_residual_ax.tick_params(direction="in", labelsize="small")
rv_residual_ax.set_ylim(-30, 30)

for ax, p in zip(rv_folded_axs[::-1], rv_plot_soln["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)
    ax.tick_params(direction="in", labelsize="small")
    # ax.set_xlim((-0.5, 0.5))

rv_folded_axs[0].set_ylim(-25, 25)
rv_folded_axs[1].set_ylim(-45, 45)
    
# fig.savefig('plot/toi2000_rv_2m_rv_gp_sho_sed_prelim.pdf', bbox_inches='tight')

In [None]:
fig, ax = plt.subplots(figsize=(8, 4), dpi=120)
ax.errorbar(harps_time_0, harps_rv_0, harps_rv_unc_0, fmt='.', elinewidth=1, ecolor='gray', alpha=0.8)
ax.errorbar(harps_time_1, harps_rv_1, harps_rv_unc_1, fmt='.', elinewidth=1, ecolor='gray', alpha=0.8)
ax.set_xlabel(f'$\mathrm{{BJD}}_\mathrm{{TDB}} - {TESS_EPOCH}$')
ax.set_ylabel('RV (m/s)')
ax.set_title('TOI-2000 HARPS raw RVs')
fig.savefig('plot/toi2000_rv_harps_raw.png')

In [None]:
harps_time_0

In [None]:
# with model:
fig = corner.corner(
    trace_detrend,
    var_names=display_vars)
# fig.savefig('plot/toi2000_corner_lc_only_no_rhostar_prior.pdf')

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=["period", "t0", "rp", "b", "tdur"], # S"sqrt_ecc_vec_0"],
    coords={"period_dim_0": [1], "t0_dim_0": [1], "b_dim_0": [1], "rp_dim_0": [1], "ecc_dim_0": [1], "omega_dim_0": [1], "tdur_dim_0":[1]},
)
# fig.savefig("plot/toi2000_corner_20s_planet_b_9.png", dpi=300)

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=["period", "t0", "rp", "b", "sqrt_ecc_vec_0"],
    coords={"period_dim_0": [0], "t0_dim_0": [0], "rp_dim_0": [0], "b_dim_0": [0], "ecc_dim_0": [0], "omega_dim_0": [0], "K_dim_0": [0]},
)
# fig.savefig('plot/toi2000_20s_

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=["m_star", "r_star", "r_star_pred", "rho_star", "logg_star", "teff", "feh", "teff_pred", "feh_pred", "mag_bol", "age", 'sed_unc_scale'],
    # divergences=True,
    coords={"b_dim_0": [1]},
)
fig.savefig('plot/toi2000_corner_stellar_param_9.png', dpi=300)

In [None]:
fig = corner.corner(
    np.vstack([flat_samples['t0'], flat_samples['period']]).T,
)

In [None]:
def epoch_min_cov(period_samples, epoch_samples):
    period_dev = period_samples - np.average(period_samples)
    epoch_dev = epoch_samples - np.average(epoch_samples)
    return epoch_samples - period_samples * round((period_dev @ epoch_dev) / (period_dev @ period_dev))

In [None]:
(
    np.quantile(epoch_min_cov(flat_samples['period'].data[0], flat_samples['t0'].data[0]), np.array([0.16, 0.50, 0.84])),
    np.quantile(epoch_min_cov(flat_samples['period'].data[1], flat_samples['t0'].data[1]), np.array([0.16, 0.50, 0.84])),
)

In [None]:
1855.24680948-1855.24456894, 1855.24231904-1855.24456894, 

In [None]:
design_M = np.column_stack([flat_samples['period'].data[0], flat_samples['t0'].data[0]])
# design_M -= np.average(design_M, axis=0)
design_M.T @ design_M

In [None]:
np.average(design_M, axis=1)

In [None]:
np.linalg.eig(design_M.T @ design_M)

In [None]:
round(3.5)

In [None]:
np.linalg.cholesky(cov_M)

In [None]:
X = flat_samples['period'].data[0]
mX = np.average(X)
dX = X-mX
Y = flat_samples['t0'].data[0]
mY = np.average(Y)
dY = Y-mY
(dX@dY)/(dX@dX)

In [None]:
(design_M @ np.array([[  1, -78],
       [ 0,   1]]))

In [None]:
cov_M = np.cov(flat_samples['period'].data[0], flat_samples['t0'].data[0]-0*flat_samples['period'].data[0])
eigen_val, eigen_vec = np.linalg.eig(cov_M)

In [None]:
np.sqrt(1.37694501e-07**2 + 9.12652931e-11**2*43), 3.06997221e-07

In [None]:
with model:
    fig = corner.corner(
        trace,
        var_names=["t0", "period"])

In [None]:
with model:
    fig = corner.corner(
        trace,
        var_names=["rho_star", "sqrt_ecc_vec_0", "b"])
# fig.savefig('plot/toi2000_lc_only_eccen_corner.pdf', bbox_inches='tight')

In [None]:
flat_samples['t0'].data

In [None]:
model.logp

In [None]:
az.to_netcdf(trace, 'chains/toi2000_trace_2m_tess_rv_no_trend_sed_min_unc_prelim.nc')

In [None]:
trace = az.from_netcdf('chains/toi2000_trace_20s_tess_astep_lco_pest_rv_gp_sho_sed_2.nc')

In [None]:
flat_samples = trace.posterior.stack(sample=("chain", "draw"))
median_soln = {k:v.data for k, v in flat_samples.median(dim='sample').items()}
median_soln_unc = {k:v.data for k, v in flat_samples.std(dim='sample').items()}
# median_soln = {k: np.median(v, axis=-1) for k, v in flat_samples.items()}
max_post_index = flat_samples.log_prob.argmax(dim='sample')
max_post_soln = {k:v.data for k, v in flat_samples[{'sample': max_post_index}].items()}

In [None]:
flat_samples.quantile(0.999999426697,dim='sample')['m_planet_earth']

In [None]:
rv_residual_soln = max_post_soln
rv_residuals = []
rv_residual_uncs = []
for i, (rv, rv_unc, rv_jitter) in enumerate(zip(
    rv_data, rv_uncs, rv_residual_soln['rv_jitter'])):
    rv_residuals.append(rv - rv_residual_soln[f'rv_pred_{i}'])
    rv_residual_uncs.append(np.sqrt(rv_unc**2 + rv_jitter**2))

In [None]:
import astropy.timeseries as timeseries

In [None]:
freq, periodogram = timeseries.LombScargle(
    np.concatenate(rv_times),
    np.concatenate(rv_residuals),
    np.concatenate(rv_residual_uncs),
).autopower(minimum_frequency=1./460, maximum_frequency=2.)

In [None]:
lomb_scargle_times = np.concatenate(rv_times)
lomb_scargle_series = np.concatenate(rv_residuals)
lomb_scargle_uncs = np.concatenate(rv_residual_uncs)
lombscargle_dict = xo.lomb_scargle_estimator(
    lomb_scargle_times,
    lomb_scargle_series,
    lomb_scargle_uncs,
    min_period=10.,
    max_period=460.,
)
lombscargle_window_dict = xo.lomb_scargle_estimator(
    lomb_scargle_times,
    np.ones_like(lomb_scargle_times, dtype=np.float64),
    lomb_scargle_uncs,
    min_period=10.,
    max_period=460.,
)
lombscargle_dict

In [None]:
np.savetxt('data/toi2000/toi2000_rv_residuals.txt', np.column_stack([lomb_scargle_times, lomb_scargle_series, lomb_scargle_uncs]))

In [None]:
np.column_stack([rv_times, rv_residuals, rv_residual_uncs])

In [None]:
max_post_soln['rv_gamma'], max_post_soln['rv_jitter']

In [None]:
lombscargle_window_dict

In [None]:
fig, ax = plt.subplots()
ax.plot(*lombscargle_dict['periodogram'])
ax.plot(*lombscargle_window_dict['periodogram'])
ax.set_xscale('log')
ax.set_xlabel('Frequency')
ax.set_ylabel('L-S power')
fig.set_dpi(144)
fig.savefig('plot/toi2000_rv_residual_lomb_scargle.pdf', bbox_inches='tight')

In [None]:
xo.estimate_minimum_mass(
    90.704269208467736,
    np.concatenate(rv_times),
    np.concatenate(rv_residuals),
    np.concatenate(rv_residual_uncs),
    1.08
#     min_period=1.,
#     max_period=460.,
) / (max_post_soln['m_star'] * u.M_sun).to(u.M_jup)

In [None]:
max_post_soln['m_star']

In [None]:
plot_orbit = xo.orbits.KeplerianOrbit(
    period=np.array([ 9.12704946,  3.09833298, 90.704]),
    t0=np.array([ 2055.30363783,  1818.0619524 , 2000.]),
    b=np.array([ 0.63065569,  0.79743971, 0.]),
    ecc=np.array([ 0.0058757,  0., 0.2]),
    omega=np.array([-0.94160963,  1.57079633, 1.57079633]),
    m_planet=np.array([  2.45097661e-04,   2.04797485e-05, 1.3101366e-04]),
    **{k: max_post_soln[k] for k in ['m_star', 'r_star']})

plot_rv_inst_styles = [
    {'color': 'C0', 'fmt': 'o'},
    {'color': 'C1', 'fmt': 'd'},
    {'color': 'C2', 'fmt': '^'},
    {'color': 'C2', 'fmt': '^'},
]

fig, axs = plot_multi_planet_folded_rv(
    3, plot_orbit, np.array([ 22.93022657,   2.74690921, 7.053596]),
    max_post_soln['rv_gamma'],
    max_post_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS', None],
#     trends=trend_generator(max_post_soln, len(rv_data)),
#     model_trend_func=lambda t: polynomial_design_matrix(t, max_post_soln['rv_time_offset'], 1) @ max_post_soln['rv_trend_coeff'],
    rv_inst_styles=plot_rv_inst_styles,
#     plot_residuals=False,
)
axs[0][0].set_title('Unfolded')
axs[0][0].set_xlabel(f'$\mathrm{{BJD}}_\mathrm{{TDB}} - {TESS_EPOCH}$')
axs[1][0].set_title('Residuals')
axs[1][0].set_xlabel(f'$\mathrm{{BJD}}_\mathrm{{TDB}} - {TESS_EPOCH}$')
axs[2][0].set_title('c')
axs[2][0].set_xlabel('Phase')
axs[3][0].set_title('b')
axs[3][0].set_xlabel('Phase')
for row in axs:
    row[0].set_ylabel('RV (m/s)')

# axs[2][0].remove()
# axs[3][0].remove()
# fig = plt.Figure()
# fig.add_axes(axs[2][0])
# fig
# fig.set_size_inches((8, 4*4))
fig.subplots_adjust(hspace=0.35)

In [None]:
matplotlib.rcParams['figure.dpi'] = 300

In [None]:
np.max(tess_time) - np.min(tess_time)

In [None]:
np.max(tess_2m_time) - np.min(tess_2m_time)