# Modelling K2-24 with `simppler`

This is a reproduction of the [RadVel tutorial](https://radvel.readthedocs.io/en/latest/tutorials/K2-24_Fitting%2BMCMC.html) on the same dataset.
Hopefully this can provide a useful comparison between how to implement similar models with the two packages (note that `simppler` models have a `.to_radvel()` method to easily convert models).

## Importing the data

Let us first load the K2-24 observations directly from the RadVel repository, extract the relevant columns, and display it.

In [None]:
from pandas import read_csv

url = "https://raw.githubusercontent.com/California-Planet-Search/radvel/refs/heads/master/example_data/epic203771098.csv"
df = read_csv(url, index_col=0)
df.head()

In [None]:
from matplotlib import rcParams
import matplotlib.pyplot as plt
rcParams["font.size"] = 12.0

t = df.t.values
vel = df.vel.values
errvel = df.errvel.values

def plot_data():
    plt.figure(figsize=(12, 4))
    plt.errorbar(t, vel, yerr=errvel, fmt="k.", capsize=2, mfc="w", label="Data")
    plt.xlabel("Time [d]")
    plt.ylabel("RV [m/s]")
plot_data()
plt.title("RVs of K2-24")
plt.show()

## Building the model

To build `simppler` model, we must first specify our parameters as prior distributions.

To follow the RadVel tutorial, we will fix some parameters.
`simpple` allows us to do this with a `Fixed` distribution in the prior.
Parameters with a fixed distribution will not be included in the model dimensions or keys by default, but are registered and passed to the forward model when needed.
See the [dedicated `simpple` tutorial](https://simpple.readthedocs.io/en/stable/tutorials/fixed-parameters.html) on this topic for more info.

We will use a builder function to easily create models with varying subsets of fixed parameters.

A few things to note:

- Models are created via `RVModel`
- The `RVModel` requires dictionary of parameter distributions
- The `RVModel` also requires:
    - The number of planets (2 in this case)
    - `t`, `rv` and `erv`: The RV data
    - The basis to be used for orbital parameters
    - Model times, to be used when plotting model curves (optional, will be set to `t` by default)
    - Time base to be used as reference point for the trend component of the model (optional, 0 by default)

In [None]:
import numpy as np
import simppler.model as smod
from simpple import distributions as sdist
periods = [20.8851, 42.3633]
period_errs = [0.0003, 0.0006]
t0s = [2072.7948, 2082.6251]
t0_errs = [0.0007, 0.0004]
def build_model(vary):
    # TODO: Eccentricity constraint
    if vary == "all":
        parameters = {
            "per1": sdist.Normal(periods[0], period_errs[0]),
            "tc1": sdist.Normal(t0s[0], t0_errs[0]),
            "secosw1": sdist.Uniform(-1, 1),
            "sesinw1": sdist.Uniform(-1, 1),
            "logk1": sdist.Normal(np.log(5), 10),
            "per2": sdist.Normal(periods[1], period_errs[1]),
            "tc2": sdist.Normal(t0s[1], t0_errs[1]),
            "secosw2": sdist.Uniform(-1, 1),
            "sesinw2": sdist.Uniform(-1, 1),
            "logk2": sdist.Normal(np.log(5), 10),
            "gamma": sdist.Normal(0, 10.0),
            "dvdt": sdist.Normal(0, 1.0),
            "curv": sdist.Normal(0, 1e-1),
            "jit": sdist.Normal(np.log(3), 0.5),
        }
    elif vary == "ecc":
        parameters = {
            "per1": sdist.Fixed(periods[0]),
            "tc1": sdist.Fixed(t0s[0]),
            "secosw1": sdist.Uniform(-1, 1),
            "sesinw1": sdist.Uniform(-1, 1),
            "logk1": sdist.Normal(np.log(5), 10),
            "per2": sdist.Fixed(periods[1]),
            "tc2": sdist.Fixed(t0s[1]),
            "secosw2": sdist.Uniform(-1, 1),
            "sesinw2": sdist.Uniform(-1, 1),
            "logk2": sdist.Normal(np.log(5), 10),
            "gamma": sdist.Normal(0, 10.0),
            "dvdt": sdist.Normal(0, 1.0),
            "curv": sdist.Normal(0, 1e-1),
            "jit": sdist.Normal(np.log(3), 0.5),
        }
    else:
        parameters = {
            "per1": sdist.Fixed(periods[0]),
            "tc1": sdist.Fixed(t0s[0]),
            "secosw1": sdist.Fixed(0.01),
            "sesinw1": sdist.Fixed(0.01),
            "logk1": sdist.Normal(np.log(5), 10),
            "per2": sdist.Fixed(periods[1]),
            "tc2": sdist.Fixed(t0s[1]),
            "secosw2": sdist.Fixed(0.01),
            "sesinw2": sdist.Fixed(0.01),
            "logk2": sdist.Normal(np.log(5), 10),
            "gamma": sdist.Normal(0, 10.0),
            "dvdt": sdist.Normal(0, 1.0),
            "curv": sdist.Normal(0, 1e-1),
            "jit": sdist.Normal(np.log(3), 0.5),
        }
    tmod = np.linspace(t.min() - 5, t.max() + 5, num=1000)
    time_base = 2420
    return smod.RVModel(parameters, 2, t, vel, errvel, "per tc secosw sesinw logk", tmod=tmod, time_base=time_base)

## Circular model

Let us start by building a circular model and plotting with with test values.
The `simppler.plot` module has a few utility functions to plot RV data and models from an `RVModel`.

In [None]:
from simppler.plot import plot_rv, plot_phase

model = build_model("circular")
test_p = {"per1": periods[0], "tc1": t0s[0], "secosw1": 0.01, "sesinw1": 0.01, "logk1": 1.1}
test_p |= {"per2": periods[1], "tc2": t0s[1], "secosw2": 0.01, "sesinw2": 0.01, "logk2": 1.1}
test_p |= {"gamma": -10, "dvdt": -0.02, "curv": 0.01, "jit": 1.0}
plot_rv(model, parameters=test_p, residuals=False)
plt.title("K2-24 RVs with test model")
plt.show()

In [None]:
model.log_likelihood(test_p)
model.log_prob(test_p)

### Optimization

Let us start by doing a simple maximum a posteriori (MAP) estimate.

In [None]:
from scipy.optimize import minimize

vary_p = {p: v for p, v in test_p.items() if p in model.vary_p}
res = minimize(lambda p: - model.log_prob(p), np.array(list(vary_p.values())), method="Nelder-Mead")

In [None]:
opt_p = dict(zip(model.keys(), res.x))

In [None]:
fig, axs = plot_rv(model, opt_p)
axs[0].set_title("K2-24 RVs with the optimized circular model")
plt.show()

In [None]:
fig, axs = plot_phase(model, opt_p)
axs[0].set_title("K2-24 phase folds for the optimized circular model")
plt.show()

### Sampling

We can also do MCMC sampling for our circular model.

In [None]:
import emcee

nwalkers = 50
nsteps = 10_000
ndim = model.ndim
sampler = emcee.EnsembleSampler(nwalkers, ndim, model.log_prob)

In [None]:
rng = np.random.default_rng()
p0 = res.x + 1e-4 * rng.normal(size=(nwalkers, ndim))
_ = sampler.run_mcmc(p0, nsteps, progress=True)

In [None]:
from simpple.plot import chainplot
chainplot(sampler.get_chain(), labels=model.keys())
plt.show()

In [None]:
import corner
chains = sampler.get_chain(discard=2000, flat=True, thin=10)
corner.corner(chains, labels=model.keys(), show_titles=True)
plt.show()

The `plot_rv` and `plot_phase` functions also accept MCMC chains as input.
By default they will display 100 samples from the chain.


In [None]:
fig, axs = plot_rv(model, chains.T, n_samples=100)
axs[0].set_title("Posterior samples for the circular model")
plt.show()

In [None]:
plot_phase(model, chains.T, n_samples=100)
axs[0].set_title("Phase-folded posterior samples for the circular model")

plt.show()

## Eccentric orbits

Let us now repeat all the steps we did for the circular model, but for an eccentric model.

In [None]:
model = build_model("ecc")

### Optimization

In [None]:
vary_p = {p: v for p, v in test_p.items() if p in model.vary_p}
res = minimize(lambda p: - model.log_prob(p), np.array(list(vary_p.values())), method="Powell")

In [None]:
opt_p = dict(zip(model.keys(), res.x))

In [None]:
fig, axs = plot_rv(model, opt_p)
axs[0].set_title("K2-24 RVs with the optimized eccentric model")
plt.show()

In [None]:
fig, axs = plot_phase(model, opt_p)
axs[0].set_title("K2-24 pase folds with the optimized eccentric model")
plt.show()

The optimized model looks like it provides a better fit to the data. Let us now explore the posterior a bit more with MCMC.

### Sampling

In [None]:
import emcee

nwalkers = 50
nsteps = 10_000
ndim = model.ndim
sampler = emcee.EnsembleSampler(nwalkers, ndim, model.log_prob)

In [None]:
rng = np.random.default_rng()
p0 = res.x + 1e-4 * rng.normal(size=(nwalkers, ndim))
_ = sampler.run_mcmc(p0, nsteps, progress=True)

In [None]:
chainplot(sampler.get_chain(), labels=model.keys())
plt.show()

In [None]:
import corner
chains = sampler.get_chain(discard=1000, flat=True, thin=10)
corner.corner(chains, labels=model.keys(), show_titles=True)
plt.show()

In [None]:
fig, axs = plot_rv(model, chains.T)
axs[0].set_title("Posterior samples for the eccentric model")
plt.show()

In [None]:
fig,axs = plot_phase(model, chains.T)
axs[0].set_title("Phase-folded posterior samples for the eccentric model")
plt.show()

## Orbit without fixed parameters

Instead of freezing some parameters, let us see what happens if we let them all vary.
For parameters with good external contraints, we will use Gaussian priors, instead of fixing them.

In [None]:
model = build_model("all")

### Optimization

In [None]:
vary_p = {p: v for p, v in test_p.items() if p in model.vary_p}
res = minimize(lambda p: - model.log_prob(p), np.array(list(vary_p.values())), method="Powell")

In [None]:
opt_p = dict(zip(model.keys(), res.x))

In [None]:
fig, axs = plot_rv(model, opt_p)
axs[0].set_title("K2-24 RVs with the optimized model")
plt.show()

In [None]:
fig, axs = plot_phase(model, opt_p)
axs[0].set_title("K2-24 phase folds for the optimized model")
plt.show()

### Sampling

In [None]:
import emcee

nwalkers = 50
nsteps = 10_000
ndim = model.ndim
sampler = emcee.EnsembleSampler(nwalkers, ndim, model.log_prob)

In [None]:
rng = np.random.default_rng()
p0 = res.x + 1e-4 * rng.normal(size=(nwalkers, ndim))
_ = sampler.run_mcmc(p0, nsteps, progress=True)

In [None]:
chainplot(sampler.get_chain(), labels=model.keys())
plt.show()

In [None]:
import corner
chains = sampler.get_chain(discard=1000, flat=True, thin=10)
corner.corner(chains, labels=model.keys(), show_titles=True)
plt.show()

In [None]:
fig, axs = plot_rv(model, chains.T)
axs[0].set_title("Posterior samples for the full model")
plt.show()

In [None]:
fig, axs = plot_phase(model, chains.T)
axs[0].set_title("Phase-folded posterior samples for the full model")
plt.show()