In [10]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
from scipy.stats import norm
import matplotlib

import jax.numpy as jnp
from jax import config, random
import numpyro, jax
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, init_to_value
config.update('jax_enable_x64', True)
numpyro.set_platform('cpu') 
num_chains = 4
numpyro.set_host_device_count(num_chains)
print ('# jax device count:', jax.local_device_count())

from jnkepler.jaxttv import JaxTTV
from jnkepler.jaxttv import ttv_default_parameter_bounds, ttv_optim_curve_fit, scale_pdic
import corner

# jax device count: 1


In [15]:
### get the best fit params and param bounds from the 3_jnkep_minimizer_fit file
import pickle

with open('jnkep_initial_fit_data.pkl', 'rb') as f:
    data = pickle.load(f)

popt = data['popt']
param_bounds = data['param_bounds']
jttv = data['jttv']

print(popt)
print(param_bounds)
dir(jttv)

{'period': array([ 7.91977991, 11.90383163]), 'ecosw': array([-0.12634932, -0.08642707]), 'esinw': array([-0.14745586, -0.17554991]), 'tic': array([1980.38363536, 1984.27272814]), 'lnpmass': array([ -9.0762036 , -10.13875893]), 'pmass': Array([1.14354921e-04, 3.95178166e-05], dtype=float64)}
{'tic': [array([1980.33403, 1984.22227]), array([1980.43403, 1984.32227])], 'period': [array([ 7.84055446, 11.78382531]), array([ 7.9989495 , 12.02188239])], 'ecosw': [array([-0.25, -0.25]), array([0.24, 0.24])], 'esinw': [array([-0.25, -0.25]), array([0.24, 0.24])], 'lnpmass': [array([-16.11809565, -16.11809565]), array([-6.90775528, -6.90775528])], 'pmass': [array([1.e-07, 1.e-07]), array([0.001, 0.001])]}


['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'check_residuals',
 'check_timing_precision',
 'dt',
 'errorobs',
 'errorobs_flatten',
 'get_transit_times_all',
 'get_transit_times_all_list',
 'get_transit_times_and_rvs_obs',
 'get_transit_times_obs',
 'linear_ephemeris',
 'nitr_kepler',
 'nitr_transit',
 'nplanet',
 'p_init',
 'pidx',
 'plot_model',
 'sample_means_and_stds',
 'set_tcobs',
 't_end',
 't_start',
 'tcall_linear',
 'tcobs',
 'tcobs_flatten',
 'tcobs_linear',
 'times',
 'transit_time_method']

# Set up and run HMC

In [6]:
def model_scaled(sample_keys, param_bounds):
    """numpyro model for scaled parameters"""
    par = {}

    # sample parameters from priors
    for key in sample_keys:
        par[key+"_scaled"] = numpyro.sample(key+"_scaled", dist.Uniform(param_bounds[key][0]*0, param_bounds[key][0]*0+1.))
        par[key] = numpyro.deterministic(key, par[key+"_scaled"] * (param_bounds[key][1] - param_bounds[key][0]) + param_bounds[key][0])
    if "pmass" not in sample_keys:
        par["pmass"] = numpyro.deterministic("pmass", jnp.exp(par["lnpmass"]))
    
    # Jacobian for uniform ecc prior
    ecc = numpyro.deterministic("ecc", jnp.sqrt(par['ecosw']**2+par['esinw']**2))
    numpyro.factor("eprior", -jnp.log(ecc))

    # compute transit times
    tcmodel, ediff = jttv.get_transit_times_obs(par)
    numpyro.deterministic("ediff", ediff)
    numpyro.deterministic("tcmodel", tcmodel)
    
    # likelihood
    tcerrmodel = jttv.errorobs_flatten     
    numpyro.sample("obs", dist.Normal(loc=tcmodel, scale=tcerrmodel), obs=jttv.tcobs_flatten)

In [7]:
# physical parameters to sample from
sample_keys = ["ecosw", "esinw", "pmass", "period", "tic"] # uniform mass prior

In [8]:
# scaled parameters
pdic_scaled = scale_pdic(popt, param_bounds)

In [9]:
kernel = NUTS(model_scaled, 
            init_strategy=init_to_value(values=pdic_scaled), 
            dense_mass=True,
            #regularize_mass_matrix=False # this speeds up sampling for unknown reason
            )

In [None]:
mcmc = MCMC(kernel, num_warmup=500, num_samples=1500, num_chains=num_chains)

In [None]:
# 4hr30min on M1 mac studio
rng_key = random.PRNGKey(0)
mcmc.run(rng_key, sample_keys, param_bounds, extra_fields=('potential_energy', 'num_steps', 'adapt_state'))

In [None]:
mcmc.print_summary()

In [None]:
# save results
import dill
with open("jnkep_fit_full.pkl", "wb") as f:
    dill.dump(mcmc, f)

# Plot models drawn from posteriors

In [None]:
samples = mcmc.get_samples()

In [None]:
means, stds = jttv.sample_means_and_stds(samples)

In [None]:
jttv.plot_model(means, tcmodelunclist=stds)

# Trace and corner plots

In [None]:
import arviz as az
idata = az.from_numpyro(mcmc)
fig = az.plot_trace(mcmc, var_names=sample_keys, compact=False)
plt.tight_layout(pad=0.2)

In [None]:
idata.posterior['mu'] = idata.posterior['pmass'] / 3.003e-6
names = ["period", "tic", "ecosw", "esinw", "mu"]
fig = corner.corner(idata, var_names=names, show_titles=True)