In [1]:
import itertools
import os.path as path
import re

import astropy.units as u
import arviz as az
import celerite2
import celerite2.terms as terms
import celerite2.theano.terms as theano_terms
import corner
import exoplanet as xo
# import isochrones
import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import matplotlib.transforms as transforms
import numpy as np
import pandas as pd
import pymc3 as pm
import scipy.stats as stats

import adhocfitter.astro as astro
import adhocfitter.plotting as plotting
import adhocfitter.timeseries as aftimeseries

import third_party.keplersplinev2.keplersplinev2 as ksp

pd.set_option('display.max_rows', 100)

plt.rcParams['font.sans-serif'] = ['TeX Gyre Heros', 'Helvetica', 'Arial', 'sans serif']
plt.rcParams['font.cursive'] = ['TeX Gyre Chorus', 'Apple Chancery', 'cursive']
plt.rcParams['mathtext.fontset'] = 'stixsans'

DATA_DIR = '../data'



In [2]:
for p in matplotlib.font_manager.findSystemFonts(fontpaths=[
    '/usr/share/texmf/fonts/opentype/', '/usr/share/fonts/opentype']):
    matplotlib.font_manager.fontManager.addfont(p)

In [3]:
test_epoch = np.array([2100.93878864,  1901.71996768])  # TJD
test_epoch_unc = np.array([0.001, 0.001])
test_period = np.array([9.12706134,  3.09831393])  # day
test_period_unc = np.array([0.0001, 0.0001])  #np.array([0.00001, 0.000023])
test_last_seen = np.array([2101.0, 1902.0])
test_epoch = (test_last_seen - test_epoch) // test_period * test_period + test_epoch
print(test_epoch)
test_duration = np.array([5.0, 5.0])  # hour
test_duration_unc = np.array([6., 4.])

test_params = {
    "t0": test_epoch,
    "period": test_period,
    "tdur": test_duration,
}

prior_unc = {}

[ 2100.93878864  1901.71996768]


In [4]:
def binned_weighted_statistics(x, y, yerr, bins):
    # O(len(bin)*len(x)), so don't use too many bins.
    # Too lazy to do the proper O(len(x)).
    nbins = len(bins) - 1
    x_binned = np.zeros(nbins)
    y_binned = np.zeros(nbins)
    yerr_binned = np.zeros(nbins)
    count_binned = np.zeros(nbins, dtype=int)
    for i, (bin_start, bin_end) in enumerate(zip(bins[:-1], bins[1:])):
        if bin_end == bins[-1]:
            mask = np.logical_and(bin_start <= x, x <= bin_end)
        else:
            mask = np.logical_and(bin_start <= x, x < bin_end)
        x_masked = x[mask]
        if len(x_masked) == 0:
            x_binned[i] = np.nan
            y_binned[i] = np.nan
            yerr_binned[i] = np.nan
            continue
        y_masked = y[mask]
        yerr_masked = yerr[mask]
        yvar = yerr_masked * yerr_masked
        y_weight = 1. / yvar
        x_binned[i] = np.average(x_masked, weights=y_weight)
        y_binned[i] = np.average(y_masked, weights=y_weight)
        yerr_binned[i] = np.sqrt(1. / np.sum(y_weight))
    return x_binned, y_binned, yerr_binned

In [5]:
tess_time, tess_dflux, tess_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_02.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)
tess_noise = np.full_like(tess_dflux, tess_noise)

tess_20s_time, tess_20s_dflux, tess_20s_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_03.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)
tess_20s_noise = tess_20s_noise[0]
tess_20s_times = [tess_20s_time]
tess_20s_dfluxes = [tess_20s_dflux]
tess_20s_noises = [tess_20s_noise]

tess_2m_time, tess_2m_dflux, tess_2m_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_03_binned_to_2_min.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)
tess_2m_times = [tess_2m_time]
tess_2m_dfluxes = [tess_2m_dflux]
tess_2m_noises = [tess_2m_noise]

tess_y3_time, tess_y3_dflux, tess_y3_noise = aftimeseries.read_tess_lc(
    path.join(DATA_DIR, 'toi_2000_table_03_binned_to_30_min.csv'),
    test_params['period'], test_params['t0'], test_params['tdur']/24.)

In [6]:
# Exposure time 70 s
astep_time, astep_dflux, astep_unc, astep_design = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_06.csv'),
    'BJD', 'FLUX', 'ERRFLUX', ['SKY'], time_epoch=2450000.)

# Exposure time 30 s (except last cadence of B)
lco_zs_time, lco_zs_dflux, lco_zs_unc, lco_zs_design = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_04_zs.tsv'),
    '#BJD_TDB', 'rel_flux_T1_n', 'rel_flux_err_T1_n', ['Sky/Pixel_T1'], delim_whitespace=True)
lco_B_time, lco_B_dflux, lco_B_unc, lco_B_design = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_04_B.tsv'),
    '#BJD_TDB', 'rel_flux_T1_n', 'rel_flux_err_T1_n', ['Width_T1'], delim_whitespace=True)

pest_Ic_time, pest_Ic_dflux, pest_Ic_unc = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_05_Ic.txt'),
    '#BJD_TDB', 'flux', 'flux_err', delim_whitespace=True)
pest_B_time, pest_B_dflux, pest_B_unc = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_05_B.txt'),
    '#BJD_TDB', 'flux', 'flux_err', delim_whitespace=True)

ground_times = [astep_time, lco_zs_time] #, lco_B_time, pest_Ic_time, pest_B_time]
ground_dfluxes = [astep_dflux, lco_zs_dflux] #, lco_B_dflux, pest_Ic_dflux, pest_B_dflux]
ground_uncs = [astep_unc, lco_zs_unc] #, lco_B_unc, pest_Ic_unc, pest_B_unc]
ground_exp_times = list(np.array([70., 30.]) / 3600 / 24) #  30., 60., 60.]) / 3600 / 24)
ground_supersampling_factors = [1] * 2
ground_filters = ['Rc', 'zs'] #'B', 'Ic', 'B']
ground_detrends = [astep_design, lco_zs_design] #, lco_B_design, None, None]

In [7]:
tess_full_30m_time, tess_full_30m_dflux, _ = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_02.csv'),
    'time', 'kspflux', 'kspflux', time_epoch=aftimeseries.TESS_EPOCH,
)
tess_full_20s_time, tess_full_20s_dflux, tess_full_20s_unc = aftimeseries.read_generic_lc(
    path.join(DATA_DIR, 'toi_2000_table_03_binned_to_2_min.csv'),
    'time', 'kspflux', 'kspflux_unc', time_epoch=aftimeseries.TESS_EPOCH,
)
tess_full_times = [tess_full_30m_time, tess_full_20s_time]
tess_full_dfluxes = [tess_full_30m_dflux, tess_full_20s_dflux]

In [8]:
lc_times = [tess_time, tess_y3_time]# + tess_20s_times
lc_dfluxes = [tess_dflux, tess_y3_dflux] #+ tess_20s_dfluxes
lc_uncs = [tess_noise, tess_y3_dflux] #+ tess_20s_noises
exposure_times = [0.5/24] + [0.5/24]*len(tess_20s_dfluxes)
supersampling_factors = [15] + [15]*len(tess_20s_dfluxes)
num_lc = 1 + len(tess_20s_dfluxes)
filters = ['tess'] * num_lc

test_params['mean_flux'] = np.array([0.]*num_lc)
test_params['lc_jitter'] = np.array([1e-9]*num_lc)

np.array(exposure_times)*24*3600, supersampling_factors, filters

(array([ 1800.,  1800.]), [15, 15], ['tess', 'tess'])

In [9]:
rv_table = aftimeseries.read_generic_rv(path.join(DATA_DIR, 'toi_2000_table_08.csv'))
chiron_time, chiron_rv, chiron_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'CHIRON')
feros_time, feros_rv, feros_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'FEROS')
harps_time, harps_rv, harps_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'HARPS')

rv_names = ['chiron', 'feros', 'harps']
rv_times = [chiron_time, feros_time, harps_time]
rv_data = [chiron_rv, feros_rv, harps_rv]
rv_uncs = [chiron_rv_unc, feros_rv_unc, harps_rv_unc]
num_rv_outside = len(rv_data)

test_gamma = np.array([np.average(i) for i in rv_data])
test_rv_unc = np.array([np.average(i) for i in rv_uncs])
test_jitter = np.array([1e-3]*len(rv_data))

test_params['K'] = np.array([23., 9.])
prior_unc['K'] = np.array([50., 30.])
test_params['rv_gamma'] = test_gamma
prior_unc['rv_gamma'] = np.array([1000.]*num_rv_outside)
test_params['rv_jitter'] = test_jitter
prior_unc['rv_jitter'] = np.array([15., 30., 15.])

test_params['rv_fit_planet'] = np.array([True, True])

test_gamma, test_rv_unc, test_jitter, prior_unc['rv_gamma'], prior_unc['rv_jitter']

(array([ 6659.03408652,  8114.59285714,  8113.53725   ]),
 array([ 8.74179577,  6.72142857,  2.3035    ]),
 array([ 0.001,  0.001,  0.001]),
 array([ 1000.,  1000.,  1000.]),
 array([ 15.,  30.,  15.]))

In [10]:
pd.options.display.float_format = '{:11f}'.format

In [27]:
trace = az.from_netcdf('../chains/toi2000_trace_20s_tess_astep_lco_pest_rv_gp_sho_sed_15.nc')

In [11]:
trace = az.from_netcdf('../chains/toi-2000_trace_20230201T191240.nc')

In [12]:
display_vars = [
    "period", "t0", "rp", "b", "tdur",
    "sqrt_ecc_vec_0", "ecc", "omega",
    "r_planet_earth",
    "m_planet_earth",
    "rho_planet",
    "K",
    "m_star_0", "feh_0", "eep",
    "m_star", "r_star", "rho_star", "logg_star", "teff",
    "feh", "age", "parallax", "av", "L_star",
    "r_star_mist", "teff_mist", "feh_mist",
    "r_star_sed", "teff_sed", "ap_mag_bol", "ap_mag_bol_sed",
    "sed_unc_scale",
    "u_tess", "u_Rc", "u_zs", #"u_B", "u_Ic",
    "mean_flux_0", "mean_flux_1", #"mean_flux_5", "mean_flux_6",
    "lc_detrend_coeffs_2", "lc_detrend_coeffs_3", #"lc_detrend_coeffs_4",
    "lc_jitter",
    "rv_gamma", "rv_jitter",
    # "rv_trend_coeff",
    "gp_sigma", "gp_rho", "gp_tau",
    "a", "aor", "incl",
    "distance", "m_planet", "irradiation",
]

summary = az.summary(
    trace,
    var_names=display_vars,
    round_to=26,
    hdi_prob=0.68,
    # kind='stats',
    # extend=False,
    coords={"ecc_dim_0": [0], "omega_dim_0": [0]},
    skipna=True,
    circ_var_names={'omega'},
    stat_funcs={
        'median': np.median,
        '-': lambda x: np.quantile(x, 0.16) - np.median(x),
        '+': lambda x: np.quantile(x, 0.84) - np.median(x),
    },
)
summary

  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Unnamed: 0,mean,sd,hdi_16%,hdi_84%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat,median,-,+
period[0],9.127055,7e-06,9.127048,9.127062,0.0,0.0,95506.486481,75151.524035,1.000488,9.127055,-7e-06,7e-06
period[1],3.098331,2e-05,3.098312,3.098352,0.0,0.0,8919.272428,13561.372398,1.004017,3.09833,-1.9e-05,2.1e-05
t0[0],2110.065873,0.000276,2110.0656,2110.066148,1e-06,1e-06,76876.893846,61428.873545,1.000136,2110.065875,-0.000276,0.000272
t0[1],1855.244202,0.002216,1855.242016,1855.24631,1.7e-05,1.2e-05,16870.370642,18215.37672,1.001651,1855.244189,-0.002118,0.002185
rp[0],0.065803,0.000707,0.065145,0.066509,4e-06,3e-06,30132.600502,37518.832612,1.00096,0.065809,-0.000683,0.000683
rp[1],0.021794,0.000916,0.020905,0.022763,8e-06,6e-06,12730.430071,27320.723607,1.002476,0.021825,-0.000969,0.000892
b[0],0.626886,0.045059,0.592956,0.676894,0.000283,0.0002,28771.471489,31910.807768,1.001033,0.631475,-0.046792,0.038636
b[1],0.755558,0.055619,0.724013,0.822853,0.000641,0.000455,8345.379054,17709.829948,1.003699,0.769886,-0.070707,0.037584
tdur[0],3.655001,0.029048,3.624796,3.682315,0.000124,8.8e-05,54970.5654,68437.24437,1.000417,3.654269,-0.028087,0.029538
tdur[1],2.006155,0.148701,1.8345,2.132227,0.001798,0.001273,8591.187677,13655.384833,1.003684,1.959276,-0.097812,0.209191


In [13]:
r_planet_jupiter_summary = (
    summary.loc[['r_planet_earth[0]', 'r_planet_earth[1]'], ['median', '-', '+']].set_index(
        pd.Index(['r_planet_jupiter[0]', 'r_planet_jupiter[1]']))
    * (u.earthRad).to(u.jupiterRad)
)
r_planet_jupiter_summary

Unnamed: 0,median,-,+
r_planet_jupiter[0],0.726503,-0.027172,0.028038
r_planet_jupiter[1],0.240772,-0.013587,0.013525


In [15]:
m_planet_jupiter_summary = (
    summary.loc[['m_planet_earth[0]', 'm_planet_earth[1]'], ['median', '-', '+']].set_index(
        pd.Index(['m_planet_jupiter[0]', 'm_planet_jupiter[1]']))
    * (u.earthMass).to(u.jupiterMass)
)
m_planet_jupiter_summary

Unnamed: 0,median,-,+
m_planet_jupiter[0],0.257031,-0.014407,0.014713
m_planet_jupiter[1],0.034742,-0.007525,0.007684


In [16]:
def format_sig_fig(value, lower, upper):
    lower_abs = abs(lower)
    lower_order = np.floor(np.log10(lower_abs)) if lower_abs != 0 else 0
    upper_order = np.floor(np.log10(upper)) if upper != 0 else 0
    value_order = int(np.floor(np.log10(abs(value)))) if value != 0 else 0

    unc_order = int(min(lower_order, upper_order)) - 1
    if unc_order <= 1 and value_order >= -3:
        dec_places = abs(unc_order) if unc_order <= 0 else 0
        fmt_str = f'{{:.{dec_places}f}}'
        value_str = fmt_str.format(value)
        lower_str = fmt_str.format(lower_abs)
        upper_str = fmt_str.format(upper)
        if lower_str == upper_str:
            unc_str = f'\\pm {lower_str}'
        else:
            unc_str = f'_{{-{lower_str}}}^{{+{upper_str}}}'
    elif unc_order <= value_order:
        dec_places = abs(value_order - unc_order)
        fmt_str = f'{{:.{dec_places}f}}'
        exponent = 10 ** -value_order
        value_str = '(' + fmt_str.format(value * exponent)
        lower_str = fmt_str.format(lower_abs * exponent)
        upper_str = fmt_str.format(upper * exponent)
        if lower_str == upper_str:
            unc_str = f'\\pm {lower_str}) \\times 10^{{{value_order}}}'
        else:
            unc_str = f'_{{-{lower_str}}}^{{+{upper_str}}}) \\times 10^{{{value_order}}}'
    return value_str, unc_str

format_sig_fig(3.09832834, -0.000014452, 0.000027123)

english_num_exceptions = {
    0: 'zero', 1: 'one', 2: 'two', 3: 'three', 4: 'four',
    5: 'five', 6: 'six', 7: 'seven', 8: 'eight', 9: 'nine',
}

def wordify_number(n):
    if n in english_num_exceptions:
        return english_num_exceptions[n]
    else:
        raise ValueError('Numbers greater than 9 not implemented.')

def texify_name(s, prefix='sysParam'):
    index_match = re.match(r'(\w+)(\[\d+\])?', s)
    full_name = index_match.group(1)
    name_parts = full_name.split('_')
    new_name_parts = [prefix]
    for n in name_parts:
        if n.isdigit():
            new_name_parts.append(wordify_number(int(n)).capitalize())
        elif not n.isalpha():
            tmp_list = []
            for c in n:
                if c.isdigit():
                    tmp_list.append(wordify_number(int(c)))
                else:
                    tmp_list.append(c)
            new_name_parts.append(''.join(tmp_list).capitalize())
        else:
            new_name_parts.append(n.capitalize())
    index_part = index_match.group(2)
    if index_part is not None:
        index = re.match(r'\[(\d+)\]', index_part).group(1)
        new_name_parts.append('Sub')
        new_name_parts.append(wordify_number(int(index)).capitalize())
    tex_name = ''.join(new_name_parts)
    return tex_name

texify_name('sqrt_ecc_vec_0[0]')

def texify_val_unc(k, val, low, upp):
    val_str, unc_str = format_sig_fig(val, low, upp)
    tex_name = texify_name(k)
    return (
        f'\\newcommand{{\\{tex_name}}}{{${val_str}$}}',
        f'\\newcommand{{\\{tex_name}Unc}}{{${unc_str}$}}',
    )

def texify_summary(sum_table):
    output = []
    for k, v in sum_table.iterrows():
        val = v['median']
        low = v['-']
        upp = v['+']
        output.extend(texify_val_unc(k, val, low, upp))
    return '\n'.join(output)

In [17]:
print(texify_summary(summary))

\newcommand{\sysParamPeriodSubZero}{$9.1270550$}
\newcommand{\sysParamPeriodSubZeroUnc}{$_{-0.0000072}^{+0.0000073}$}
\newcommand{\sysParamPeriodSubOne}{$3.098330$}
\newcommand{\sysParamPeriodSubOneUnc}{$_{-0.000019}^{+0.000021}$}
\newcommand{\sysParamTzeroSubZero}{$2110.06588$}
\newcommand{\sysParamTzeroSubZeroUnc}{$_{-0.00028}^{+0.00027}$}
\newcommand{\sysParamTzeroSubOne}{$1855.2442$}
\newcommand{\sysParamTzeroSubOneUnc}{$_{-0.0021}^{+0.0022}$}
\newcommand{\sysParamRpSubZero}{$0.06581$}
\newcommand{\sysParamRpSubZeroUnc}{$\pm 0.00068$}
\newcommand{\sysParamRpSubOne}{$0.02182$}
\newcommand{\sysParamRpSubOneUnc}{$_{-0.00097}^{+0.00089}$}
\newcommand{\sysParamBSubZero}{$0.631$}
\newcommand{\sysParamBSubZeroUnc}{$_{-0.047}^{+0.039}$}
\newcommand{\sysParamBSubOne}{$0.770$}
\newcommand{\sysParamBSubOneUnc}{$_{-0.071}^{+0.038}$}
\newcommand{\sysParamTdurSubZero}{$3.654$}
\newcommand{\sysParamTdurSubZeroUnc}{$_{-0.028}^{+0.030}$}
\newcommand{\sysParamTdurSubOne}{$1.959$}
\newcommand{\sysPar

In [19]:
print(texify_summary(r_planet_jupiter_summary))
print(texify_summary(m_planet_jupiter_summary))

\newcommand{\sysParamRPlanetJupiterSubZero}{$0.727$}
\newcommand{\sysParamRPlanetJupiterSubZeroUnc}{$_{-0.027}^{+0.028}$}
\newcommand{\sysParamRPlanetJupiterSubOne}{$0.241$}
\newcommand{\sysParamRPlanetJupiterSubOneUnc}{$\pm 0.014$}
\newcommand{\sysParamMPlanetJupiterSubZero}{$0.257$}
\newcommand{\sysParamMPlanetJupiterSubZeroUnc}{$_{-0.014}^{+0.015}$}
\newcommand{\sysParamMPlanetJupiterSubOne}{$0.0347$}
\newcommand{\sysParamMPlanetJupiterSubOneUnc}{$_{-0.0075}^{+0.0077}$}


In [30]:
flat_samples = trace.posterior.stack(sample=("chain", "draw"))
median_soln = {k:v.data for k, v in flat_samples.median(dim='sample').items()}
max_post_index = flat_samples.log_prob.argmax(dim='sample')
max_post_soln = {k:v.data for k, v in flat_samples[{'sample': max_post_index}].items()}

In [None]:
median_soln['t0'][1] + median_soln['period'][1] * np.array([113, 214, 221]), median_soln['tdur']/2

In [None]:
(np.array([2459586.36, 2459586.54]) - 2457000 - 2586.45024633) * 24

In [None]:
(np.array([2459564.68, 2459564.85]) - 2457000 - 2564.76193271) * 24

In [None]:
(np.array([59251.536778, 59251.850281]) - 57000 - 2251.8305503670549) * 24

In [None]:
(0.47353519+0.97857525)/(0.97857525*2)

In [None]:
def make_summary_table(var_names, samples):
    rows = []
    for s in samples:
        quantiles = np.quantile(s, [0.5, 0.16, 0.84])
        rows.append([quantiles[0], quantiles[1]-quantiles[0], quantiles[2]-quantiles[0]])
    rows = np.array(rows)
    return pd.DataFrame(
        data=rows,
        index=var_names,
        columns=['median', '-', '+'],
    )

def make_summary_circular(var_names, samples):
    rows = []
    for s in samples:
        quantiles = np.quantile(s, [0.5, 0.16, 0.84])
        rows.append([quantiles[0], quantiles[1]-quantiles[0], quantiles[2]-quantiles[0]])
    rows = np.array(rows)
    return pd.DataFrame(
        data=rows,
        index=var_names,
        columns=['median', '-', '+'],
    )

In [None]:
summary.loc[['omega[0]']] / np.pi * 180

In [None]:
omega_avg_vec = np.average(flat_samples['sqrt_ecc_vec_0'] / np.linalg.norm(flat_samples['sqrt_ecc_vec_0'], axis=0), axis=1)
np.arctan2(omega_avg_vec[1], omega_avg_vec[0]) / np.pi * 180, np.sqrt(-2 * np.log(np.linalg.norm(omega_avg_vec))) / np.pi * 180

In [None]:
rng = np.random.default_rng()

In [None]:
sample_len = len(flat_samples['sample'])
temp_eq_favg_0 = astro.calculate_temperature_eq_flux_avg(flat_samples['teff'], flat_samples['aor'][0], 0.7*rng.random(sample_len), ecc=flat_samples['ecc'][0], heat_dist=0.5)
temp_eq_favg_1 = astro.calculate_temperature_eq_flux_avg(flat_samples['teff'], flat_samples['aor'][1], 0.7*rng.random(sample_len), ecc=0., heat_dist=0.5)
omega_fold_0 = (flat_samples['omega'][0] / np.pi * 180 + 360) % 360
derived_var_summary = make_summary_table(['temp_eq[0]', 'temp_eq[1]', 'omega_fold[0]'], [temp_eq_favg_0, temp_eq_favg_1, omega_fold_0])
derived_var_summary

In [31]:
mass_planet_earth_corrected = flat_samples['m_planet_earth'] * flat_samples['m_star']
rho_planet_earth_corrected = flat_samples['m

In [None]:
_ = corner.corner(trace, var_names=['omega', 'ecc'], coords={"ecc_dim_0": [0], "omega_dim_0": [0]},)

In [None]:
print(texify_summary(derived_var_summary))

In [None]:
abs(trace.posterior.log_prob - trace.sample_stats.lp) < 1e6

In [None]:
astro.calculate_min_planet_mass_earth(13.11318376, 0, median_soln['period'][1], median_soln['m_star'])

In [None]:
tdur_cut = 2.
tdur_cut_indx = flat_samples.tdur[1, :] > tdur_cut
second_mode_idx = (flat_samples.log_prob[tdur_cut_indx]).argmax(dim='sample')
second_mode_soln_array = flat_samples[{'sample': np.array(tdur_cut_indx)}][{'sample': np.array(second_mode_idx)}]
second_mode_soln = {k:v.data for k, v in second_mode_soln_array.items()}

In [None]:
median_soln['rp']**2, median_soln['t0']

In [None]:
median_soln['period']

In [None]:
flat_samples.quantile([0.16, 0.84], dim='sample')['period'] - median_soln['period']

In [None]:
trace

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=['rv_jitter', 'gp_rho', 'gp_tau', 'gp_sigma'],
)

In [None]:
fig = corner.corner(
    trace,
    var_names=["sqrt_ecc_vec_0"],
    truths=[0, 0],
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5, 3])**2/2),
)

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=["period", "t0", "rp", "b", "tdur"], # S"sqrt_ecc_vec_0"],
    # truths=second_mode_soln[{"period_dim_0": 1, "t0_dim_0": 1, "b_dim_0": 1, "rp_dim_0": 1, "tdur_dim_0": 1}][["period", "t0", "rp", "b", "tdur"]],
    coords={"period_dim_0": [1], "t0_dim_0": [1], "b_dim_0": [1], "rp_dim_0": [1], "ecc_dim_0": [1], "omega_dim_0": [1], "tdur_dim_0": [1], "K_dim_0": [1]},
)
# fig.savefig("plot/toi2000_corner_20s_rv_gp_sho_lc_only.png", dpi=144)

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=["period", "t0", "rp", "b", "K", "sqrt_ecc_vec_0"],
    coords={"period_dim_0": [0], "t0_dim_0": [0], "rp_dim_0": [0], "b_dim_0": [0], "ecc_dim_0": [0], "omega_dim_0": [0], "K_dim_0": [0]},
)
# fig.savefig('plot/toi2000_20s_

In [None]:
fig = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=["m_star", "r_star", "r_star_pred", "rho_star", "logg_star", "teff", "teff_pred", "feh", "feh_pred", "mag_bol", "age", ],
)
# fig.savefig('plot/toi2000_corner_stellar_param_11.png', dpi=300)

In [None]:
flat_samples.quantile(np.array([0.999936657516334, 0.997300203936740, 0.987580669348448, 0.954499736103642]), dim='sample')['ecc']

In [None]:
flat_samples.quantile(1-np.array([0.999936657516334, 0.997300203936740, 0.987580669348448, 0.954499736103642]), dim='sample')['K']

In [None]:
def epoch_min_cov(period_samples, epoch_samples):
    period_dev = period_samples - np.average(period_samples)
    epoch_dev = epoch_samples - np.average(epoch_samples)
    offset_periods = (period_dev @ epoch_dev) / (period_dev @ period_dev)
    print(offset_periods)
    return epoch_samples - period_samples * round(offset_periods)

In [None]:
list(map(
    lambda b: [b[1], b[0]-b[1], b[2]-b[1]],
    (
        np.quantile(epoch_min_cov(flat_samples['period'].data[0], flat_samples['t0'].data[0]), np.array([0.16, 0.50, 0.84])),
        np.quantile(epoch_min_cov(flat_samples['period'].data[1], flat_samples['t0'].data[1]), np.array([0.16, 0.50, 0.84])),
)))

In [None]:
plot_offset_periods = 0
fig = corner.corner(
    trace,
    var_names=["period", "t0"],
    coords={"period_dim_0": [1], "t0_dim_0": [1]},
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    truths=[median_soln['period'][1], np.median(flat_samples['t0'][1]-flat_samples['period'][1]*plot_offset_periods)],
)

In [None]:
plot_offset_periods = 50
fig = corner.corner(
    np.vstack([flat_samples['period'][1], flat_samples['t0'][1]-flat_samples['period'][1]*plot_offset_periods]).T,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    truths=[median_soln['period'][1], np.median(flat_samples['t0'][1]-flat_samples['period'][1]*plot_offset_periods)],
)

In [None]:
def plot_long_light_curve(
    axs, lc_times, lc_dfluxes, periods, epochs, alphas, ylim, markerys, marker_styles=None, y_unit=1.):
    if marker_styles is None:
        marker_styles = [{}] * len(periods)
    for ax, lc_time, lc_dflux, alph in zip(axs, lc_times, lc_dfluxes, alphas):
        ax.scatter(
            lc_time, lc_dflux/y_unit, s=5, c='gray', marker='.', alpha=alph, rasterized=True)
        for per, t0, m_y, mstyle in zip(periods, epochs, markerys, marker_styles):
            n_min = (min(lc_time) - t0) // per
            n_max = (max(lc_time) - t0) // per
            per_marks = t0 + np.arange(n_min+1, n_max+1) * per
            ax.plot(per_marks, np.full_like(per_marks, m_y/y_unit), '^', markersize=3, **mstyle)
        if ylim is not None:
            ax.set_ylim(ylim[0]/y_unit, ylim[1]/y_unit)

In [None]:
fig, axs = plt.subplots(2, dpi=150)
plot_long_light_curve(
    axs, tess_full_times, tess_full_dfluxes,
    max_post_soln['period'], max_post_soln['t0'],
    [0.1, 0.01], [-0.008, 0.003], [-0.007, -0.003])

In [None]:
def plot_model_light_curve(
    ldlc_obj, orbit, radii_planets,
    texp, supersampling_factor,
    max_phases, num_points=2000):
    model_phases = []
    model_light_curve = []
    for i, max_phase in enumerate(max_phases):
        model_phase = np.linspace(-max_phase, max_phase, num_points)
        model_phases.append(model_phase)
        model_time = model_phase * orbit.period[i] + orbit.t0[i]
        model_light_curve.append(ldlc_obj.get_light_curve(
            orbit=orbit,
            r=radii_planets,
            t=model_time,
            texp=texp,
            oversample=supersampling_factor,
        ).eval()[:, i])
    return model_phases, model_light_curve

def plot_multi_planet_folded_light_curve(
    num_planets, orbit, rp,
    lc_times, lc_dfluxes, filters, limb_dark_params, mean_fluxes,
    exposure_times, supersampling_factors,
    nbins, max_phases, alphas, dpi=300,
    detrend_series=None,
    detrend_coeffs=None):

    num_light_curves = len(lc_times)
    if detrend_series is None:
        detrend_design = [None] * num_light_curves
    else:
        detrend_design = []
        for detrend in detrend_series:
            if detrend is None:
                detrend_design.append(None)
                continue
            detrend_design.append(np.hstack((
                np.ones((len(detrend), 1)),
                detrend)))
    if detrend_coeffs is None:
        detrend_coeffs = [None] * num_light_curves

    periods = orbit.period.eval()
    epochs = orbit.t0.eval()
    radii_planets = rp * orbit.r_star.eval()
#     orbit = xo.orbits.KeplerianOrbit(
#         **{k: map_soln[k] for k in ['period', 't0', 'b', 'm_star', 'r_star', 'ecc', 'omega']})
    ldlc_objs = dict()
    for filter_name in set(filters):
        ldlc_objs[filter_name] = xo.LimbDarkLightCurve(
                limb_dark_params[f'u_{filter_name}'])
    
    fig, axs = plt.subplots(
        num_light_curves,
        num_planets,
        sharex='col',
        sharey='row',
        figsize=(5*num_planets, 3*num_light_curves),
        dpi=dpi,
        squeeze=False,
    )

    for row, lc_time, lc_dflux, filter_name, texp, supersample, nbin, alpha, design, coeffs in zip(
        axs, lc_times, lc_dfluxes, filters,
        exposure_times, supersampling_factors,
        nbins, alphas, detrend_design, detrend_coeffs):

        ldlc_obj = ldlc_objs[filter_name]
        model_dflux = ldlc_obj.get_light_curve(
            orbit=orbit,
            r=radii_planets,
            t=lc_time,
            texp=texp,
            oversample=supersample,
        ).eval()
        model_dflux_sum = np.sum(model_dflux, axis=1)
        plot_model_phases, plot_model_dflux = plot_model_light_curve(
            ldlc_obj, orbit, radii_planets, texp, supersample, max_phases)
        
        if design is None:
                lc_trend = 0.
        else:
            if coeffs is not None:
                lc_trend = design @ coeffs
            else:
                raise ValueError(f'Unknown detrending coeffs for light curve')

        for i, (ax, period, epoch, max_phase) in enumerate(
            zip(row, periods, epochs, max_phases)):

            lc_phase = phase_fold(lc_time, epoch, period)
            lc_dflux_only_planet = lc_dflux - model_dflux_sum + model_dflux[:, i] - mean_fluxes[i] - lc_trend
            ax.plot(
                lc_phase,
                lc_dflux_only_planet,
                '.', color='gray', alpha=alpha, rasterized=True)
            ax.plot(
                plot_model_phases[i], plot_model_dflux[i])

            phase_mask = np.abs(lc_phase) < max_phase
            select_phase = lc_phase[phase_mask]
            select_dflux = lc_dflux_only_planet[phase_mask]
            if nbin is not None:
                binned_mean, bins, _ = stats.binned_statistic(
                    select_phase, select_dflux, statistic='mean', bins=nbin)
                binned_err, _, _ = stats.binned_statistic(
                    select_phase, select_dflux, statistic=lambda a: np.std(a, ddof=1), bins=bins)
                binned_count, _ = np.histogram(select_phase, bins=bins)
                print(binned_count)
                mid_bin = (bins[:-1] + np.diff(bins) / 2.)
                ax.errorbar(mid_bin, binned_mean,
                    yerr=binned_err/np.sqrt(binned_count),
                    fmt='.', color='C1')
            ax.set_xlim(-max_phase, max_phase)
    fig.tight_layout()
    return fig, axs

In [None]:
# plot_orbit = xo.orbits.KeplerianOrbit(
#     **{k: max_post_soln[k] for k in ['period', 't0', 'b', 'm_star', 'r_star', 'ecc', 'omega', 'm_planet']})

lc_plot_soln = max_post_soln

plot_soln = {k: lc_plot_soln[k] for k in ['period', 't0', 'b', 'ecc', 'omega', 'm_planet']}
plot_soln.update({k: lc_plot_soln[k] for k in ['m_star', 'r_star']})
plot_orbit = xo.orbits.KeplerianOrbit(
    **plot_soln)

fig, axs = plotting.plot_multi_planet_folded_light_curve(
    2, plot_orbit, lc_plot_soln['rp'],
    lc_times,
    lc_dfluxes,
    ['tess']*num_lc,
    {k:v.data for k, v in lc_plot_soln.items() if k[:2]=='u_'},
    [lc_plot_soln['mean_flux_0'], lc_plot_soln['mean_flux_1']],
    [.5/24, 0.5/24], [15, 15],
#     exposure_times, supersampling_factors,
    [50]*num_lc, [0.035, 0.05], [0.3]*num_lc, dpi=600)
axs[0][0].set_ylim(-0.0020, 0.0005)
axs[1][0].set_ylim(-0.0020, 0.0005)
# axs[2][0].set_ylim(-0.0055, 0.0015)
axs[0][0].set_title('c')
axs[0][1].set_title('b')
axs[-1][0].set_xlabel('Phase')
axs[-1][1].set_xlabel('Phase')
for row in axs:
    row[0].set_ylabel('$\Delta\mathrm{flux} / \mathrm{flux}$')
# fig.set_size_inches(6, 3)
# fig.set_dpi(600)
# fig.savefig('plot/toi2000_lc_2m_rv_gp_sho_sed_prelim.png', dpi=144, bbox_inches='tight')

In [None]:
def setup_detrend(detrend_series, detrend_coeffs):
    if detrend_series is None:
        detrend_design = [None] * num_light_curves
    else:
        detrend_design = []
        for detrend in detrend_series:
            if detrend is None:
                detrend_design.append(None)
                continue
            detrend_design.append(np.hstack((
                np.ones((len(detrend), 1)),
                detrend)))
    if detrend_coeffs is None:
        detrend_coeffs = [None] * num_light_curves
    return detrend_design, detrend_coeffs

def optimize_detrend_limb_dark(orbit, rp, lc_time, lc_dflux, lc_unc, limb_dark_test, texp, supersample, detrend_design):
    ldlc_obj = xo.LimbDarkLightCurve(limb_dark_test)
    raw_light_curve = ldlc_obj.get_light_curve(
        orbit=orbit,
        r=rp,
        t=lc_time,
        texp=texp,
        oversample=supersample,
    ).eval()
    raw_light_curve = np.sum(raw_light_curve, axis=-1)
    residual = lc_dflux - raw_light_curve
    detrend_coeffs = (np.linalg.lstsq(detrend_design, residual, rcond=None))[0]
    return lc_dflux - detrend_design @ detrend_coeffs
    # with pm.Model() as mymodel:
    #     limb_dark = xo.distributions.QuadLimbDark('u', testval=limb_dark_test)
    #     ldlc_obj = xo.LimbDarkLightCurve(limb_dark)
    #     raw_light_curve = ldlc_obj.get_light_curve(
    #         orbit=orbit,
    #         r=rp,
    #         t=lc_time,
    #         texp=texp,
    #         oversample=supersample,
    #     )
    #     detrend_params = pm.Normal(f'lc_detrend_coeffs', mu=1e-5, sd=1., shape=detrend_design.shape[1])
    #     light_curve = tt.sum(raw_light_curve, axis=-1) + detrend_design @ detrend_params
    #     jitter = pm.Uniform(
    #         "lc_jitter", testval=1e-9, lower=0., upper=1.)
    #     obs_likelihood = pm.Normal(
    #         'lc_obs',
    #         mu=light_curve,
    #         sd=tt.sqrt(lc_unc*lc_unc + jitter*jitter),
    #         observed=lc_dflux,
    #     )
    #     map_soln = pmx.optimize(start=model.test_point)
    # return lc_dflux - detrend_design @ map_soln['detrend_params']

def plot_lc_bins(period, bin_width, max_phase):
    half_num_bins = int(np.ceil(max_phase * period / bin_width))
    bins = np.linspace(
        -half_num_bins*bin_width,
        half_num_bins*bin_width,
        2*half_num_bins+1)
    return bins / period  # Unit is phase.

def plot_model_light_curve_2(
    ldlc_obj, orbit, radii_planets,
    texp, supersampling_factor,
    max_phases, index_planet,
    num_points=2000,
    ):
    model_phases = []
    model_light_curve = []

    for i, max_phase in zip(index_planet, max_phases):
        model_phase = np.linspace(-max_phase, max_phase, num_points)
        model_phases.append(model_phase)
        model_time = model_phase * orbit.period[i] + orbit.t0[i]
        model_light_curve.append(ldlc_obj.get_light_curve(
            orbit=orbit,
            r=radii_planets,
            t=model_time,
            texp=texp,
            oversample=supersampling_factor,
        ).eval()[:, i])
    return model_phases, model_light_curve

def plot_single_planet_folded_light_curve(
    ax, num_planets, index_planet, orbit, ror,
    lc_times, lc_dfluxes, lc_uncs, limb_dark_params, mean_fluxes,
    exposure_times, supersampling_factors,
    bin_width, bin_supersampling, max_phase,
    plot_phase=True,
    detrend_series=None,
    detrend_coeffs=None,
    y_unit=1.,
    plot_model_style=dict(),
    plot_errorbar_style={'fmt': '.'},):

    num_light_curves = len(lc_times)
    detrend_design, detrend_coeffs = setup_detrend(detrend_series, detrend_coeffs)

    period = orbit.period.eval()[index_planet]
    epoch = orbit.t0.eval()[index_planet]
    radii_planets = (ror * orbit.r_star.eval())
    rp = radii_planets[index_planet]

    if plot_phase:
        phase_factor = 1.
    else:
        phase_factor = period * 24.

    ldlc_obj = xo.LimbDarkLightCurve(limb_dark_params)

    bins = plot_lc_bins(period, bin_width, max_phase[index_planet])
    new_max_phase = bins[-1]
    
    plot_model_phases, plot_model_dflux = plot_model_light_curve_2(
        ldlc_obj, orbit, radii_planets, bin_width, bin_supersampling,
        [new_max_phase]*num_planets, [index_planet])
    ax.plot(plot_model_phases[0]*phase_factor, plot_model_dflux[0]/y_unit, **plot_model_style)
    ax.set_xlim(-new_max_phase*phase_factor, new_max_phase*phase_factor)

    lc_phases = []
    lc_dflux_corrected = []

    for lc_time, lc_dflux, lc_unc, texp, supersample, design, coeffs, mean_flux in zip(
        lc_times, lc_dfluxes, lc_uncs, exposure_times, supersampling_factors,
        detrend_design, detrend_coeffs, mean_fluxes):
        lc_phase = aftimeseries.phase_fold(lc_time, epoch, period)
        # print(design, coeffs)
        if design is None:
            lc_dflux_new = lc_dflux - mean_flux
        elif coeffs is None:
            lc_dflux_new = optimize_detrend_limb_dark(
                orbit, radii_planets, lc_time, lc_dflux, lc_unc,
                limb_dark_params, texp, supersample, design)
        else:
            lc_dflux_new = lc_dflux - design @ coeffs        
        lc_phases.append(lc_phase)
        lc_dflux_corrected.append(lc_dflux_new)

    lc_phase_all = np.concatenate(lc_phases)
    lc_dflux_all = np.concatenate(lc_dflux_corrected)
    lc_unc_all = np.concatenate(lc_uncs)

    phase_mask = np.abs(lc_phase_all) < new_max_phase
    select_phase = lc_phase_all[phase_mask]
    select_dflux = lc_dflux_all[phase_mask]
    select_unc = lc_unc_all[phase_mask]

    # binned_mean, _, _ = stats.binned_statistic(
    #     select_phase, select_dflux, statistic='mean', bins=bins)
    # binned_err, _, _ = stats.binned_statistic(
    #     select_phase, select_dflux, statistic=lambda a: np.std(a, ddof=1), bins=bins)
    binned_count, _ = np.histogram(select_phase, bins=bins)
    print(binned_count)
    # binned_phase, _, _ = stats.binned_statistic(
    #     select_phase, select_phase, statistic='mean', bins=bins)
    binned_phase, binned_mean, binned_err = binned_weighted_statistics(
        select_phase, select_dflux, select_unc, bins)
    ax.errorbar(
        binned_phase*phase_factor, binned_mean/y_unit,
        yerr=binned_err/y_unit,
        **plot_errorbar_style)


In [None]:
def plot_lc_bins_multiple(periods, bin_widths, max_phases):
    bins = []
    for period, bin_width, max_phase in zip(periods, bin_widths, max_phases):
        bins.append(plot_lc_bins(period, bin_width, max_phase))
    return bins

default_lc_errorbar_style = {'fmt': '.'}
default_lc_scatter_style = {'marker': '.', 'color': 'gray', 'alpha': 0.1}

def plot_multi_planet_folded_light_curve_2(
    num_planets, axs, orbit, rp,
    lc_times, lc_dfluxes, lc_uncs, limb_dark_params, mean_fluxes,
    exposure_times, supersampling_factors,
    bin_widths, max_phases,
    model_exp, model_supersampling,
    plot_phase=True,
    y_unit=1.,
    plot_model_style=dict(),
    plot_errorbar_style=default_lc_errorbar_style,
    plot_scatter_style=default_lc_scatter_style):

    num_light_curves = len(lc_times)

    periods = orbit.period.eval()
    epochs = orbit.t0.eval()
    radii_planets = rp * orbit.r_star.eval()
    ldlc_obj = xo.LimbDarkLightCurve(limb_dark_params)

    bin_list = plot_lc_bins_multiple(periods, bin_widths, max_phases)
    new_max_phases = [b[-1] for b in bin_list]
    if plot_phase:
        phase_factors = [1.] * num_planets
    else:
        phase_factors = periods * 24.

    plot_model_phases, plot_model_dflux = plotting.plot_model_light_curve(
        ldlc_obj, orbit, radii_planets, model_exp, model_supersampling, new_max_phases)
    for ax, model_phase, model_dflux, phase_factor, new_max_phase in zip(
        axs, plot_model_phases, plot_model_dflux, phase_factors, new_max_phases):
        ax.plot(model_phase*phase_factor, model_dflux/y_unit, **plot_model_style)
        ax.set_xlim(-new_max_phase*phase_factor, new_max_phase*phase_factor)

    planet_only_dfluxes = [[] for _ in range(num_planets)]
    lc_uncs_all = []

    for lc_time, lc_dflux, lc_unc, texp, supersample in zip(
        lc_times, lc_dfluxes, lc_uncs,
        exposure_times, supersampling_factors):
        model_dflux = ldlc_obj.get_light_curve(
            orbit=orbit,
            r=radii_planets,
            t=lc_time,
            texp=texp,
            oversample=supersample,
        ).eval()
        model_dflux_sum = np.sum(model_dflux, axis=1)
        for i, (lc_dflux_only_planet, period, epoch) in enumerate(zip(
            planet_only_dfluxes, periods, epochs)):
            lc_dflux_only_planet.append(lc_dflux - model_dflux_sum + model_dflux[:, i] - mean_fluxes[i])

    lc_times_all = np.concatenate(lc_times)
    lc_dflux_all_per_planet = [np.concatenate(a) for a in planet_only_dfluxes]
    lc_uncs_all = np.concatenate(lc_uncs)

    for ax, period, epoch, new_max_phase, phase_factor, lc_dflux_only_planet, bins in zip(
        axs, periods, epochs, new_max_phases, phase_factors, lc_dflux_all_per_planet, bin_list):
        lc_phase = aftimeseries.phase_fold(lc_times_all, epoch, period)
        phase_mask = np.abs(lc_phase) < new_max_phase
        select_phase = lc_phase[phase_mask]
        select_dflux = lc_dflux_only_planet[phase_mask]
        select_unc = lc_uncs_all[phase_mask]
        ax.scatter(select_phase*phase_factor, select_dflux/y_unit, **plot_scatter_style)

        binned_mean, _, _ = stats.binned_statistic(
            select_phase, select_dflux, statistic='mean', bins=bins)
        binned_err, _, _ = stats.binned_statistic(
            select_phase, select_dflux, statistic=lambda a: np.std(a, ddof=1), bins=bins)
        binned_count, _ = np.histogram(select_phase, bins=bins)
        print(binned_count)
        binned_phase, _, _ = stats.binned_statistic(
            select_phase, select_phase, statistic='mean', bins=bins)
        # binned_phase, binned_mean, binned_err = binned_weighted_statistics(
        #     select_phase, select_dflux, select_unc, bins)
        ax.errorbar(
            binned_phase*phase_factor, binned_mean/y_unit,
            yerr=binned_err/np.sqrt(binned_count)/y_unit,
            **plot_errorbar_style)

In [None]:
lc_plot_soln = max_post_soln

plot_soln = {k: lc_plot_soln[k] for k in ['period', 't0', 'b', 'ecc', 'omega', 'm_planet']}
plot_soln.update({k: lc_plot_soln[k] for k in ['m_star', 'r_star']})
plot_orbit = xo.orbits.KeplerianOrbit(
    **plot_soln)

fig, ax = plt.subplots(dpi=144)
plot_single_planet_folded_light_curve(
    ax, 2, 0, plot_orbit, lc_plot_soln['rp'],
    ground_times, ground_dfluxes, ground_uncs,
    lc_plot_soln['u_Rc'], [0.]*len(ground_times),
    ground_exp_times, ground_supersampling_factors,
    15./60/24, 5, np.array([0.028, 0.05]),
    plot_phase=False,
    detrend_series=ground_detrends,
    detrend_coeffs=[
        lc_plot_soln['lc_detrend_coeffs_2'],
        lc_plot_soln['lc_detrend_coeffs_3'],
        None,
        None,
        None,
    ],
)

In [None]:
lc_plot_soln = max_post_soln

plot_soln = {k: lc_plot_soln[k] for k in ['period', 't0', 'b', 'ecc', 'omega', 'm_planet']}
plot_soln.update({k: lc_plot_soln[k] for k in ['m_star', 'r_star']})
plot_orbit = xo.orbits.KeplerianOrbit(
    **plot_soln)

fig, axs = plt.subplots(2, figsize=(8, 10), dpi=144)
plot_multi_planet_folded_light_curve_2(
    2, axs, plot_orbit, lc_plot_soln['rp'],
    lc_times, lc_dfluxes, lc_uncs,
    lc_plot_soln['u_tess'],
    [lc_plot_soln['mean_flux_0'], lc_plot_soln['mean_flux_1']],
    exposure_times, supersampling_factors,
    [12./60/24, 6./60/24], [0.03, 0.05], 0.5/24, 15,
    plot_phase=False)

In [None]:
def make_multi_planet_lc_axes(num_planets, num_unfolded, num_folded, forbidden=set(), figure_kwargs={'dpi': 600}):
    figure_kwargs['figsize'] = (7, 8.5)
    heights = [4, 6]
    fig = plt.figure(constrained_layout=True, **figure_kwargs)
    unfolded_axs = []
    if num_unfolded > 0:
        gs = gridspec.GridSpec(2, 1, figure=fig, hspace=0.4, height_ratios=heights)
        gs0 = gs[0, :].subgridspec(num_unfolded, 1, height_ratios=[1, 1])
        for i in range(num_unfolded):
            unfolded_axs.append(fig.add_subplot(gs0[i]))
    folded_axs = []
    gs1 = gs[1, :].subgridspec(num_planets, num_folded, height_ratios=[1, 1])
    
    for i in range(num_planets):
        row = []
        for j in range(num_folded):
            if (i, j) in forbidden:
                row.append(None)
            else:
                row.append(fig.add_subplot(gs1[i, j]))
        folded_axs.append(row)
    return fig, folded_axs, unfolded_axs

def make_multi_planet_lc_axes_fixed(figure_kwargs={'dpi': 600}):
    num_planets = 2
    num_unfolded = 2
    num_folded = 3

    figure_kwargs['figsize'] = (7, 8.5)
    heights = [4, 6]
    fig = plt.figure(**figure_kwargs)
    unfolded_axs = []
    if num_unfolded > 0:
        gs = gridspec.GridSpec(2, 1, figure=fig, hspace=0.15, height_ratios=heights)
        gs0 = gs[0, :].subgridspec(num_unfolded, 1, height_ratios=[1, 1])
        for i in range(num_unfolded):
            unfolded_axs.append(fig.add_subplot(gs0[i]))

    folded_axs = []
    gs1 = gs[1, :].subgridspec(num_planets, num_folded, hspace=0.25, wspace=0.05, height_ratios=[1, 1])
    folded_axs.append(fig.add_subplot(gs1[0, :]))
    ax = fig.add_subplot(gs1[1, 0])
    folded_axs.append(ax)
    folded_axs.append(fig.add_subplot(gs1[1, 1], sharey=ax))
    folded_axs.append(fig.add_subplot(gs1[1, 2], sharey=ax))
    return fig, folded_axs, unfolded_axs

In [None]:
lc_model_style = {
    'color': '#866bd6',
}

lc_transit_mark_styles = [
    {
        'color': '#db9448',
        'markeredgecolor': '#a66e33',
        'markeredgewidth': 0.5,
    },
    {
        'color': '#5690f5',
        'markeredgecolor': '#325796',
        'markeredgewidth': 0.5,
    },
]

lc_tess_style = {
    'fmt': 'o',
    'color': '#fec260',
    'ecolor': '#70501b',
    'elinewidth': 0.5,
    'markersize': 3,
    'markeredgecolor': '#70501b',
    'markeredgewidth': 0.5,
}

lc_scatter_style = {
    'c': 'gray',
    'alpha': 0.1,
    'marker': '.',
    's': 10,
}

fig, lc_folded_axs, lc_unfolded_axs = make_multi_planet_lc_axes_fixed()

lc_plot_soln = max_post_soln

plot_long_light_curve(
    lc_unfolded_axs, tess_full_times, tess_full_dfluxes,
    lc_plot_soln['period'], lc_plot_soln['t0'],
    [0.2, 0.05], [-0.0079, 0.003], [-0.007, -0.003],
    lc_transit_mark_styles,
    y_unit=1e-3)

for ax in lc_unfolded_axs:
    ax.text(0.015, 0.429, 'b', ha='center', fontsize='x-small', transform=ax.transAxes)
    ax.text(0.015, 0.063, 'c', ha='center', fontsize='x-small', transform=ax.transAxes)

plot_soln = {k: lc_plot_soln[k] for k in ['period', 't0', 'b', 'ecc', 'omega', 'm_planet', 'm_star', 'r_star']}
plot_orbit = xo.orbits.KeplerianOrbit(**plot_soln)

plot_single_planet_folded_light_curve(
    lc_folded_axs[2], 2, 0, plot_orbit, lc_plot_soln['rp'],
    [astep_time], [astep_dflux], [astep_unc],
    lc_plot_soln['u_Rc'], [0.]*len(ground_times),
    [ground_exp_times[0]], [ground_supersampling_factors[0]],
    12./60/24, 6, np.array([0.03, 0.05]),
    plot_phase=False,
    detrend_series=[astep_design],
    detrend_coeffs=[
        lc_plot_soln['lc_detrend_coeffs_2'],
        lc_plot_soln['lc_detrend_coeffs_3'],
        None,
        None,
        None,
    ],
    y_unit=1e-3,
    plot_model_style=lc_model_style,
    plot_errorbar_style=lc_tess_style,
)

plot_single_planet_folded_light_curve(
    lc_folded_axs[3], 2, 0, plot_orbit, lc_plot_soln['rp'],
    [lco_zs_time], [lco_zs_dflux], [lco_zs_unc],
    lc_plot_soln['u_zs'], [0.]*len(ground_times),
    [ground_exp_times[1]], [ground_supersampling_factors[1]],
    12./60/24, 6, np.array([0.03, 0.05]),
    plot_phase=False,
    detrend_series=[lco_zs_design],
    detrend_coeffs=[
        # lc_plot_soln['lc_detrend_coeffs_2'],
        lc_plot_soln['lc_detrend_coeffs_3'],
        None,
        None,
        None,
    ],
    y_unit=1e-3,
    plot_model_style=lc_model_style,
    plot_errorbar_style=lc_tess_style,
)

plot_multi_planet_folded_light_curve_2(
    2, [lc_folded_axs[1], lc_folded_axs[0]],
    plot_orbit, lc_plot_soln['rp'],
    lc_times, [l for l in lc_dfluxes], lc_uncs,
    lc_plot_soln['u_tess'],
    [lc_plot_soln['mean_flux_0'], lc_plot_soln['mean_flux_1']],
    exposure_times, supersampling_factors,
    [15./60/24, 10./60/24], [0.03, 0.06], 0.5/24, 15,
    plot_phase=False,
    y_unit=1e-3,
    plot_model_style=lc_model_style,
    plot_errorbar_style=lc_tess_style,
    plot_scatter_style=lc_scatter_style,
)


lc_unfolded_axs[0].set_xlim(1542, 1626)
lc_unfolded_axs[1].set_xlim(2280, 2362)

for ax in lc_unfolded_axs:
    ax.set_ylabel('$\Delta \mathrm{flux}\, / \, \mathrm{flux}$ (‰)', fontsize="small")
    ax.tick_params(direction='in', which='both', right=True, labelsize='x-small')
    ax.minorticks_on()
ax.set_xlabel(f'$\mathrm{{BJD}}_\mathrm{{TDB}} - 2\,457\,000$ (day)', fontsize="small")

lc_unfolded_axs[0].text(0.1, 0.85, 'TESS Year 1', ha='center', transform=lc_unfolded_axs[0].transAxes)
lc_unfolded_axs[1].text(0.1, 0.85, 'TESS Year 3', ha='center', transform=lc_unfolded_axs[1].transAxes)

lc_folded_axs[0].set_xlim(-4.5, 4.5)
lc_folded_axs[0].set_ylim(-0.750, 0.350)
for ax in lc_folded_axs[1:]:
    ax.set_xlim(-4.5, 4.5)
    ax.set_ylim(-5.600, 1.200)
for ax in lc_folded_axs:
    ax.tick_params(direction="in", which='both', right=True, labelsize="x-small")
    ax.minorticks_on()
    ax.set_xlabel('Hours from Mid-Transit', fontsize="small")
for ax in lc_folded_axs[:2]:
    ax.set_ylabel('$\Delta \mathrm{flux}\, / \, \mathrm{flux}$ (‰)', fontsize="small")
lc_folded_axs[2].tick_params(labelleft=False)
lc_folded_axs[3].tick_params(labelleft=False)

lc_folded_axs[0].text(0.2/3, 0.1, 'Planet b\n(TESS)', ha='center', transform=lc_folded_axs[0].transAxes)
lc_folded_axs[1].text(0.2, 0.1, 'Planet c\n(TESS)', ha='center', transform=lc_folded_axs[1].transAxes)
lc_folded_axs[2].text(0.2, 0.1, 'Planet c\n($R_\mathrm{C}$)', ha='center', transform=lc_folded_axs[2].transAxes)
lc_folded_axs[3].text(0.2, 0.1, 'Planet c\n($z_s$)', ha='center', transform=lc_folded_axs[3].transAxes)


fig.savefig('../plots/toi_2000_lc.pdf', bbox_inches='tight', pad_inches=0.01)

In [None]:
rv_data_default_style = {
    'ecolor': 'gray',
    'elinewidth': 1,
    'alpha': 0.7,
    'fmt': 'o',
    'markersize': 3,
}
rv_model_default_style = {
    'color': 'slateblue',
    'zorder': 0,
}
rv_trend_unc_default_style = {
    'color': 'slateblue',
    'alpha': 0.2,
}

def plot_model_rv(num_planets, orbit, rv_semiamps, num_points=500):
    periods = orbit.period.eval()
    epochs = orbit.period.eval()
    model_phase = np.linspace(-0.5, 0.5, num_points)
    model_rvs = []
    for i in range(num_planets):
        model_time = model_phase * orbit.period[i] + orbit.t0[i]
        model_rvs.append(orbit.get_radial_velocity(
            model_time, K=rv_semiamps,
        ).eval()[:, i])
    return model_phase, model_rvs

def plot_multi_planet_folded_rv(
    folded_axs,
    unfolded_ax,
    residual_ax,
    num_planets, orbit,
    rv_semiamps, gammas, jitters,
    rv_times, rv_data, rv_uncs, rv_names,
    trends=None, model_trend_func=None,
    model_trend_unc_func=None,
    rv_data_style=rv_data_default_style,
    rv_model_style=rv_model_default_style,
    rv_inst_styles=None,
    rv_trend_unc_style=rv_trend_unc_default_style,
):
    if rv_inst_styles is None:
        rv_inst_styles = [dict()] * len(rv_times)

    min_obs_time = np.min([np.min(t) for t in rv_times])
    max_obs_time = np.max([np.max(t) for t in rv_times])
    unfold_model_time = np.linspace(min_obs_time, max_obs_time, num=int((max_obs_time-min_obs_time)//0.05))
    unfold_model = np.sum(orbit.get_radial_velocity(
            unfold_model_time, K=rv_semiamps,
    ).eval(), axis=1)
    if unfolded_ax is not None:
        if model_trend_func is not None:
            unfold_model_trend = model_trend_func(unfold_model_time)
            unfold_model += unfold_model_trend
            unfolded_ax.plot(unfold_model_time, unfold_model_trend, color='gray', linestyle='--')
            if model_trend_unc_func is not None:
                unfold_model_trend_unc = model_trend_unc_func(unfold_model_time)
                trend_unc = unfold_model_trend + unfold_model_trend_unc
                unfolded_ax.fill_between(unfold_model_time, trend_unc[0], trend_unc[1], **rv_trend_unc_style)
        unfolded_ax.plot(unfold_model_time, unfold_model, **rv_model_style)
        if residual_ax is not None:
            residual_ax.axhline(0, **rv_model_style)
            if model_trend_unc_func is not None:
                residual_ax.fill_between(unfold_model_time, unfold_model_trend_unc[0], unfold_model_trend_unc[1], **rv_trend_unc_style)

    model_phase, model_rvs = plot_model_rv(num_planets, orbit, rv_semiamps)
    for i, (ax, model_rv) in enumerate(zip(folded_axs, model_rvs)):
        ax.plot(model_phase, model_rv, **rv_model_style)

    epochs = orbit.t0.eval()
    periods = orbit.period.eval()
    if trends is None:
        trends = [0.] * len(rv_times)
    rv_data_style = rv_data_style.copy()
    for rv_time, rv, rv_unc, gamma, jitter, trend, label, plot_style in zip(
        rv_times, rv_data, rv_uncs, gammas, jitters, trends, rv_names, rv_inst_styles):
        rv_errorbar = np.sqrt(rv_unc**2+jitter**2)
        rv_shifted = rv - gamma
        rv_detrend = rv_shifted - trend

        model_rv = orbit.get_radial_velocity(
            rv_time, K=rv_semiamps,
        ).eval()
        model_rv_sum = np.sum(model_rv, axis=1)

        rv_data_style.update(plot_style)
        if unfolded_ax is not None:
            unfolded_ax.errorbar(rv_time, rv_shifted, rv_errorbar, label=label, **rv_data_style)
            if residual_ax is not None:
                residual_ax.errorbar(rv_time, rv_detrend-model_rv_sum, rv_errorbar, label=label, **rv_data_style)

        for i, (ax, period, epoch) in enumerate(zip(folded_axs, periods, epochs)):
            rv_phase = phase_fold(rv_time, epoch, period)
            rv_only_planet = rv_shifted - model_rv_sum + model_rv[:, i] - trend
            ax.errorbar(rv_phase, rv_only_planet, rv_errorbar, label=label, **rv_data_style)

In [None]:
def make_multi_planet_rv_axes(num_planets, unfolded=True, residuals=True, figure_kwargs={'dpi': 600}):
    figure_kwargs['figsize'] = (7, 8.5)
    heights = [5, 3]
    fig = plt.figure(constrained_layout=False, **figure_kwargs)
    if unfolded:
        gs = gridspec.GridSpec(2, num_planets, figure=fig, wspace=0.15, hspace=0.15, height_ratios=heights)
        folded_row = 1
        if residuals:
            gs0 = gs[0, :].subgridspec(2, 1, hspace=0, height_ratios=[2, 1])
            unfolded_ax = fig.add_subplot(gs0[0])
            residual_ax = fig.add_subplot(gs0[1], sharex=unfolded_ax)
        else:
            unfolded_ax = fig.add_subplot(gs[0, :])
            residual_ax = None
    else:
        folded_row = 0
        unfolded_ax = None
        residual_ax = None
    folded_axs = []
    for i in range(num_planets):
        folded_axs.append(fig.add_subplot(gs[folded_row, i]))
    return fig, folded_axs, unfolded_ax, residual_ax

In [None]:
def polynomial_design_matrix(x, offset, order):
    return np.vander(x-offset, order+1, increasing=True)[:, 1:]

def model_set_up_polynomial_detrend(model, rv_times, order):
    if order < 1:
        raise ValueError('Polynomial must be at least of linear order')
    min_time = min(np.min(t) for t in rv_times)
    max_time = max(np.max(t) for t in rv_times)
    with model:
        rv_offset = (min_time + max_time) / 2.
        pm.Deterministic('rv_time_offset', tt.as_tensor_variable(rv_offset))
        pm.Normal("rv_trend_coeff", mu=0, sd=10.**-np.arange(1, order+1), shape=order)
        for i, t in enumerate(rv_times):
            design_matrix = tt.as_tensor_variable(polynomial_design_matrix(t, rv_offset, order))
            pm.Deterministic(f'rv_design_matrix_{i}', design_matrix)

def trend_generator(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_design_matrix_{i}'] @ model['rv_trend_coeff']

def trend_generator_trace(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_design_matrix_{i}'].data @ model['rv_trend_coeff'].data

def gp_generator(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_gp_pred_{i}']

def gp_matern(map_soln):
    sigma = map_soln['gp_sigma']
    rho = map_soln['gp_rho']
    kernel = terms.Matern32Term(sigma=sigma, rho=rho)
    gp = celerite2.GaussianProcess(kernel)
    return gp

def gp_sho(map_soln):
    sigma = map_soln['gp_sigma']
    rho = map_soln['gp_rho']
    tau = map_soln['gp_tau']
    kernel = terms.SHOTerm(rho=rho, tau=tau, sigma=sigma)
    gp = celerite2.GaussianProcess(kernel)
    return gp

def gp_trend_func_generator(kernel_func, map_soln, rv_times, rv_data, rv_uncs, return_unc=False):
    gp = kernel_func(map_soln)
    
    gp_time = np.concatenate(rv_times)
    gp_sort_args = np.argsort(gp_time)
    gp_time = np.ascontiguousarray(gp_time[gp_sort_args])
    rv_jitters = map_soln['rv_jitter']
    gp_diag = np.ascontiguousarray(
        np.concatenate(
            [unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)]
        )[gp_sort_args])
    gp.compute(gp_time, diag=gp_diag, quiet=True)

    gp_res = np.concatenate([rv - map_soln[f'rv_pred_{i}'] for i, rv in enumerate(rv_data)])[gp_sort_args]
    def gp_trend(t):
        return gp.predict(gp_res, t=t)
    def gp_trend_unc(t):
        mean, variance = gp.predict(gp_res, t=t, return_var=True)
        stdev = np.sqrt(variance)
        return np.vstack((-stdev, stdev))
    if return_unc:
        return gp_trend, gp_trend_unc
    else:
        return gp_trend

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(2)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-0.5, fontsize='small')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-0.5, fontsize='small')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2\,457\,000$', fontsize='small')

for ax in rv_folded_axs:
    ax.set_xlabel('Phase', fontsize='small')
rv_folded_axs[0].set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-1.0, fontsize='small')

rv_plot_soln = max_post_soln

plot_orbit = xo.orbits.KeplerianOrbit(
    **{k: rv_plot_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': '#555555',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^', "alpha": 0.5},
    {'color': '#a12568', 'fmt': 'd', "alpha": 0.3},
    {'color': '#fec260', 'fmt': 'o', "alpha": 1.},
]

rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
    gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs[::-1], rv_unfolded_ax, rv_residual_ax,
    2, plot_orbit, rv_plot_soln['K'],
    rv_plot_soln['rv_gamma'],
    rv_plot_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_plot_soln, len(rv_data)),
    trends=gp_generator(rv_plot_soln, len(rv_data)),
    model_trend_func=rv_trend_func,
    model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_plot_soln['rv_time_offset'], 2) @ rv_plot_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.set_ylim(-75, 50)
rv_unfolded_ax.legend(loc='lower center', fontsize='small')
rv_unfolded_ax.tick_params(direction="in", which='both', labelsize="x-small", top=True, right=True, labelbottom=False)
rv_unfolded_ax.minorticks_on()
rv_residual_ax.tick_params(direction="in", which='both', labelsize="x-small", right=True)
rv_residual_ax.minorticks_on()
rv_residual_ax.set_ylim(-30, 30)

for l, ax, p in zip(['c', 'b'], rv_folded_axs[::-1], rv_plot_soln["period"]):
    ax.text(0.05, 0.05, f'Planet {l}\n$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)
    ax.tick_params(direction="in", which='both', labelsize="x-small", top=True, right=True)
    ax.set_xticks([-0.5, -0.25, 0, 0.25, 0.5])
    ax.set_xticklabels(['−0.5', '−0.25', '0', '0.25', '0.5'])
    ax.minorticks_on()
    # ax.set_xlim((-0.5, 0.5))

rv_folded_axs[0].set_ylim(-25, 25)
rv_folded_axs[1].set_ylim(-45, 45)
    
fig.savefig('../plots/toi_2000_rv.pdf', bbox_inches='tight', pad_inches=0.01)

In [None]:
trace_4p = az.from_netcdf('../chains/toi2000_rv_only_90d_17d.nc')
rv_flat_samples = trace_4p.posterior.stack(sample=("chain", "draw"))
rv_median_soln = {k:v.data for k, v in rv_flat_samples.median(dim='sample').items()}
rv_max_post_index = rv_flat_samples.log_prob.argmax(dim='sample')
rv_max_post_soln = {k:v.data for k, v in rv_flat_samples[{'sample': rv_max_post_index}].items()}

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(4)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-0.5, fontsize='small')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-0.5, fontsize='small')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2\,457\,000$', fontsize='small')

rv_folded_axs_for_plotting = [
    rv_folded_axs[1],
    rv_folded_axs[0],
    rv_folded_axs[3],
    rv_folded_axs[2],
]

rv_folded_axs[0].set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-1.0, fontsize='small')

rv_plot_soln = rv_median_soln

# plot_orbit = xo.orbits.KeplerianOrbit(
#     **{k: rv_plot_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_orbit = xo.orbits.KeplerianOrbit(
    period=np.concatenate([rv_plot_soln['period'], rv_plot_soln['period_rv_only']]),
    t0=np.concatenate([rv_plot_soln['t0'], rv_plot_soln['t0_rv_only']]),
    ecc=np.concatenate([rv_plot_soln['ecc'], [0.]*2]),
    omega=np.concatenate([rv_plot_soln['omega'], [np.pi/2]*2]),
)

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': '#555555',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^', "alpha": 0.5},
    {'color': '#a12568', 'fmt': 'd', "alpha": 0.3},
    {'color': '#fec260', 'fmt': 'o', "alpha": 1.},
]

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs_for_plotting, rv_unfolded_ax, rv_residual_ax,
    4, plot_orbit, np.concatenate([rv_plot_soln['K'], rv_plot_soln['K_rv_only']]),
    rv_plot_soln['rv_gamma'],
    rv_plot_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_plot_soln, len(rv_data)),
    # trends=gp_generator(rv_plot_soln, len(rv_data)),
    # model_trend_func=rv_trend_func,
    # model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_plot_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_plot_soln['rv_time_offset'], 2) @ rv_plot_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.set_ylim(-75, 50)
rv_unfolded_ax.legend(loc='lower center', fontsize='small')
rv_unfolded_ax.tick_params(direction="in", which='both', labelsize="x-small", top=True, right=True, labelbottom=False)
rv_unfolded_ax.minorticks_on()
rv_residual_ax.tick_params(direction="in", which='both', labelsize="x-small", right=True)
rv_residual_ax.minorticks_on()
rv_residual_ax.set_ylim(-30, 30)

for l, ax, p in zip(
    ['c', 'b', None, None], rv_folded_axs_for_plotting,
    np.concatenate([rv_plot_soln['period'], rv_plot_soln['period_rv_only']])):
    if l is not None:
        ax.text(0.05, 0.05, f'Planet {l}\n$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)
    else:
        ax.text(0.05, 0.05, f'\n$P = {p:.1f}\,\mathrm{{d}}$', transform=ax.transAxes)
    ax.tick_params(direction="in", which='both', labelsize="x-small", top=True, right=True)
    ax.set_xticks([-0.5, -0.25, 0, 0.25, 0.5])
    ax.set_xticklabels(['−0.5', '−0.25', '0', '0.25', '0.5'])
    ax.minorticks_on()
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)', labelpad=-1.0, fontsize='small')
    ax.set_xlabel('Phase', fontsize='small')
    # ax.set_xlim((-0.5, 0.5))


rv_folded_axs[0].set_ylim(-25, 25)
rv_folded_axs[1].set_ylim(-45, 45)
rv_folded_axs[2].set_ylim(-25, 25)
rv_folded_axs[3].set_ylim(-35, 35)
    
fig.savefig('../plots/toi_2000_rv_4p.pdf', bbox_inches='tight', pad_inches=0.01)