In [None]:
import itertools
import os.path as path

import aesara_theano_fallback.tensor as tt
import astropy.io.fits as fits
import astropy.timeseries as timeseries
import astropy.units as u
import arviz as az
import celerite2
import celerite2.terms as terms
import celerite2.theano.terms as theano_terms
import corner
import exoplanet as xo
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc3 as pm
import pymc3_ext as pmx
import scipy.stats as stats

# import third_party.keplersplinev2.keplersplinev2 as ksp

import adhocfitter.astro as astro
# import adhocfitter.limbdark as limbdark
import adhocfitter.plotting as plotting
import adhocfitter.priors as priors
import adhocfitter.timeseries as aftimeseries

# Disable OpenMP parallelism
import os
os.environ["OMP_NUM_THREADS"] = "1"

DATA_DIR = '../data'

In [None]:
def phase_fold(time_series, mid_transit, period):
    return ((time_series - mid_transit) / period + 0.5) % 1. - 0.5

def estimate_white_noise(light_curve, detrend_key, is_magnitude, planet_mask, cut=3.):
    mask = light_curve['mask']
    noise_mask = np.logical_and(mask, np.logical_not(planet_mask))
    if is_magnitude:
        mag = light_curve[detrend_key][noise_mask]
        zero_point, mean_std, _ = ksp.robust_mean(mag, cut=cut)
        unmasked_flux = 10 ** (-0.4 * (light_curve[detrend_key] - zero_point))
        flux = unmasked_flux[noise_mask]
    else:
        unmasked_flux = light_curve[detrend_key]
        flux = unmasked_flux[noise_mask]
        _, mean_std, _ = ksp.robust_mean(flux, cut=cut) 
    stddev = mean_std  * np.sqrt(len(flux) - 1)
    return stddev, unmasked_flux

def estimate_white_noise_v2(flux, planet_mask, cut=3.):
    masked_flux = flux[np.logical_not(planet_mask)]
    mean, mean_std, _ = ksp.robust_mean(masked_flux, cut)
    return mean_std * np.sqrt(len(masked_flux) + 1), mean_std

def mask_planet(time, period, mid_transit_time, transit_duration, duration_factor):
    """Only use in-transit points within `duration_factor` times of the transit duration."""
    phase_limit = transit_duration / period / 2. * duration_factor
    phase = (time - mid_transit_time) / period
    phase += 0.5
    phase %= 1.
    phase -= 0.5
    planet_mask = np.abs(phase) < phase_limit
    return planet_mask

In [None]:
test_epoch = np.array([2110.0659, 1901.719])  # TJD
test_epoch_unc = np.array([0.005, 0.02])
test_period = np.array([9.127055, 3.09833])  # day
test_period_unc = np.array([0.0001, 0.0002])  #np.array([0.00001, 0.000023])
test_last_seen = np.array([2111.0, 1855.3])
# 2073.55739092,  2073.55767828,  2073.55796009]),
#  array([ 1907.91526705,  1907.91617858,  1907.91709297])
test_epoch = (test_last_seen - test_epoch) // test_period * test_period + test_epoch
print(test_epoch)
test_duration = np.array([4.9, 2.0])  # hour
test_duration_unc = np.array([6., 4.])
test_omega_vec = np.array([0., 1.])

test_radius_planet = np.array([0.066, 0.022])  # Stellar radius
test_radius_planet_unc = np.array([0.5, 0.5])
test_impact_param = np.array([0.615, 0.756])
test_limb_dark = np.array([0.37, 0.21])

# Estimated values from previous fits
test_mass_star = 1.05  # HARPS value
test_mass_star_unc = 0.2
test_feh = 0.415
test_eep = 380
# A_V extinction
test_av = 0.15  # From Keivan.
test_av_upper = 0.70  # Upper limit from Schlegel 1998 dust map.
# Gaia EDR3 parallax
PARALLAX_GAIA_EDR3 = 5.76044602669829
PARALLAX_GAIA_EDR3_UNC= 0.01048182
# Gaia EDR3 parallax zero-point correction, computed by:
# zpt.get_zpt(
#     phot_g_mean_mag=10.8542,
#     nu_eff_used_in_astrometry=1.5263987,
#     pseudocolour=float('nan'),
#     ecl_lat=-68.38768994372302,
#     astrometric_params_solved=31,  # 5-param solution.
# )
PARALLAX_GAIA_EDR3 -= -0.012446

test_mass_planet = (np.array([82., 11.]) * u.M_earth).to(u.M_sun).value / test_mass_star
test_mass_planet_unc = (np.array([300., 100.]) * u.M_earth).to(u.M_sun).value / test_mass_star # Upper limit.

print(test_mass_planet, test_mass_planet_unc)

# TIC 8.1 values
# test_density_star = 0.839021  # g cm^-3
# test_density_star_unc = 0.190382
# test_radius_star = 1.18093  # R_sun
# test_radius_star_unc = 0.0597739
# test_mass_star = 0.98  # M_sun
# test_mass_star_unc = 0.131459

# FEROS values
# test_density_star = 1.052  # g cm^-3
# test_density_star_unc = 0.052
# test_radius_star = 1.104  # R_sun
# test_radius_star_unc = 0.011
# test_mass_star = 1.001  # M_sun
# test_mass_star_unc = 0.026

# print(test_mass_star/test_radius_star**3 * astro.SOLAR_DENSITY*1e-3)

test_params = {
    "t0": test_epoch,
    "period": test_period,
    "rp": test_radius_planet,
    "b": test_impact_param,
    "ecc_fix": np.array([False, True]),
    "ecc": np.array([0., 0.]),
    "omega": np.array([np.pi/2, np.pi/2]),
    "u": {"tess": test_limb_dark, "Rc": test_limb_dark, "tess_c": test_limb_dark},
#     "rho_star": test_density_star,
    "m_star": test_mass_star,
#     "r_star": test_radius_star,
    "tdur": test_duration,
    "m_planet": test_mass_planet,
    "av": test_av,
    "parallax": PARALLAX_GAIA_EDR3,
    "feh": test_feh,
    "eep": test_eep,
}

prior_unc = {
    "t0": test_epoch_unc,
    "period": test_period_unc,
    "rp": test_radius_planet_unc,
#     "rho_star": test_density_star_unc,
    "m_star": test_mass_star_unc,
#     "r_star": test_radius_star_unc,
    "m_planet": test_mass_planet_unc,
    "tdur": test_duration_unc,
    "av": test_av_upper,
    "parallax": PARALLAX_GAIA_EDR3_UNC,
    "albedo_bond": 0.7,
}

In [None]:
def add_normal_prior(name, test_params, prior_unc, shape=None):
    if shape is None:
        return pm.Normal(
            name, mu=test_params[name], sd=prior_unc[name])
    else:
        return pm.Normal(
            name, mu=test_params[name], sd=prior_unc[name], shape=shape)

def add_uniform_prior(name, test_params, prior_unc, shape=None):
    lower = test_params[name] - prior_unc[name]
    upper = test_params[name] + prior_unc[name]
    if shape is None:
        return pm.Uniform(
            name, lower=lower, upper=upper, testval=test_params[name])
    else:
        return pm.Uniform(
            name, lower=lower, upper=upper, testval=test_params[name], shape=shape)

In [None]:
def optimize_model(model, passes=None):
    with model:
        if passes is None:
            return pmx.optimize(start=model.test_point)
        map_soln = pmx.optimize(start=model.test_point, vars=[model[v] for v in passes[0]])
        for p in passes[1:]:
            map_soln = pmx.optimize(start=map_soln, vars=[model[v] for v in p])
        map_soln = pmx.optimize(start=map_soln)
    return map_soln

In [None]:
def read_andrew_lc(filename, cols=['time', 'rawflux', 'kspflux', 'none'], keep_planets=None):
    andrew_data = pd.read_csv(filename,
                              names=cols, index_col=None)
    light_curve = np.zeros(
        len(andrew_data),
        dtype=[('mask', '?'), ('BJD', 'f8'), ('KSPSAP_FLUX', 'f8'), ('KSPSAP_FLUX_UNC', 'f8')])
    light_curve['mask'][:] = True
    light_curve['BJD'][:] = andrew_data['time']
    light_curve['KSPSAP_FLUX'][:] = andrew_data['kspflux']
    prelim_planet_mask = np.zeros(len(andrew_data), dtype=bool)
    for period, t0, tdur in zip(test_params['period'], test_params['t0'], test_params['tdur']/24.):
        prelim_planet_mask = np.logical_or(
            prelim_planet_mask, mask_planet(
                light_curve['BJD'], period, t0, tdur, 2.))
    if 'kspflux_unc' in cols:
        light_curve['KSPSAP_FLUX_UNC'][:] = andrew_data['kspflux_unc']
    else:
        tess_noise, _ = estimate_white_noise(light_curve, 'KSPSAP_FLUX', False, prelim_planet_mask, cut=3.)

    if keep_planets is None:
        keep_planets = [True for _ in test_params['period']]
    planet_mask = np.zeros(len(andrew_data), dtype=bool)
    for period, t0, tdur, keep_planet in zip(test_params['period'], test_params['t0'], test_params['tdur']/24., keep_planets):
        if keep_planet:
            planet_mask = np.logical_or(
                planet_mask, mask_planet(
                    light_curve['BJD'], period, t0, tdur, 10.))
    for period, t0, tdur, keep_planet in zip(test_params['period'], test_params['t0'], test_params['tdur']/24., keep_planets):
        if not keep_planet:
            planet_mask = np.logical_and(
                planet_mask,
                np.logical_not(mask_planet(
                    light_curve['BJD'], period, t0, tdur, 3.)))
    tess_mask = np.logical_and(light_curve['mask'], planet_mask)
    # tess_mask = light_curve['mask']

    tess_time = np.ascontiguousarray(light_curve['BJD'][tess_mask])
    tess_flux = np.ascontiguousarray(light_curve['KSPSAP_FLUX'][tess_mask])
    tess_dflux = tess_flux - 1
    if 'kspflux_unc' in cols:
        tess_noise = np.ascontiguousarray(light_curve['KSPSAP_FLUX_UNC'][tess_mask])
    return tess_time, tess_dflux, tess_noise

TRANSIT_KEEP_PLANETS = [False, True]

tess_time, tess_dflux, tess_noise = read_andrew_lc(
    'data/toi2000/toi2000finallc-smallcutout-cbv.csv',
    keep_planets=TRANSIT_KEEP_PLANETS)

tess_2m_time, tess_2m_dflux, tess_2m_noise = read_andrew_lc(
    'data/toi2000/toi2000twentysecondbinnedtotwominues-cbv.csv',
    cols=['time', 'kspflux', 'kspflux_unc', 'none'],
    keep_planets=TRANSIT_KEEP_PLANETS)

tess_20s_time, tess_20s_dflux, tess_20s_noise = read_andrew_lc(
    'data/toi2000/toi2000twentysecond-cbv.csv',
    cols=['time', 'kspflux', 'kspflux_unc', 'none'],
    keep_planets=TRANSIT_KEEP_PLANETS)
tess_20s_noise = np.median(tess_20s_noise)

lc_times = [
    tess_time,
    # tess_2m_time,
    tess_20s_time,
]
lc_dfluxes = [
    tess_dflux,
    # tess_2m_dflux,
    tess_20s_dflux,
]
lc_uncs = [
    tess_noise,
    # tess_2m_noise,
    tess_20s_noise,
]
exposure_times = [
    30./60/24,
    # 2./60/24,
    20./60/60/24,
]
supersampling_factors = [15, 1]

In [None]:
fig, ax = plt.subplots(dpi=144)
# ax.plot(
#     phase_fold(tess_time, test_params['t0'][1], test_params['period'][1]),
#     tess_dflux, '.')
tess_phase = phase_fold(tess_time, test_params['t0'][1], test_params['period'][1]) * test_params['period'][1] * 24.
ax.plot(
    tess_phase,
    tess_dflux, '.', color='gray', alpha=0.1)
tess_binned_dflux, tess_bins, _ = stats.binned_statistic(
    tess_phase,
    tess_dflux,
    statistic='mean',
    bins=60,
)
tess_binned_std, _, _ = stats.binned_statistic(
    tess_phase,
    tess_dflux,
    statistic=lambda x: np.std(x, ddof=1)/np.sqrt(len(x)),
    bins=tess_bins,
)
ax.errorbar(
    (tess_bins[:-1]+tess_bins[1:])/2,
    tess_binned_dflux,
    tess_binned_std,
    fmt='.')
# ax.set_xlim(-0.05, 0.05)
# ax.set_ylim(-0.001, 0.0005)
fig.savefig('plot/toi2000_30min_planet_c.png')

In [None]:
fig, ax = plt.subplots(dpi=144)
# ax.plot(
#     phase_fold(tess_time, test_params['t0'][1], test_params['period'][1]),
#     tess_dflux, '.')
tess_20s_phase = phase_fold(tess_20s_time, test_params['t0'][1], test_params['period'][1]) * test_params['period'][1] * 24.
ax.plot(
    tess_20s_phase,
    tess_20s_dflux, '.', color='gray', alpha=0.1)
tess_20s_binned_dflux, tess_20s_bins, _ = stats.binned_statistic(
    tess_20s_phase,
    tess_20s_dflux,
    statistic='mean',
    bins=60,
)
tess_20s_binned_std, _, _ = stats.binned_statistic(
    tess_20s_phase,
    tess_20s_dflux,
    statistic=lambda x: np.std(x, ddof=1)/np.sqrt(len(x)),
    bins=tess_20s_bins,
)
ax.errorbar(
    (tess_20s_bins[:-1]+tess_20s_bins[1:])/2,
    tess_20s_binned_dflux,
    tess_20s_binned_std,
    fmt='.')
# ax.set_xlim(-0.05, 0.05)
ax.set_ylim(-0.001, 0.0005)
fig.savefig('plot/toi2000_20s_planet_c.png')

In [None]:
fig, ax = plt.subplots(dpi=144)
# ax.plot(
#     phase_fold(tess_time, test_params['t0'][1], test_params['period'][1]),
#     tess_dflux, '.')
tess_2m_phase = phase_fold(tess_2m_time, test_params['t0'][1], test_params['period'][1]) * test_params['period'][1] * 24.
ax.plot(
    tess_2m_phase,
    tess_2m_dflux, '.', color='gray', alpha=0.1)
tess_2m_binned_dflux, tess_2m_bins, _ = stats.binned_statistic(
    tess_2m_phase,
    tess_2m_dflux,
    statistic='mean',
    bins=60,
)
tess_2m_binned_std, _, _ = stats.binned_statistic(
    tess_2m_phase,
    tess_2m_dflux,
    statistic=lambda x: np.std(x, ddof=1)/np.sqrt(len(x)),
    bins=tess_2m_bins,
)
ax.errorbar(
    (tess_2m_bins[:-1]+tess_2m_bins[1:])/2,
    tess_2m_binned_dflux,
    tess_2m_binned_std,
    fmt='.')
ax.set_ylim(-0.001, 0.0005)

In [None]:
def plot_model_light_curve(
    ldlc_obj, orbit, radii_planets,
    texp, supersampling_factor,
    max_phases, num_points=2000):
    model_phases = []
    model_light_curve = []
    for i, max_phase in enumerate(max_phases):
        model_phase = np.linspace(-max_phase, max_phase, num_points)
        model_phases.append(model_phase)
        model_time = model_phase * orbit.period[i] + orbit.t0[i]
        model_light_curve.append(ldlc_obj.get_light_curve(
            orbit=orbit,
            r=radii_planets,
            t=model_time,
            texp=texp,
            oversample=supersampling_factor,
        ).eval()[:, i])
    return model_phases, model_light_curve

def plot_multi_planet_folded_light_curve(
    num_planets, orbit, rp,
    lc_times, lc_dfluxes, filters, limb_dark_params, mean_fluxes,
    exposure_times, supersampling_factors,
    nbins, max_phases, alphas, dpi=300):

    num_light_curves = len(lc_times)
    periods = orbit.period.eval()
    epochs = orbit.t0.eval()
    radii_planets = rp * orbit.r_star.eval()
#     orbit = xo.orbits.KeplerianOrbit(
#         **{k: map_soln[k] for k in ['period', 't0', 'b', 'm_star', 'r_star', 'ecc', 'omega']})
    ldlc_objs = dict()
    for filter_name in set(filters):
        ldlc_objs[filter_name] = xo.LimbDarkLightCurve(
                limb_dark_params[f'u_{filter_name}'])
    
    fig, axs = plt.subplots(
        num_light_curves,
        num_planets,
        sharex='col',
        sharey='row',
        figsize=(5*num_planets, 3*num_light_curves),
        dpi=dpi,
        squeeze=False,
    )

    for row, lc_time, lc_dflux, filter_name, texp, supersample, nbin, alpha in zip(
        axs, lc_times, lc_dfluxes, filters,
        exposure_times, supersampling_factors,
        nbins, alphas):

        ldlc_obj = ldlc_objs[filter_name]
        model_dflux = ldlc_obj.get_light_curve(
            orbit=orbit,
            r=radii_planets,
            t=lc_time,
            texp=texp,
            oversample=supersample,
        ).eval()
        model_dflux_sum = np.sum(model_dflux, axis=1)
        plot_model_phases, plot_model_dflux = plot_model_light_curve(
            ldlc_obj, orbit, radii_planets, texp, supersample, max_phases)

        for i, (ax, period, epoch, max_phase) in enumerate(
            zip(row, periods, epochs, max_phases)):

            lc_phase = phase_fold(lc_time, epoch, period)
            lc_dflux_only_planet = lc_dflux - model_dflux_sum + model_dflux[:, i] - mean_fluxes[i]
            ax.plot(
                lc_phase,
                lc_dflux_only_planet,
                '.', color='gray', alpha=alpha, rasterized=True)
            ax.plot(
                plot_model_phases[i], plot_model_dflux[i])

            phase_mask = np.abs(lc_phase) < max_phase
            select_phase = lc_phase[phase_mask]
            select_dflux = lc_dflux_only_planet[phase_mask]
            if nbin is not None:
                binned_mean, bins, _ = stats.binned_statistic(
                    select_phase, select_dflux, statistic='mean', bins=nbin)
                binned_err, _, _ = stats.binned_statistic(
                    select_phase, select_dflux, statistic=lambda a: np.std(a, ddof=1), bins=bins)
                binned_count, _ = np.histogram(select_phase, bins=bins)
                print(binned_count)
                mid_bin = (bins[:-1] + np.diff(bins) / 2.)
                ax.errorbar(mid_bin, binned_mean,
                    yerr=binned_err/np.sqrt(binned_count),
                    fmt='.', color='C1')
            ax.set_xlim(-max_phase, max_phase)
    fig.tight_layout()
    return fig, axs

In [None]:
def make_simple_model():
    with pm.Model() as model:
        # period = pm.Deterministic('period', tt.as_tensor_variable(test_params['period'][1]))
        # epoch = pm.Deterministic('t0', tt.as_tensor_variable(test_params['t0'][1]))
        period = pm.Normal('period', mu=test_params['period'][1], sigma=prior_unc['period'][1])
        epoch = pm.Normal('t0', mu=test_params['t0'][1], sigma=prior_unc['t0'][1])
        duration = pm.Uniform('tdur', lower=0., upper=4., testval=2.)
        ror = pm.Uniform('ror', lower=0., upper=0.2, testval=test_params['rp'][1])
        impact_param = xo.distributions.ImpactParameter('b', ror=ror)
        radius_star = pm.Deterministic('r_star', tt.as_tensor_variable(1.))

        orbit = xo.orbits.SimpleTransitOrbit(
            period=period,
            duration=duration/24.,
            t0=epoch,
            b=impact_param,
            r_star=radius_star,
            ror=ror
        )

        limb_dark_param = xo.distributions.QuadLimbDark('u_tess')
        ldlc_obj = xo.LimbDarkLightCurve(limb_dark_param)
        for i, (lc_time, lc_dflux, lc_noise, texp, supersample) in enumerate(zip(
            lc_times, lc_dfluxes, lc_uncs, exposure_times, supersampling_factors)):
            model_light_curve = ldlc_obj.get_light_curve(
                orbit=orbit,
                r=ror*radius_star,
                t=lc_time,
                texp=texp,
                oversample=supersample,
            )
            # pm.Deterministic(f'lc_pred_{i}', model_light_curve)
            pm.Normal(f'lc_obs_{i}', mu=tt.sum(model_light_curve, axis=1), sigma=lc_noise, observed=lc_dflux)
    return model

model = make_simple_model()

map_soln = optimize_model(model)

In [None]:
map_soln

In [None]:
fig, ax = plt.subplots(2, sharex=True, figsize=(3, 4), dpi=300)
tess_phase = phase_fold(tess_time, test_params['t0'][1], test_params['period'][1])
ax[0].plot(
    tess_phase,
    tess_dflux, '.', color='gray', alpha=0.1)
ax[0].plot(
    tess_phase,
    np.sum(map_soln['lc_pred_0'], axis=1), '.')
tess_20s_phase = phase_fold(tess_20s_time, test_params['t0'][1], test_params['period'][1])
ax[1].plot(
    tess_20s_phase,
    tess_20s_dflux, '.', color='gray', alpha=0.05)
ax[1].plot(
    tess_20s_phase,
    np.sum(map_soln['lc_pred_1'], axis=1), '.')
ax[0].set_ylim(-0.0017, 0.0014)
ax[1].set_ylim(-0.0017, 0.0014)

In [None]:
fig, ax = plt.subplots(dpi=144)
# ax.plot(
#     phase_fold(tess_time, test_params['t0'][1], test_params['period'][1]),
#     tess_dflux, '.')
tess_20s_phase = phase_fold(tess_20s_time, map_soln['t0'], map_soln['period']) * map_soln['period'] * 24.
ax.plot(
    tess_20s_phase,
    tess_20s_dflux, '.', color='gray', alpha=0.1)
tess_20s_binned_dflux, tess_20s_bins, _ = stats.binned_statistic(
    tess_20s_phase,
    tess_20s_dflux,
    statistic='mean',
    bins=60,
)
tess_20s_binned_std, _, _ = stats.binned_statistic(
    tess_20s_phase,
    tess_20s_dflux,
    statistic=lambda x: np.std(x, ddof=1)/np.sqrt(len(x)),
    bins=tess_20s_bins,
)
ax.errorbar(
    (tess_20s_bins[:-1]+tess_20s_bins[1:])/2,
    tess_20s_binned_dflux,
    tess_20s_binned_std,
    fmt='.')
# ax.set_xlim(-0.05, 0.05)
ax.set_ylim(-0.001, 0.0005)

In [None]:
traces = []

In [None]:
with model:
    trace = pmx.sample(
        tune=3000,
        draws=1000,
        start=map_soln,
        cores=32,
        chains=32,
        # initial_accept=0.5,
        target_accept=0.97,
        return_inferencedata=True,
#         parameter_groups=[
#             pmx.ParameterGroup([simple_model.period, simple_model.t0]),
#             pmx.ParameterGroup([simple_model.tdur, simple_model.rp, simple_model.b]),
#             pmx.ParameterGroup([simple_model.mean_flux, simple_model.lc_jitter]),
#         ],
    )
traces.append(trace)

In [None]:
az.

In [None]:
az.plot_trace(trace, var_names=['period', 't0', 'ror', 'b', 'tdur', 'u_tess'])

In [None]:
az.plot_trace(trace, var_names=['period', 't0', 'ror', 'b', 'tdur', 'u_tess'])

In [None]:
summary = az.summary(trace, var_names=['period', 't0', 'ror', 'b', 'tdur', 'u_tess'], round_to=7)
summary

In [None]:
summary = az.summary(trace, var_names=['period', 't0', 'ror', 'b', 'tdur', 'u_tess'], round_to=6)
summary

In [None]:
print(summary)

In [None]:
flat_samples = trace.posterior.stack(sample=("chain", "draw"))
median_soln = {k:v.data for k, v in flat_samples.median(dim='sample').items()}
# median_soln = {k: np.median(v, axis=-1) for k, v in flat_samples.items()}
# max_post_index = flat_samples.log_prob.argmax(dim='sample')
# max_post_soln = {k:v.data for k, v in flat_samples[{'sample': max_post_index}].items()}

In [None]:
map_soln, median_soln

In [None]:
np.atleast_1d(np.array(1923.40779516067))

In [None]:
plot_orbit = xo.orbits.SimpleTransitOrbit(
    **{k: np.atleast_1d(median_soln[k]) for k in ['period', 't0', 'b']},
    duration=np.atleast_1d(median_soln['tdur']/24.),
    ror=np.atleast_1d(median_soln['ror']),
    r_star=median_soln['r_star'],
    )

fig, axs = plot_multi_planet_folded_light_curve(
    1, plot_orbit, np.atleast_1d(median_soln['ror']),
    lc_times, lc_dfluxes, ['tess', 'tess'],
    {k:v.data for k, v in median_soln.items() if k[:2]=='u_'},
    [0],
    exposure_times, supersampling_factors,
    [35]*2, [0.05], [0.3]*2)

In [None]:
plot_orbit = xo.orbits.SimpleTransitOrbit(
    **{k: np.atleast_1d(median_soln[k]) for k in ['period', 't0', 'b']},
    duration=np.atleast_1d(median_soln['tdur']/24.),
    ror=np.atleast_1d(median_soln['ror']),
    r_star=median_soln['r_star'],
    )

fig, axs = plot_multi_planet_folded_light_curve(
    1, plot_orbit, np.atleast_1d(median_soln['ror']),
    lc_times, lc_dfluxes, ['tess', 'tess'],
    {k:v.data for k, v in median_soln.items() if k[:2]=='u_'},
    [0],
    exposure_times, supersampling_factors,
    [35]*2, [0.05], [0.3]*2)

for row in axs:
    row[0].set_ylim(-0.001, 0.0005)

In [None]:
# plot_orbit = xo.orbits.SimpleTransitOrbit(
#     **{k: np.atleast_1d(median_soln[k]) for k in ['period', 't0', 'b']},
#     duration=np.atleast_1d(median_soln['tdur']/24.),
#     ror=np.atleast_1d(median_soln['ror']),
#     r_star=median_soln['r_star'],
#     )
# plot_orbit = xo.orbits.SimpleTransitOrbit(
#     **{k: np.atleast_1d(median_soln[k]) for k in ['period', 't0', 'b']},
#     duration=np.atleast_1d(median_soln['tdur']/24.),
#     ror=np.atleast_1d(median_soln['ror']),
#     r_star=median_soln['r_star'],
#     )

def compute_lc(soln):
    orbit = xo.orbits.SimpleTransitOrbit(
        **{k: np.atleast_1d(soln[k]) for k in ['period', 't0', 'b']},
        duration=np.atleast_1d(soln['tdur']/24.),
        ror=np.atleast_1d(soln['ror']),
        r_star=soln['r_star'],
        )
    ldlc_obj = xo.LimbDarkLightCurve(soln['u_tess'])
    light_curves = []
    for i, (lc_time, lc_dflux, lc_noise, texp, supersample) in enumerate(zip(
        lc_times, lc_dfluxes, lc_uncs, exposure_times, supersampling_factors)):
        model_light_curve = ldlc_obj.get_light_curve(
            orbit=orbit,
            r=np.atleast_1d(soln['ror'])*soln['r_star'],
            t=lc_time,
            texp=texp,
            oversample=supersample,
        ).eval()
        light_curves.append(model_light_curve)
    return light_curves

In [None]:
median_soln

In [None]:
def plot_residual(soln):
    fig, axs = plt.subplots(2, squeeze=False, dpi=300)
    residuals = []
    for row, model_lc, lc_time, lc_dflux in zip(axs, compute_lc(soln), lc_times, lc_dfluxes):
        residual = np.sum(model_lc, axis=1) - lc_dflux
        residuals.append(residual)
        phase = phase_fold(lc_time, soln['t0'], soln['period'])
        row[0].plot(phase, residual, '.')
    return residuals
residuals = plot_residual(map_soln)
num_param = 7
chi2_tot = 0
dof = 0
for res, lc_noise in zip(residuals, lc_uncs):
    chi2 = np.sum(res**2/lc_noise**2)
    chi2_tot += chi2
    dof += len(res)
    print('chi^2: ', np.sum(res**2/lc_noise**2))
    print('reduced chi^2: ', np.sum(res**2/lc_noise**2) / (len(res) - num_param))
print('total reduced chi^2: ', chi2_tot / (dof - num_param))

In [None]:
def plot_residual(soln):
    fig, axs = plt.subplots(2, squeeze=False, dpi=300)
    residuals = []
    for row, model_lc, lc_time, lc_dflux in zip(axs, compute_lc(soln), lc_times, lc_dfluxes):
        residual = np.sum(model_lc, axis=1) - lc_dflux
        residuals.append(residual)
        phase = phase_fold(lc_time, soln['t0'], soln['period'])
        row[0].plot(phase, residual, '.')
    return residuals
residuals = plot_residual(median_soln)
num_param = 7
chi2_tot = 0
dof = 0
for res, lc_noise in zip(residuals, lc_uncs):
    chi2 = np.sum(res**2/lc_noise**2)
    chi2_tot += chi2
    dof += len(res)
    print('chi^2: ', np.sum(res**2/lc_noise**2))
    print('reduced chi^2: ', np.sum(res**2/lc_noise**2) / (len(res) - num_param))
print('total reduced chi^2: ', chi2_tot / (dof - num_param))

In [None]:
def plot_residual(soln):
    fig, axs = plt.subplots(2, squeeze=False, dpi=300)
    residuals = []
    for row, model_lc, lc_time, lc_dflux in zip(axs, compute_lc(soln), lc_times, lc_dfluxes):
        residual = np.sum(model_lc, axis=1) - lc_dflux
        residuals.append(residual)
        phase = phase_fold(lc_time, soln['t0'], soln['period'])
        row[0].plot(phase, residual, '.')
    return residuals
residuals = plot_residual(median_soln)
for res, lc_noise in zip(residuals, lc_uncs):
    print('chi^2: ', np.sum(res**2/lc_noise**2))
    print('reducec chi^2: ', np.sum(res**2/lc_noise**2) / (len(res) - 7))

In [None]:
np.random.rand(123)

In [None]:
az.to_netcdf(traces[-1], 'chains/toi2000_trace_simple_2m_all_free.nc')

In [None]:
az.to_netcdf(traces[-3], 'chains/toi2000_trace_simple_30m_all_free.nc')

In [None]:
trace = az.from_netcdf('chains/toi2000_trace_simple_20s_all_free.nc')

In [None]:
az.summary(trace, round_to=6)

In [None]:
0.081123**2

In [None]:
def calculate_transit_duration_23(period, impact_param, scaled_semimajor_axis, ecc, sin_omega, scaled_planet_radius):
    esinw = sin_omega * np.sqrt(ecc)
    return ((period / np.pi) * np.sqrt(1 - ecc ** 2) / (1 + esinw)
            * np.arcsin(np.sqrt(((1 - scaled_planet_radius) ** 2 - impact_param ** 2)
                                / (scaled_semimajor_axis ** 2 - impact_param ** 2))))

In [None]:
((1 - 0.081123)**2 - 0.946562**2)/((1 + 0.081123)**2 - 0.946562**2)

In [None]:
np.sqrt(((1 - 0.02)**2 - 0.76**2)/((1 + 0.02)**2 - 0.76**2))

In [None]:
((1-0.909471908847024)/(1+0.909471908847024))**2

In [None]:
-2.5*np.log10(0.002247709488158551)  # magnitude difference

In [None]:
az.plot_trace(trace)

In [None]:
def load_fancy_harps_rv(filename):
    df = pd.read_csv(
        filename,
        delim_whitespace=True,
        # usecols=['jdb', 'vrad', 'svrad'],
        skiprows=lambda x: x == 1,
        comment='#',
        )
    return df

harps_0_df = load_fancy_harps_rv(path.join(DATA_DIR, 'TOI-2000_harps_brahm.rdb_'))
harps_time_0 = np.array(harps_0_df.jdb - (aftimeseries.TESS_EPOCH - 2400000))
harps_rv_0 = np.array(harps_0_df.vrad*1000)
harps_rv_unc_0 = np.array(harps_0_df.svrad*1000)

harps_1_df = load_fancy_harps_rv(path.join(DATA_DIR, 'TOI-2000_harps_armstrong.rdb_'))
harps_time_1 = np.array(harps_1_df.jdb - (aftimeseries.TESS_EPOCH - 2400000))
harps_rv_1 = np.array(harps_1_df.vrad*1000)
harps_rv_unc_1 = np.array(harps_1_df.svrad*1000)

# harps_time = np.concatenate([harps_time_0, harps_time_1])
# harps_rv = np.concatenate([harps_rv_0, harps_rv_1])
# harps_rv_unc = np.concatenate([harps_rv_unc_0, harps_rv_unc_1])

# harps_sort_args = np.argsort(harps_time)
# harps_time = harps_time[harps_sort_args]
# harps_rv = harps_rv[harps_sort_args]
# harps_rv_unc = harps_rv_unc[harps_sort_args]

In [None]:
rv_table = aftimeseries.read_generic_rv(path.join(DATA_DIR, 'toi_2000_table_08.csv'))
chiron_time, chiron_rv, chiron_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'CHIRON')
feros_time, feros_rv, feros_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'FEROS')
harps_time, harps_rv, harps_rv_unc = aftimeseries.select_rv_by_instrument(rv_table, 'HARPS')

rv_names = ['chiron', 'feros', 'harps']
rv_times = [chiron_time, feros_time, harps_time]
rv_data = [chiron_rv, feros_rv, harps_rv]
rv_uncs = [chiron_rv_unc, feros_rv_unc, harps_rv_unc]
num_rv_outside = len(rv_data)

rv_times_all = np.concatenate(rv_times)
rv_data_all = np.concatenate(rv_data)
rv_uncs_all = np.concatenate(rv_uncs)

test_gamma = np.array([np.average(i) for i in rv_data])
test_rv_unc = np.array([np.average(i) for i in rv_uncs])
test_jitter = np.array([1e-3]*len(rv_data))

test_params['K'] = np.array([23., 9.])
prior_unc['K'] = np.array([50., 30.])
test_params['rv_gamma'] = test_gamma
prior_unc['rv_gamma'] = np.array([1000.]*num_rv_outside)
test_params['rv_jitter'] = test_jitter
prior_unc['rv_jitter'] = np.array([15., 30., 15.])

test_params['rv_fit_planet'] = np.array([True, True])

test_gamma, test_rv_unc, test_jitter, prior_unc['rv_gamma'], prior_unc['rv_jitter']

In [None]:
harps_df = pd.concat([harps_0_df, harps_1_df])
harps_df.sort_values('jdb', inplace=True)
harps_df.columns

In [None]:
activity_index = ['fwhm', 'bis_span', 's_mw', 'rhk']
fig, axs = plt.subplots(
    len(activity_index),
    figsize=(6, 4*len(activity_index)),
    dpi=144,
    sharex=True,
)
false_alarm_prob = [0.1, 0.05, 0.01]
for idx, ax in zip(activity_index, axs):
    ls_obj = timeseries.LombScargle(harps_df['jdb'], harps_df[idx])
    freq, power = ls_obj.autopower(
        minimum_frequency=1/100.,
        maximum_frequency=2.,
    )
    ax.plot(1/freq, power)
    false_alarm_levels = ls_obj.false_alarm_level(false_alarm_prob)
    for fal in false_alarm_levels:
        ax.axhline(fal, linestyle='--', color='gray')
    for p in [90, 17.5]:
        ax.axvline(p, linestyle='--', color='gray')
    ax.set_xscale('log')
    ax.text(0.05, 0.9, idx.upper(), fontsize='large', transform=ax.transAxes)
    ax.tick_params(which='both', direction='in', top=True, right=True)
    ax.set_ylabel('Normalized power')
ax.set_xlabel('Period (day)')
fig.tight_layout()
fig.savefig('../plots/toi_2000_harps_rv_activity_periodograms.png', bbox_inches='tight')

In [None]:

ls_bis_span = timeseries.LombScargle(
    rv_times_all, np.ones_like(rv_times_all), center_data=False, fit_mean=False)
wn_freq, wn_pow = ls_bis_span.autopower(
    minimum_frequency=1/1000.,
    maximum_frequency=3.,
)
fig, ax = plt.subplots(dpi=144)
ax.plot(1/wn_freq, wn_pow)
ax.set_xscale('log')
ax.set_xlabel('Period (day)')
ax.set_ylabel('Spectral window function')
false_alarm_levels = ls_obj.false_alarm_level(false_alarm_prob)
for fal in false_alarm_levels:
    ax.axhline(fal, linestyle='--', color='gray')
for p in [90, 17.5]:
    ax.axvline(p, linestyle='--', color='gray')

In [None]:
ls_bis_span = timeseries.LombScargle(np.array(harps_df.jdb), np.array(harps_df.bis_span))
plt.plot(*ls_bis_span.autopower())

In [None]:
test_rv_orbit = xo.orbits.KeplerianOrbit(
    period=test_params['period'],
    t0=test_params['t0'],
    # b=impact_param,
    # ecc=ecc,
    # cos_omega=omega_vec[0],
    # sin_omega=omega_vec[1],
    # r_star=radius_star,
)

In [None]:
xo.estimate_semi_amplitude(test_params['period'], rv_times_all, rv_data_all, rv_uncs_all, test_params['t0'])

In [None]:
test_params.update({
    'period_rv_only': np.array([89.9, 17.2]),
    't0_rv_only': np.array([2174.56, 2212.0]),
    'K_rv_only': np.array([15.5, 6.]),
})
prior_unc.update({
    'period_rv_only': np.array([20.]),
    'K_rv_only': np.array([30., 15.]),
})
num_planet_rv_only = 2

In [None]:
55495.07 + 17.208 * ((59220-55495.07)//17.208) - 57000

In [None]:
def make_multi_planet_rv_axes(num_planets, unfolded=True, residuals=True, figure_kwargs={'dpi': 600}):
    figure_kwargs['figsize'] = (7, 8)
    heights = [5, 3]
    fig = plt.figure(constrained_layout=False, **figure_kwargs)
    if unfolded:
        gs = gridspec.GridSpec(2, num_planets, figure=fig, wspace=0.3, hspace=0.25, height_ratios=heights)
        folded_row = 1
        if residuals:
            gs0 = gs[0, :].subgridspec(2, 1, hspace=0, height_ratios=[2, 1])
            unfolded_ax = fig.add_subplot(gs0[0])
            residual_ax = fig.add_subplot(gs0[1], sharex=unfolded_ax)
        else:
            unfolded_ax = fig.add_subplot(gs[0, :])
            residual_ax = None
    else:
        folded_row = 0
        unfolded_ax = None
        residual_ax = None
    folded_axs = []
    for i in range(num_planets):
        folded_axs.append(fig.add_subplot(gs[folded_row, i]))
    return fig, folded_axs, unfolded_ax, residual_ax

In [None]:
rv_data_default_style = {
    'ecolor': 'gray',
    'elinewidth': 1,
    'alpha': 0.7,
    'fmt': 'o',
    'markersize': 3,
}
rv_model_default_style = {
    'color': 'slateblue',
    'zorder': 0,
}
rv_trend_unc_default_style = {
    'color': 'slateblue',
    'alpha': 0.2,
}

def plot_model_rv(num_planets, orbit, rv_semiamps, num_points=500):
    periods = orbit.period.eval()
    epochs = orbit.period.eval()
    model_phase = np.linspace(-0.5, 0.5, num_points)
    model_rvs = []
    for i in range(num_planets):
        model_time = model_phase * orbit.period[i] + orbit.t0[i]
        model_rvs.append(orbit.get_radial_velocity(
            model_time, K=rv_semiamps,
        ).eval()[:, i])
    return model_phase, model_rvs

def plot_multi_planet_folded_rv(
    folded_axs,
    unfolded_ax,
    residual_ax,
    num_planets, orbit,
    rv_semiamps, gammas, jitters,
    rv_times, rv_data, rv_uncs, rv_names,
    trends=None, model_trend_func=None,
    model_trend_unc_func=None,
    rv_data_style=rv_data_default_style,
    rv_model_style=rv_model_default_style,
    rv_inst_styles=None,
    rv_trend_unc_style=rv_trend_unc_default_style,
):
    if rv_inst_styles is None:
        rv_inst_styles = [dict()] * len(rv_times)

    min_obs_time = np.min([np.min(t) for t in rv_times])
    max_obs_time = np.max([np.max(t) for t in rv_times])
    unfold_model_time = np.linspace(min_obs_time, max_obs_time, num=int((max_obs_time-min_obs_time)//0.05))
    unfold_model = np.sum(orbit.get_radial_velocity(
            unfold_model_time, K=rv_semiamps,
    ).eval(), axis=1)
    if unfolded_ax is not None:
        if model_trend_func is not None:
            unfold_model_trend = model_trend_func(unfold_model_time)
            unfold_model += unfold_model_trend
            unfolded_ax.plot(unfold_model_time, unfold_model_trend, color='gray', linestyle='--')
            if model_trend_unc_func is not None:
                unfold_model_trend_unc = model_trend_unc_func(unfold_model_time)
                trend_unc = unfold_model_trend + unfold_model_trend_unc
                unfolded_ax.fill_between(unfold_model_time, trend_unc[0], trend_unc[1], **rv_trend_unc_style)
        unfolded_ax.plot(unfold_model_time, unfold_model, **rv_model_style)
        if residual_ax is not None:
            residual_ax.axhline(0, **rv_model_style)
            if model_trend_unc_func is not None:
                residual_ax.fill_between(unfold_model_time, unfold_model_trend_unc[0], unfold_model_trend_unc[1], **rv_trend_unc_style)

    model_phase, model_rvs = plot_model_rv(num_planets, orbit, rv_semiamps)
    for i, (ax, model_rv) in enumerate(zip(folded_axs, model_rvs)):
        ax.plot(model_phase, model_rv, **rv_model_style)

    epochs = orbit.t0.eval()
    periods = orbit.period.eval()
    if trends is None:
        trends = [0.] * len(rv_times)
    rv_data_style = rv_data_style.copy()
    for rv_time, rv, rv_unc, gamma, jitter, trend, label, plot_style in zip(
        rv_times, rv_data, rv_uncs, gammas, jitters, trends, rv_names, rv_inst_styles):
        rv_errorbar = np.sqrt(rv_unc**2+jitter**2)
        rv_shifted = rv - gamma
        rv_detrend = rv_shifted - trend

        model_rv = orbit.get_radial_velocity(
            rv_time, K=rv_semiamps,
        ).eval()
        model_rv_sum = np.sum(model_rv, axis=1)

        rv_data_style.update(plot_style)
        if unfolded_ax is not None:
            unfolded_ax.errorbar(rv_time, rv_shifted, rv_errorbar, label=label, **rv_data_style)
            if residual_ax is not None:
                residual_ax.errorbar(rv_time, rv_detrend-model_rv_sum, rv_errorbar, label=label, **rv_data_style)

        for i, (ax, period, epoch) in enumerate(zip(folded_axs, periods, epochs)):
            rv_phase = phase_fold(rv_time, epoch, period)
            rv_only_planet = rv_shifted - model_rv_sum + model_rv[:, i] - trend
            ax.errorbar(rv_phase, rv_only_planet, rv_errorbar, label=label, **rv_data_style)
#     return fig, axs

In [None]:
def _add_fixed_eccentricity(fix_flag):
    if all(fix_flag):
        return None, [None, None]
    if not any(fix_flag):
        ecc_vec = pmx.UnitDisk(f'sqrt_ecc_vec', testval=np.array([1e-6, 1e-6]), shape=(2, len(fix_flag)))
        ecc = pm.Deterministic('ecc', tt.sum(ecc_vec*ecc_vec, axis=0))
        omega_vec = ecc_vec / tt.sqrt(ecc)
        omega = pm.Deterministic('omega', tt.arctan2(omega_vec[1], omega_vec[0]))
        return ecc, omega_vec
    ecc_stack = []
    omega_vec_stack = []
    for i, ecc_fix in enumerate(fix_flag):
        if ecc_fix:
            ecc_stack.append(tt.as_tensor_variable(0.))
            omega_vec_stack.append(tt.as_tensor_variable(np.array([0., 1.])))
        else:
            ecc_vec = pmx.UnitDisk(f'sqrt_ecc_vec_{i}', testval=np.array([1e-6, 1e-6]))
            ecc = tt.sum(ecc_vec*ecc_vec, axis=0)
            ecc_stack.append(ecc)
            omega_vec_stack.append(ecc_vec / tt.sqrt(ecc))
    ecc = pm.Deterministic('ecc', tt.stack(ecc_stack, axis=0))
    omega_vec = tt.stack(omega_vec_stack, axis=1)
    omega = pm.Deterministic('omega', tt.arctan2(omega_vec[1], omega_vec[0]))
    return ecc, omega_vec

def polynomial_design_matrix(x, offset, order):
    return np.vander(x-offset, order+1, increasing=True)[:, 1:]

def model_set_up_polynomial_detrend(model, rv_times, order):
    if order < 1:
        raise ValueError('Polynomial must be at least of linear order')
    min_time = min(np.min(t) for t in rv_times)
    max_time = max(np.max(t) for t in rv_times)
    with model:
        rv_offset = (min_time + max_time) / 2.
        pm.Deterministic('rv_time_offset', tt.as_tensor_variable(rv_offset))
        pm.Normal("rv_trend_coeff", mu=0, sd=10.**-np.arange(1, order+1), shape=order)
        for i, t in enumerate(rv_times):
            design_matrix = tt.as_tensor_variable(polynomial_design_matrix(t, rv_offset, order))
            pm.Deterministic(f'rv_design_matrix_{i}', design_matrix)

def trend_generator(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_design_matrix_{i}'] @ model['rv_trend_coeff']

def trend_generator_trace(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_design_matrix_{i}'].data @ model['rv_trend_coeff'].data

def gp_generator(model, num_rv):
    for i in range(num_rv):
        yield model[f'rv_gp_pred_{i}']
        
def gp_matern(map_soln):
    sigma = map_soln['gp_sigma']
    rho = map_soln['gp_rho']
    kernel = terms.Matern32Term(sigma=sigma, rho=rho)
    gp = celerite2.GaussianProcess(kernel)
    return gp

def gp_sho(map_soln):
    sigma = map_soln['gp_sigma']
    rho = map_soln['gp_rho']
    tau = map_soln['gp_tau']
    kernel = terms.SHOTerm(rho=rho, tau=tau, sigma=sigma)
    gp = celerite2.GaussianProcess(kernel)
    return gp

def gp_trend_func_generator(kernel_func, map_soln, rv_times, rv_data, rv_uncs, return_unc=False):
    gp = kernel_func(map_soln)
    
    gp_time = np.concatenate(rv_times)
    gp_sort_args = np.argsort(gp_time)
    gp_time = gp_time[gp_sort_args]
    rv_jitters = map_soln['rv_jitter']
    gp_diag = np.concatenate([unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)])[gp_sort_args]
    gp.compute(gp_time, diag=gp_diag, quiet=True)

    gp_res = np.concatenate([rv - map_soln[f'rv_pred_{i}'] for i, rv in enumerate(rv_data)])
    def gp_trend(t):
        return gp.predict(gp_res, t=t)
    def gp_trend_unc(t):
        mean, variance = gp.predict(gp_res, t=t, return_var=True)
        stdev = np.sqrt(variance)
        return np.vstack((-stdev, stdev))
    if return_unc:
        return gp_trend, gp_trend_unc
    else:
        return gp_trend

def model_set_up_gp_matern32(model, rv_times, rv_residuals, rv_uncs, rv_jitters):
    sigma = pm.Uniform('gp_sigma',
        lower=0., upper=np.sqrt(1000), testval=1.)
    rho = pm.Uniform('gp_rho',
        lower=15./np.sqrt(3), upper=100./np.sqrt(3), testval=16.)
    kernel = theano_terms.Matern32Term(sigma=sigma, rho=rho)
    gp = celerite2.theano.GaussianProcess(kernel)
    
    gp_time = np.concatenate(rv_times)
    gp_sort_args = np.argsort(gp_time)
    gp_time = gp_time[gp_sort_args]
    # print(rv_uncs)
    # print(rv_jitters.eval())
    # print([unc for i, unc in enumerate(rv_uncs)])
    gp_diag = tt.concatenate([unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)])[gp_sort_args]
    gp_res = tt.concatenate(rv_residuals)[gp_sort_args]
    # pm.Deterministic('gp_res', gp_res)
    # mean = pm.Normal('gp_mean', mu=0., sigma=100.)

    # for i, (rv_time, rv_res, rv_unc) in enumerate(zip(rv_times, rv_residuals, rv_uncs)):
        # jitter = rv_jitters[i]
    gp.compute(gp_time, diag=gp_diag, quiet=True)
    gp.marginal('rv_gp_obs', observed=gp_res)
    for i, rv_time in enumerate(rv_times):
        pm.Deterministic(f'rv_gp_pred_{i}', gp.predict(gp_res, t=rv_time))

def model_set_up_gp_sho(model, rv_times, rv_residuals, rv_uncs, rv_jitters):
    sigma = pm.Uniform('gp_sigma',
        lower=0., upper=100., testval=1.)
    rho = pm.Uniform('gp_rho',
        lower=15., upper=200., testval=16.)
    tau = pm.Uniform('gp_tau',
        lower=0., upper=200., testval=16.)
    kernel = theano_terms.SHOTerm(rho=rho, tau=tau, sigma=sigma)
    
    gp_time = np.concatenate(rv_times)
    gp_sort_args = np.argsort(gp_time)
    gp_time = gp_time[gp_sort_args]
    # print(rv_uncs)
    # print(rv_jitters.eval())
    # print([unc for i, unc in enumerate(rv_uncs)])
    gp_diag = tt.concatenate([unc*unc + rv_jitters[i]*rv_jitters[i] for i, unc in enumerate(rv_uncs)])[gp_sort_args]
    gp_res = tt.concatenate(rv_residuals)[gp_sort_args]
    # pm.Deterministic('gp_res', gp_res)
    # mean = pm.Normal('gp_mean', mu=0., sigma=100.)
    gp = celerite2.theano.GaussianProcess(kernel)
    # for i, (rv_time, rv_res, rv_unc) in enumerate(zip(rv_times, rv_residuals, rv_uncs)):
        # jitter = rv_jitters[i]
    gp.compute(gp_time, diag=gp_diag, quiet=True)
    gp.marginal('rv_gp_obs', observed=gp_res)
    for i, rv_time in enumerate(rv_times):
        pm.Deterministic(f'rv_gp_pred_{i}', gp.predict(gp_res, t=rv_time))
    

def make_simple_rv_model(num_planets, rv_data, test_params, prior_unc, ecc_flag):
    with pm.Model() as model:
        # period = pm.Deterministic('period', tt.as_tensor_variable(test_params['period'][1]))
        # epoch = pm.Deterministic('t0', tt.as_tensor_variable(test_params['t0'][1]))
        period = pm.Normal('period', mu=test_params['period'], sigma=prior_unc['period'], shape=num_planets)
        epoch = pm.Normal('t0', mu=test_params['t0'], sigma=prior_unc['t0'], shape=num_planets)
        rv_semiamp = pm.Normal('K', mu=test_params['K'], sigma=prior_unc['K'], shape=num_planets)
        
        ecc, omega_vec = _add_fixed_eccentricity(ecc_flag)

        orbit = xo.orbits.KeplerianOrbit(
            period=period,
            t0=epoch,
            # b=impact_param,
            ecc=ecc,
            cos_omega=omega_vec[0],
            sin_omega=omega_vec[1],
            # r_star=radius_star,
        )

        num_rvs = len(rv_data)
        pm.Normal('rv_gamma', mu=test_params['rv_gamma'], sigma=prior_unc['rv_gamma'], shape=num_rvs)
        pm.HalfNormal(
            'rv_jitter', sigma=prior_unc["rv_jitter"], shape=num_rvs)
        
#         model_set_up_polynomial_detrend(model, rv_times, 2)
        
#         trend_func = trend_generator
#         if trend_func is None:
#             def trend_func(model, num_rv):
#                 for i in range(num_rv):
#                     yield 0.

        rv_residuals = []
        for i, (rv_time, rv, rv_unc) in enumerate(zip(
            rv_times, rv_data, rv_uncs)): #, trend_func(model, num_rvs))):
            rv_predict = orbit.get_radial_velocity(rv_time, K=model.K)
            rv_predict_sum = tt.sum(rv_predict, axis=1) + model.rv_gamma[i] # + trend

            pm.Deterministic(f'rv_pred_{i}', rv_predict_sum)
            rv_residuals.append(rv - rv_predict_sum)

            # total_unc = tt.sqrt(rv_unc * rv_unc + model.rv_jitter[i] * model.rv_jitter[i])
            # obs_likelihood = pm.Normal(
            #     f"rv_obs_{i}",
            #     mu=rv_predict_sum,
            #     sd=total_unc,
            #     observed=rv,
            # )

        model_set_up_gp_sho(model, rv_times, rv_residuals, rv_uncs, model.rv_jitter)

        pm.Deterministic('log_prob', model.logpt)
    return model

In [None]:
rv_model = make_simple_rv_model(2, [False, True])

rv_map_soln = optimize_model(
    rv_model,
    [
        ['rv_jitter', 'rv_gamma', 'K'],  #, 'gp_sigma', 'gp_rho', 'gp_tau'],#, 'rv_trend_coeff'],
        ['t0', 'period'],
        ['gp_sigma', 'gp_rho', 'gp_tau'],
    ]
)
rv_map_soln

In [None]:
def make_rv_model_with_rv_only_planets(
    num_planets, ecc_flag, num_planets_rv_only,
    rv_times, rv_data, rv_uncs,
    test_params, prior_unc):
    with pm.Model() as model:
        # period = pm.Deterministic('period', tt.as_tensor_variable(test_params['period'][1]))
        # epoch = pm.Deterministic('t0', tt.as_tensor_variable(test_params['t0'][1]))
        period = pm.Normal('period', mu=test_params['period'], sigma=prior_unc['period'], shape=num_planets)
        epoch = pm.Normal('t0', mu=test_params['t0'], sigma=prior_unc['t0'], shape=num_planets)
        rv_semiamp = pm.Normal('K', mu=test_params['K'], sigma=prior_unc['K'], shape=num_planets)
        
        ecc, omega_vec = _add_fixed_eccentricity(ecc_flag)

        orbit = xo.orbits.KeplerianOrbit(
            period=period,
            t0=epoch,
            # b=impact_param,
            ecc=ecc,
            cos_omega=omega_vec[0],
            sin_omega=omega_vec[1],
            # r_star=radius_star,
        )

        period_rv_only = add_uniform_prior('period_rv_only', test_params, prior_unc, shape=num_planets_rv_only)
        rv_semiamp_rv_only = add_uniform_prior('K_rv_only', test_params, prior_unc, shape=num_planets_rv_only)
        epoch_rv_only = pm.Uniform(
            't0_rv_only', lower=test_params['t0_rv_only']-period_rv_only/2,
            upper=test_params['t0_rv_only']+period_rv_only/2,
            testval=test_params['t0_rv_only'],
            shape=num_planets_rv_only)
        orbit_rv_only = xo.orbits.KeplerianOrbit(
            period=period_rv_only,
            t0=epoch_rv_only,
        )

        num_rvs = len(rv_data)
        pm.Normal('rv_gamma', mu=test_params['rv_gamma'], sigma=prior_unc['rv_gamma'], shape=num_rvs)
        pm.HalfNormal(
            'rv_jitter', sigma=prior_unc["rv_jitter"], shape=num_rvs)
        
#         model_set_up_polynomial_detrend(model, rv_times, 2)
        
#         trend_func = trend_generator
#         if trend_func is None:
#             def trend_func(model, num_rv):
#                 for i in range(num_rv):
#                     yield 0.

        rv_residuals = []
        for i, (rv_time, rv, rv_unc) in enumerate(zip(
            rv_times, rv_data, rv_uncs)): #, trend_func(model, num_rvs))):
            rv_predict = pm.Deterministic(f'rv_pred_{i}', orbit.get_radial_velocity(rv_time, K=model.K))
            rv_only_pred = pm.Deterministic(
                f'rv_only_pred_{i}',
                orbit_rv_only.get_radial_velocity(rv_time, K=rv_semiamp_rv_only))
            rv_predict_sum = (
                tt.sum(rv_predict, axis=1)
                + (tt.sum(rv_only_pred, axis=1) if num_planets_rv_only > 1 else rv_only_pred)
                + model.rv_gamma[i] # + trend
            )

            rv_residuals.append(rv - rv_predict_sum)

            total_unc = tt.sqrt(rv_unc * rv_unc + model.rv_jitter[i] * model.rv_jitter[i])
            obs_likelihood = pm.Normal(
                f"rv_obs_{i}",
                mu=rv_predict_sum,
                sd=total_unc,
                observed=rv,
            )

        # model_set_up_gp_sho(model, rv_times, rv_residuals, rv_uncs, model.rv_jitter)

        pm.Deterministic('log_prob', model.logpt)
    return model

In [None]:
rv_model = make_rv_model_with_rv_only_planets(
    2, [False, True], 2, rv_times, rv_data, rv_uncs, test_params, prior_unc)

rv_map_soln = optimize_model(
    rv_model,
    [
        ['rv_jitter', 'rv_gamma', 'K', 'K_rv_only'],  #, 'gp_sigma', 'gp_rho', 'gp_tau'],#, 'rv_trend_coeff'],
        ['t0', 'period', 't0_rv_only', 'period_rv_only'],
        # ['gp_sigma', 'gp_rho', 'gp_tau'],
    ]
)
rv_map_soln

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(4)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')
rv_folded_axs[0].set_title('c')
rv_folded_axs[1].set_title('b')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')

plot_orbit = xo.orbits.KeplerianOrbit(
    **{k: rv_map_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_orbit = xo.orbits.KeplerianOrbit(
    period=np.concatenate([rv_map_soln['period'], rv_map_soln['period_rv_only']]),
    t0=np.concatenate([rv_map_soln['t0'], rv_map_soln['t0_rv_only']]),
    ecc=np.concatenate([rv_map_soln['ecc'], [0.]*num_planet_rv_only]),
    omega=np.concatenate([rv_map_soln['omega'], [np.pi/2]*num_planet_rv_only]),
)

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': 'gray',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^'},
    {'color': '#a12568', 'fmt': 'd'},
    {'color': '#fec260', 'fmt': 'o'},
]

# rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
#     gp_sho, rv_map_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs, rv_unfolded_ax, rv_residual_ax,
    4, plot_orbit, np.concatenate([rv_map_soln['K'], rv_map_soln['K_rv_only']]),
    rv_map_soln['rv_gamma'],
    rv_map_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_map_soln, len(rv_data)),
    # trends=gp_generator(rv_map_soln, len(rv_data)),
    # model_trend_func=rv_trend_func,
    # model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_map_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_map_soln['rv_time_offset'], 2) @ rv_map_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')

for ax, p in zip(rv_folded_axs, rv_map_soln["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)


In [None]:
rv_traces = []

In [None]:
with rv_model:
    trace = pmx.sample(
        tune=5000,
        draws=5000,
        start=rv_map_soln,
        cores=32,
        chains=32,
        # initial_accept=0.5,
        target_accept=0.97,
        return_inferencedata=True,
        idata_kwargs={
            'log_likelihood': False,
        },
#         parameter_groups=[
#             pmx.ParameterGroup([simple_model.period, simple_model.t0]),
#             pmx.ParameterGroup([simple_model.tdur, simple_model.rp, simple_model.b]),
#             pmx.ParameterGroup([simple_model.mean_flux, simple_model.lc_jitter]),
#         ],
    )
rv_traces.append(trace)

In [None]:
trace['posterior']

In [None]:
trace.to_netcdf(
    '../chains/toi2000_rv_only_90d_17d.nc')
    # groups=['posterior', 'log_likelihood', 'sample_stats'])

In [34]:
trace = az.from_netcdf('../chains/toi2000_rv_only_90d_17d.nc')
trace_3p = az.from_netcdf('../chains/toi2000_rv_only_90d.nc')

In [None]:
display_var_names = [
    'period', 't0',
    'sqrt_ecc_vec_0',
    'K',
    'period_rv_only', 't0_rv_only', 'K_rv_only',
    'rv_gamma', 'rv_jitter',
    # 'gp_sigma', 'gp_rho', 'gp_tau',
]
az.plot_trace(trace, var_names=display_var_names)
plt.tight_layout()

In [35]:
display_var_names = [
    'period', 't0',
    'sqrt_ecc_vec_0', 'ecc', 'omega',
    'K',
    'period_rv_only', 't0_rv_only', 'K_rv_only',
    'rv_gamma', 'rv_jitter',
    # 'gp_sigma', 'gp_rho', 'gp_tau',
]
summary_3p = az.summary(
    trace_3p, var_names=display_var_names,
    round_to=7,
    hdi_prob=0.997,
    skipna=True,
    # kind='stats',
    # stat_funcs={
    #     median
    # }
    coords={"ecc_dim_0": [0], "omega_dim_0": [0]},
    circ_var_names={'omega'},
    stat_funcs={
        'median': np.median,
        '-': lambda x: np.quantile(x, 0.16) - np.median(x),
        '+': lambda x: np.quantile(x, 0.84) - np.median(x),
    },
)
summary_3p

  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Unnamed: 0,mean,sd,hdi_0.15%,hdi_99.85%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat,median,-,+
period[0],9.127056,0.0001,9.126761,9.127354,2e-07,1e-07,262225.955838,117714.59408,1.00028,9.127056,-0.0001,0.0001
period[1],3.09829,0.000197,3.097719,3.098874,4e-07,3e-07,257859.466112,121590.794132,1.000218,3.09829,-0.000195,0.000196
t0[0],2110.065875,0.004998,2110.051,2110.080903,9.8e-06,6.9e-06,262318.129617,121999.444219,1.000362,2110.065872,-0.004971,0.004973
t0[1],1855.241018,0.019753,1855.182,1855.299321,3.87e-05,2.74e-05,260601.344009,122573.739122,1.000334,1855.240985,-0.019635,0.019674
sqrt_ecc_vec_0[0],-0.184189,0.081018,-0.3603775,0.120038,0.0002345,0.0001659,149831.971336,85019.078465,1.000097,-0.19573,-0.064607,0.087207
sqrt_ecc_vec_0[1],-0.060607,0.167664,-0.4487414,0.377508,0.000468,0.0003309,132601.205648,136575.963603,1.000104,-0.070283,-0.16688,0.193103
ecc[0],0.072274,0.037105,3e-07,0.20803,9.64e-05,7.06e-05,146504.093423,101417.132021,1.00004,0.069299,-0.03308,0.037591
ecc[1],0.0,0.0,0.0,0.0,0.0,0.0,160000.0,160000.0,,0.0,0.0,0.0
omega[0],-2.869136,0.784226,0.3811322,-0.214577,0.0030235,0.0052202,138142.126792,152059.054436,1.000045,-2.082607,-0.681459,4.764533
omega[1],1.570796,-0.0,1.570796,1.570796,-0.0,0.0,160000.0,160000.0,,1.570796,0.0,0.0


In [36]:
display_var_names = [
    'period', 't0',
    'sqrt_ecc_vec_0', 'ecc', 'omega',
    'K',
    'period_rv_only', 't0_rv_only', 'K_rv_only',
    'rv_gamma', 'rv_jitter',
    # 'gp_sigma', 'gp_rho', 'gp_tau',
]
summary = az.summary(
    trace, var_names=display_var_names,
    round_to=7,
    hdi_prob=0.997,
    skipna=True,
    # kind='stats',
    # stat_funcs={
    #     median
    # }
    coords={"ecc_dim_0": [0], "omega_dim_0": [0]},
    circ_var_names={'omega'},
    stat_funcs={
        'median': np.median,
        '-': lambda x: np.quantile(x, 0.16) - np.median(x),
        '+': lambda x: np.quantile(x, 0.84) - np.median(x),
    },
)
summary

  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)


Unnamed: 0,mean,sd,hdi_0.15%,hdi_99.85%,mcse_mean,mcse_sd,ess_bulk,ess_tail,r_hat,median,-,+
period[0],9.127056,0.0001,9.126766,9.127364,2e-07,1e-07,243749.305372,122461.135197,1.000234,9.127056,-9.9e-05,9.9e-05
period[1],3.098302,0.000189,3.097733,3.098856,4e-07,3e-07,234454.822184,123894.147924,1.000234,3.098302,-0.000188,0.000188
t0[0],2110.065927,0.004993,2110.050956,2110.08059,1.02e-05,7.2e-06,240825.593138,123000.694298,1.000112,2110.06593,-0.004957,0.004966
t0[1],1855.242212,0.019475,1855.183172,1855.297992,4.03e-05,2.85e-05,233796.561716,123893.727777,1.000166,1855.242223,-0.019384,0.019331
sqrt_ecc_vec_0[0],-0.169048,0.062888,-0.325116,0.077174,0.0001835,0.0001297,140690.26243,77700.766452,1.000086,-0.175135,-0.054151,0.064783
sqrt_ecc_vec_0[1],-0.129944,0.134262,-0.430034,0.261885,0.0004193,0.0002965,114475.919565,104037.982415,1.000132,-0.148567,-0.11456,0.162061
ecc[0],0.067444,0.0324,6e-06,0.181892,8.48e-05,6e-05,133590.064694,92895.048559,1.00003,0.063981,-0.027412,0.03442
ecc[1],0.0,0.0,0.0,0.0,0.0,0.0,160000.0,160000.0,,0.0,0.0,0.0
omega[0],-2.550544,0.61427,0.785852,-0.645068,0.0009644,0.0045649,123200.11414,149058.923817,1.000095,-2.218772,-0.504687,4.336813
omega[1],1.570796,-0.0,1.570796,1.570796,-0.0,0.0,160000.0,160000.0,,1.570796,0.0,0.0


In [None]:
display_var_names = [
    'period', 't0',
    'sqrt_ecc_vec_0', 'ecc', 'omega',
    'K',
    'rv_gamma', 'rv_jitter',
    'gp_sigma', 'gp_rho', 'gp_tau',
]
summary = az.summary(
    trace, var_names=display_var_names,
    round_to=7,
    hdi_prob=0.997,
    skipna=True,
    # kind='stats',
    # stat_funcs={
    #     median
    # }
    coords={"ecc_dim_0": [0], "omega_dim_0": [0]},
    circ_var_names={'omega'},
    stat_funcs={
        'median': np.median,
        '-': lambda x: np.quantile(x, 0.16) - np.median(x),
        '+': lambda x: np.quantile(x, 0.84) - np.median(x),
    },
)
summary

In [37]:
rv_flat_samples = trace.posterior.stack(sample=("chain", "draw"))
rv_median_soln = {k:v.data for k, v in rv_flat_samples.median(dim='sample').items()}
rv_max_post_index = rv_flat_samples.log_prob.argmax(dim='sample')
rv_max_post_soln = {k:v.data for k, v in rv_flat_samples[{'sample': rv_max_post_index}].items()}

In [38]:
astro.calculate_min_planet_mass_earth(
    np.array([22.807773, 5.948806, 15.426891, 6.332211]),
    np.array([0.063981, 0, 0, 0]),
    np.array([9.127056, 3.098302, 90.735801, 17.289302]),
    1.082411,
)

array([  78.41729143,   14.29707418,  114.28276836,   26.99338278])

In [39]:
rng = np.random.default_rng()
random_size = np.array(rv_flat_samples['K']).shape
np.quantile(
    astro.calculate_min_planet_mass_earth(
        np.array(rv_flat_samples['K']),
        np.array(rv_flat_samples['ecc']),
        np.array(rv_flat_samples['period']),
        np.random.normal(1.082411, 0.06, random_size),
    ), np.array([0.16, 0.5, 0.84]), axis=1)

array([[ 73.70009743,  11.75306891],
       [ 78.30158959,  14.26916744],
       [ 83.00635799,  16.82592233]])

In [40]:
rng = np.random.default_rng()
random_size = np.array(rv_flat_samples['K']).shape
np.quantile(
    astro.calculate_min_planet_mass_earth(
        np.array(rv_flat_samples['K']),
        0.,
        np.array(rv_flat_samples['period']),
        np.random.normal(1.082411, 0.06, random_size),
    ), np.array([0.16, 0.5, 0.84]), axis=1)

array([[ 73.9525846 ,  11.76010221],
       [ 78.50904247,  14.26861454],
       [ 83.18823412,  16.82321847]])

In [43]:
rng = np.random.default_rng()
random_size = np.array(rv_flat_samples['K']).shape
np.quantile(
    astro.calculate_min_planet_mass_earth(
        np.vstack([
            np.random.normal(23.7, 1.0, random_size[1]),
            np.random.normal(4.59, 1.0, random_size[1]),
        ]),
        0.,
        np.array(rv_flat_samples['period']),
        np.random.normal(1.082411, 0.06, random_size),
    ), np.array([0.16, 0.5, 0.84]), axis=1)

array([[ 77.07953907,   8.60121837],
       [ 81.57870989,  11.01865788],
       [ 86.21072454,  13.44872796]])

In [44]:
rng = np.random.default_rng()
random_size = np.array(rv_flat_samples['K_rv_only']).shape
np.quantile(
    astro.calculate_min_planet_mass_earth(
        np.array(rv_flat_samples['K_rv_only']),
        0.,
        np.array(rv_flat_samples['period_rv_only']),
        np.random.normal(1.082411, 0.06, random_size),
    ), np.array([0.16, 0.5, 0.84]), axis=1)

array([[ 104.25013667,   22.13895974],
       [ 114.13631714,   26.94245281],
       [ 124.36671844,   31.78584776]])

In [45]:
104.30264212 - 114.13412687, 124.37240489 - 114.13412687

(-9.831484750000001, 10.238278019999996)

In [46]:
31.77303186-26.93358525, 22.11476994-26.93358525

(4.8394466099999995, -4.8188153100000015)

In [47]:
np.quantile(
    (astro.calculate_min_planet_mass_earth(
        np.array(rv_flat_samples['K']),
        np.array(rv_flat_samples['ecc']),
        np.array(rv_flat_samples['period']),
        np.random.normal(1.082411, 0.06, random_size),
    ) /
    np.sin(np.vstack([
        np.random.normal(87.94, 0.13, random_size[1]),
        np.random.normal(84.73, 0.5, random_size[1])])/180*np.pi)),
    np.array([0.16, 0.5, 0.84]), axis=1)

array([[ 73.74526564,  11.80280977],
       [ 78.35712706,  14.331918  ],
       [ 83.06024511,  16.89949275]])

In [48]:
np.array([73.76581984, 83.03444882]) - 78.33842635, np.array([11.80104626, 16.90607814]) - 14.33273728

(array([-4.57260651,  4.69602247]), array([-2.53169102,  2.57334086]))

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(4)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')
rv_folded_axs[0].set_title('c')
rv_folded_axs[1].set_title('b')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')

plot_orbit = xo.orbits.KeplerianOrbit(
    period=np.concatenate([rv_max_post_soln['period'], rv_max_post_soln['period_rv_only']]),
    t0=np.concatenate([rv_max_post_soln['t0'], rv_max_post_soln['t0_rv_only']]),
    ecc=np.concatenate([rv_max_post_soln['ecc'], [0.]*2]),
    omega=np.concatenate([rv_max_post_soln['omega'], [np.pi/2]*2]),
)

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': 'gray',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^'},
    {'color': '#a12568', 'fmt': 'd'},
    {'color': '#fec260', 'fmt': 'o'},
]

# rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
#     gp_sho, rv_max_post_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs, rv_unfolded_ax, rv_residual_ax,
    4, plot_orbit, np.concatenate([rv_max_post_soln['K'], rv_max_post_soln['K_rv_only']]),
    rv_max_post_soln['rv_gamma'],
    rv_max_post_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_max_post_soln, len(rv_data)),
    # trends=gp_generator(rv_max_post_soln, len(rv_data)),
    # model_trend_func=rv_trend_func,
    # model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_max_post_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_max_post_soln['rv_time_offset'], 2) @ rv_max_post_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')

for ax, p in zip(rv_folded_axs, rv_max_post_soln["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)

fig.savefig('../plots/toi_2000_rv_only_90d_17d.pdf', bbox_inches='tight')

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = plotting.make_multi_planet_rv_axes(3)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')
rv_folded_axs[0].set_title('c')
rv_folded_axs[1].set_title('b')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')

plot_orbit = xo.orbits.KeplerianOrbit(
    **{k: rv_map_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_orbit = xo.orbits.KeplerianOrbit(
    period=np.concatenate([rv_map_soln['period'], rv_map_soln['period_rv_only']]),
    t0=np.concatenate([rv_map_soln['t0'], rv_map_soln['t0_rv_only']]),
    ecc=np.concatenate([rv_map_soln['ecc'], [0.]*num_planet_rv_only]),
    omega=np.concatenate([rv_map_soln['omega'], [np.pi/2]*num_planet_rv_only]),
)

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': 'gray',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^'},
    {'color': '#a12568', 'fmt': 'd'},
    {'color': '#fec260', 'fmt': 'o'},
]

# rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
#     gp_sho, rv_map_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plotting.plot_multi_planet_folded_rv(
    rv_folded_axs, rv_unfolded_ax, rv_residual_ax,
    3, plot_orbit, np.concatenate([rv_map_soln['K'], rv_map_soln['K_rv_only']]),
    rv_map_soln['rv_gamma'],
    rv_map_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_map_soln, len(rv_data)),
    # trends=gp_generator(rv_map_soln, len(rv_data)),
    # model_trend_func=rv_trend_func,
    # model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(gp_sho, rv_map_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_map_soln['rv_time_offset'], 2) @ rv_map_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')

for ax, p in zip(rv_folded_axs, rv_map_soln["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)


In [None]:
def rv_bic(rv_soln, rv_vals, rv_uncs, num_params):
    num_points = 0
    log_likelihood = 0
    for i, (my_rv, my_unc, my_jitter) in enumerate(zip(rv_vals, rv_uncs, rv_soln['rv_jitter'])):
        num_points += len(my_rv)
        log_likelihood += priors.log_prob_gaussian(
            rv_soln[f'rv_pred_{i}'],# + rv_soln[f'rv_gp_pred_{i}'],
            my_rv,
            my_unc*my_unc + my_jitter*my_jitter)
    return num_params * np.log(num_points) - 2 * log_likelihood

In [None]:
rv_bic(rv_max_post_soln, rv_data, rv_uncs, 19)

In [None]:
rv_bic(rv_max_post_soln, rv_data, rv_uncs, 16)

In [None]:
600.0033015661337-513.21787135048135

In [None]:
rv_max_post_soln

In [None]:
rv_flat_samples.quantile([1-0.999999426696856, 1-0.999993204653751, 1-0.999936657516334], dim='sample')

In [None]:
rv_max_post_soln['period_rv_only']

In [None]:
fig, rv_folded_axs, rv_unfolded_ax, rv_residual_ax = make_multi_planet_rv_axes(2)
rv_unfolded_ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_ylabel('Residuals ($\mathrm{m}\,\mathrm{s}^{-1}$)')
rv_residual_ax.set_xlabel('$\mathrm{BJD}_\mathrm{TDB} - 2{,}457{,}000$')
rv_folded_axs[0].set_title('b')
rv_folded_axs[1].set_title('c')
for ax in rv_folded_axs:
    ax.set_xlabel('Phase')
    ax.set_ylabel('RV ($\mathrm{m}\,\mathrm{s}^{-1}$)')

plot_orbit = xo.orbits.KeplerianOrbit(
    **{k: rv_max_post_soln[k] for k in ['period', 't0', 'ecc', 'omega']})

plot_rv_data_style = {
    'markeredgecolor': '#2a0944',
    'markeredgewidth': 0.5,
    'alpha': 0.7,
    'markersize': 3,
    'ecolor': 'gray',
    'elinewidth': 1,
}
    
plot_rv_inst_styles = [
    {'color': '#3b185f', 'fmt': '^'},
    {'color': '#a12568', 'fmt': 'd'},
    {'color': '#fec260', 'fmt': 'o'},
]

rv_trend_func, rv_trend_unc_func = gp_trend_func_generator(
    gp_sho, rv_max_post_soln, rv_times, rv_data, rv_uncs, return_unc=True)

plot_multi_planet_folded_rv(
    rv_folded_axs[::-1], rv_unfolded_ax, rv_residual_ax,
    2, plot_orbit, rv_max_post_soln['K'],
    rv_max_post_soln['rv_gamma'],
    rv_max_post_soln['rv_jitter'],
    rv_times,
    rv_data,
    rv_uncs,
    ['CHIRON', 'FEROS', 'HARPS'],
    # trends=trend_generator(rv_max_post_soln, len(rv_data)),
    trends=gp_generator(rv_max_post_soln, len(rv_data)),
    model_trend_func=rv_trend_func,
    model_trend_unc_func=rv_trend_unc_func,
    # model_trend_func=gp_trend_func_generator(rv_max_post_soln, rv_times, rv_data, rv_uncs),
    # model_trend_func=lambda t: polynomial_design_matrix(t, rv_max_post_soln['rv_time_offset'], 2) @ rv_max_post_soln['rv_trend_coeff'],
    rv_data_style=plot_rv_data_style,
    rv_inst_styles=plot_rv_inst_styles,
)

rv_unfolded_ax.set_xlim(1905, 2375)
rv_unfolded_ax.legend(loc='lower center')

for ax, p in zip(rv_folded_axs[::-1], rv_max_post_soln["period"]):
    ax.text(0.05, 0.05, f'$P = {p:.6f}\,\mathrm{{d}}$', transform=ax.transAxes)

fig.savefig('plot/toi2000_rv_only_gp_sho.pdf', bbox_inches='tight')

In [None]:
display_var_names = [
    'period', 't0', 'sqrt_ecc_vec_0', 'K',
    'rv_gamma', 'rv_jitter',
    # 'gp_sigma', 'gp_rho',
]
_ = corner.corner(
    trace,
    quantiles=(0.16, 0.5, 0.84),
    levels=1-np.exp(-np.array([0.5, 1, 1.5, 2, 2.5])**2/2),
    var_names=display_var_names,
)