# Modelling K2-24 with `simppler`

This is a reproduction of the [RadVel tutorial](https://radvel.readthedocs.io/en/latest/tutorials/K2-24_Fitting%2BMCMC.html) on the same dataset.
Hopefully this can provide a useful comparison between how to implement similar models with the two packages (note that `simppler` models have a `.to_radvel()` method to easily convert models).

## Importing the data

Let us first load the K2-24 observations directly from the RadVel repository, extract the relevant columns, and display it.

In [None]:
from pandas import read_csv

url = "https://raw.githubusercontent.com/California-Planet-Search/radvel/refs/heads/master/example_data/epic203771098.csv"
df = read_csv(url, index_col=0)
df.head()

In [None]:
import matplotlib.pyplot as plt

t = df.t.values
vel = df.vel.values
errvel = df.errvel.values

def plot_data():
    plt.figure(figsize=(12, 4))
    plt.errorbar(t, vel, yerr=errvel, fmt="k.", capsize=2, mfc="w", label="Data")
    plt.xlabel("Time [d]")
    plt.ylabel("RV [m/s]")
plot_data()
plt.title("RVs of K2-24")
plt.show()

## Building the model

To build `simppler` model, we must first specify our parameters as prior distributions.

To follow the RadVel tutorial, we will fix some parameters.
`simpple` allows us to do this with a `Fixed` distribution in the prior.
Parameters with a fixed distribution will not be included in the model dimensions or keys by default, but are registered and passed to the forward model when needed.
See the [dedicated simpple tutorial](https://simpple.readthedocs.io/en/stable/tutorials/fixed-parameters.html) on this topic for more info.

We will use a builder function to easily create models with varying subsets of fixed parameters

In [None]:
import numpy as np
import simppler.model as smod
from simpple import distributions as sdist
periods = [20.8851, 42.3633]
period_errs = [0.0003, 0.0006]
t0s = [2072.7948, 2082.6251]
t0_errs = [0.0007, 0.0004]
def build_model(vary):
    # TODO: Eccentricity constraint
    if vary == "all":
        parameters = {
            "per1": sdist.Normal(periods[0], period_errs[0]),
            "tc1": sdist.Normal(t0s[0], t0_errs[0]),
            "secosw1": sdist.Uniform(-1, 1),
            "sesinw1": sdist.Uniform(-1, 1),
            "logk1": sdist.Normal(np.log(5), 10),
            "per2": sdist.Normal(periods[1], period_errs[1]),
            "tc2": sdist.Normal(t0s[1], t0_errs[1]),
            "secosw2": sdist.Uniform(-1, 1),
            "sesinw2": sdist.Uniform(-1, 1),
            "logk2": sdist.Normal(np.log(5), 10),
            "gamma": sdist.Normal(0, 10.0),
            "dvdt": sdist.Normal(0, 1.0),
            "curv": sdist.Normal(0, 1e-1),
            "jit": sdist.Normal(np.log(3), 0.5),
        }
    elif vary == "ecc":
        parameters = {
            "per1": sdist.Fixed(periods[0]),
            "tc1": sdist.Fixed(t0s[0]),
            "secosw1": sdist.Uniform(-1, 1),
            "sesinw1": sdist.Uniform(-1, 1),
            "logk1": sdist.Normal(np.log(5), 10),
            "per2": sdist.Fixed(periods[1]),
            "tc2": sdist.Fixed(t0s[1]),
            "secosw2": sdist.Uniform(-1, 1),
            "sesinw2": sdist.Uniform(-1, 1),
            "logk2": sdist.Normal(np.log(5), 10),
            "gamma": sdist.Normal(0, 10.0),
            "dvdt": sdist.Normal(0, 1.0),
            "curv": sdist.Normal(0, 1e-1),
            "jit": sdist.Normal(np.log(3), 0.5),
        }
    else:
        parameters = {
            "per1": sdist.Fixed(periods[0]),
            "tc1": sdist.Fixed(t0s[0]),
            "secosw1": sdist.Fixed(0.01),
            "sesinw1": sdist.Fixed(0.01),
            "logk1": sdist.Normal(np.log(5), 10),
            "per2": sdist.Fixed(periods[1]),
            "tc2": sdist.Fixed(t0s[1]),
            "secosw2": sdist.Fixed(0.01),
            "sesinw2": sdist.Fixed(0.01),
            "logk2": sdist.Normal(np.log(5), 10),
            "gamma": sdist.Normal(0, 10.0),
            "dvdt": sdist.Normal(0, 1.0),
            "curv": sdist.Normal(0, 1e-1),
            "jit": sdist.Normal(np.log(3), 0.5),
        }
    tmod = np.linspace(t.min() - 5, t.max() + 5, num=1000)
    time_base = 2420
    return smod.RVModel(parameters, 2, t, vel, errvel, "per tc secosw sesinw logk", tmod=tmod, time_base=time_base)

In [None]:
model = build_model("circular")
test_p = {"per1": periods[0], "tc1": t0s[0], "secosw1": 0.01, "sesinw1": 0.01, "logk1": 1.1}
test_p |= {"per2": periods[1], "tc2": t0s[1], "secosw2": 0.01, "sesinw2": 0.01, "logk2": 1.1}
test_p |= {"gamma": -10, "dvdt": -0.02, "curv": 0.01, "jit": 1.0}
plot_data()
plt.plot(model.tmod, model.forward(test_p, model.tmod))
plt.show()

In [None]:
model.log_likelihood(test_p)
model.log_prob(test_p)

## Optimization

In [None]:
from scipy.optimize import minimize

vary_p = {p: v for p, v in test_p.items() if p in model.vary_p}
res = minimize(lambda p: - model.log_prob(p), np.array(list(vary_p.values())), method="Nelder-Mead")

In [None]:
opt_p = dict(zip(model.keys(), res.x))

In [None]:
plot_data()
plt.plot(model.tmod, model.forward(opt_p, model.tmod), label="MAP model")
plt.legend()
plt.show()

## Sampling

In [None]:
import emcee

nwalkers = 50
nsteps = 10_000
ndim = model.ndim
sampler = emcee.EnsembleSampler(nwalkers, ndim, model.log_prob)

In [None]:
rng = np.random.default_rng()
p0 = res.x + 1e-4 * rng.normal(size=(nwalkers, ndim))
_ = sampler.run_mcmc(p0, nsteps, progress=True)

In [None]:
from simpple.plot import chainplot
chainplot(sampler.get_chain(), labels=model.keys())
plt.show()

In [None]:
import corner
chains = sampler.get_chain(discard=2000, flat=True, thin=10)
corner.corner(chains, labels=model.keys(), show_titles=True)
plt.show()

In [None]:
from simppler.plot import plot_orbit
plot_orbit(model, chains)
plt.show()

In [None]:
med_p = dict(zip(model.keys(), np.median(chains, axis=0)))
mod_sys = model.forward(med_p, model.tmod, planets=[])
mod_b = model.forward(med_p, model.tmod, planets=[1]) - mod_sys
mod_c = model.forward(med_p, model.tmod, planets=[2]) - mod_sys
mod_sys_data = model.forward(med_p, model.t, planets=[])
mod_b_data = model.forward(med_p, model.t, planets=[1]) - mod_sys_data
mod_c_data = model.forward(med_p, model.t, planets=[2]) - mod_sys_data

## Eccentric orbits

In [None]:
model = build_model("ecc")

### Optimization

In [None]:
vary_p = {p: v for p, v in test_p.items() if p in model.vary_p}
res = minimize(lambda p: - model.log_prob(p), np.array(list(vary_p.values())), method="Powell")

In [None]:
opt_p = dict(zip(model.keys(), res.x))

In [None]:
planets = [1, 2]
all_planets = list(range(1, model.num_planets+1))
mod_sys = model.forward(opt_p, model.tmod, planets=[])
for planet in planets:
    # TODO: opt, med or samples
    per = model.fixed_p[f"per{planet}"].value
    t0 = model.fixed_p[f"tc{planet}"].value
    phase = (model.t - t0) % per / per
    phase_mod = (model.tmod - t0) % per / per
    phase_inds = np.argsort(phase)
    phase_rv = vel[phase_inds]
    phase_erv = errvel[phase_inds]
    phase = phase[phase_inds]
    phase_mod_inds = np.argsort(phase_mod)
    other_planets = [i for i in all_planets if i != planet]
    mod_others_data = model.forward(opt_p, model.t, planets=other_planets)
    mod_others_data = mod_others_data[phase_inds]
    mod_planet = model.forward(opt_p, model.tmod, planets=[planet]) - mod_sys
    plt.figure(figsize=(12, 4))
    # plt.errorbar(t, vel - mod_others_data, yerr=errvel, fmt="k.", capsize=2, mfc="w", label="Data")
    # plt.plot(model.tmod,  mod_planet)
    plt.errorbar(phase, phase_rv - mod_others_data, yerr=phase_erv, fmt="k.", capsize=2, mfc="w", label="Data")
    n, bins = np.histogram(phase, bins=10)
    bin_inds = np.digitize(phase, bins)
    bin_centers = (bins[:-1] + bins[1:]) / 2
    bin_means = [(phase_rv - mod_others_data)[bin_inds == i].mean() for i in range(1, 10+1) ]
    bin_stds = [phase_rv[bin_inds == i].std() for i in range(1, 10+1) ]
    plt.errorbar(bin_centers, bin_means, yerr=bin_stds, fmt="r.", capsize=3, mfc="w", label="Data")
    plt.plot(phase_mod[phase_mod_inds],  mod_planet[phase_mod_inds])
    plt.axhline(0.0, linestyle="--", color="k", alpha=0.5)
    plt.xlabel("Time [d]")
    plt.ylabel("RV [m/s]")
    plt.show()

In [None]:
plot_data()
plt.plot(model.tmod, model.forward(res.x, model.tmod), label="MAP model")
plt.legend()
plt.show()

### Sampling

In [None]:
import emcee

nwalkers = 50
nsteps = 10_000
ndim = model.ndim
sampler = emcee.EnsembleSampler(nwalkers, ndim, model.log_prob)

In [None]:
rng = np.random.default_rng()
p0 = res.x + 1e-4 * rng.normal(size=(nwalkers, ndim))
_ = sampler.run_mcmc(p0, nsteps, progress=True)

In [None]:
chainplot(sampler.get_chain(), labels=model.keys())
plt.show()

In [None]:
import corner
chains = sampler.get_chain(discard=1000, flat=True, thin=10)
corner.corner(chains, labels=model.keys(), show_titles=True)
plt.show()

In [None]:
posterior_preds = model.get_posterior_pred(chains.T, 99, model.tmod)

plot_data()
for i, pred in enumerate(posterior_preds):
    plt.plot(model.tmod, pred, color="C0", alpha=0.1, label="Posterior samples" if i == 0 else None)
plt.legend()
plt.show()

In [None]:
import simppler.plot as sp
sp.plot_orbit(model, chains)
plt.show()