# An introduction to PyMC3 & exoplanet for astronomers

By: **Dan Foreman-Mackey**

In this tutorial, we will learn how to use [exoplanet](https://exoplanet.dfm.io) and [PyMC3](https://docs.pymc.io) to do Markov chain Monte Carlo (MCMC) with a focus on fitting [TESS](https://en.wikipedia.org/wiki/Transiting_Exoplanet_Survey_Satellite) data.
But first, we have to do some setup:

In [None]:
# We want to see plots in the browser
%matplotlib inline

In [None]:
# This is the one dependency missing on the science platform
!pip install -q -U corner exoplanet

# Let's make the plots look a little nicer
from matplotlib import rcParams
rcParams["savefig.dpi"] = 100
rcParams["figure.dpi"] = 100
rcParams["font.size"] = 16

# The installation of Theano is a little broken (but it'll work
# fine for our purposes). Deal with those issues as follows:
import os
import warnings
os.environ["MKL_THREADING_LAYER"] = "GNU"
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

## Part 2: Fitting a transit with exoplanet

In this example, we will actually fit some real TESS data.
Specifically we will fit the light curve of the first transiting planet that was discovered by TESS: [Pi Mensae c](https://arxiv.org/abs/1809.05967).
A more complete example of fitting these data can be found on the [exoplanet docs](https://exoplanet.dfm.io/en/stable/tutorials/tess/).

I've taken the liberty of doing the photometry, preprocessing, and de-trending in advance so you can [download the data from GitHub](https://github.com/dfm/tess-tutorial).
The script that I used to prepare this file [is also available in the same repository](https://github.com/dfm/tess-tutorial/blob/master/notebooks/preprocess.ipynb).

Here's how we download and plot the data:

In [None]:
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt

with fits.open("https://github.com/dfm/tess-tutorial/raw/master/data/pimen-ffi.fits") as hdus:
    data = hdus[1].data
    hdr = hdus[1].header
    
texp = hdr["FRAMETIM"] * hdr["NUM_FRM"]
texp /= 60.0 * 60.0 * 24.0
    
x = np.array(data["time"], dtype=np.float64)
y = np.array(data["flux"], dtype=np.float64)

plt.plot(x, y, ".k")
plt.xlabel("time [days]")
plt.ylabel("relative flux [ppt]");

You can see that I have cut out the transits of our planet and saved only the data near transit.
The fluxes are measured in parts per thousand because I find that those units tend to have the easiest order of magnitude to work with for transits.
The times are measured relative to the midpoint of Sector 1.

When pre-processing the data, I ran [BLS](http://docs.astropy.org/en/stable/stats/bls.html) to determine the period and phase of this candidate.
Let's use these values to plot the folded light curve:

In [None]:
period_guess = 6.26554
t0_guess = -1.19041

x_fold = (x - t0_guess + 0.5*period_guess) % period_guess - 0.5*period_guess

plt.scatter(x_fold, y, c=x)
plt.xlabel("time since transit [days]")
plt.ylabel("relative flux [ppt]");

Since the data points are colored by time, we can see that the depth of each transit is at least qualitatively consistent.

Next up we will specify the model.
The syntax is the same as in the previous example, but there is a lot more going on here.
The inline comments should give most of the relevant information, but you can also take a look at [the exoplanet docs](https://exoplanet.dfm.io/en/stable/) for more information about the extra features provided.

In [None]:
import pymc3 as pm
import theano.tensor as tt

# The tradition is to import "exoplanet" with the shorthand "xo" 
import exoplanet as xo

with pm.Model() as model:
    
    # A parameter describing the observational uncertainties
    logs = pm.Uniform("logs", lower=-5, upper=0,
                      testval=np.log(np.std(y)))
    
    # The mean flux of the star in the units of the data. This should
    # be about zero because we're measuring relative flux, but if I
    # got the baseline slightly wrong in my preprocessing, this can
    # protect us.
    mean_flux = pm.Normal("mean_flux", mu=0, sd=1)
    
    # A prior on the limb darkening parameters. We'll use the
    # parameterization recommended by Kipping (2013) based on triangular
    # sampling.
    u = xo.distributions.QuadLimbDark("u")
    
    # We'll constrain the prior to be within 1% of the guess that we
    # provided. Making this range too large could result in issues since
    # we're not fitting the full dataset and a model with half the period
    # would be equivalent.
    period = pm.Uniform("period",
                        lower=period_guess*0.99,
                        upper=period_guess*1.01,
                        testval=period_guess)
    
    # We can see from the plot above that the reference transit time
    # is pretty close to correct so we'll constrain it to be within
    # 0.1 days of the initial guess.
    t0 = pm.Uniform("t0", lower=t0_guess-0.1, upper=t0_guess+0.1)
    
    # To sample the radius ratio and impact parameter, we'll use the
    # parameterization for the joint density recently recommended by
    # Espinoza (2018).
    r, b = xo.distributions.get_joint_radius_impact(
        min_radius=0.005, max_radius=0.05, testval_r=0.015)
    
    # Now set up a Keplerian orbit with the expected parameters
    orbit = xo.orbits.KeplerianOrbit(
        period=period, t0=t0, b=b)
    
    # The light curve model is computed using "starry"
    star = xo.StarryLightCurve(u)
    light_curve = star.get_light_curve(
        orbit=orbit, r=r, t=x, texp=texp)
    
    # The returned light curve will always have the shape (ntime, nplanet)
    # but we only have one planet so we can "squeeze" the result to remove
    # the last axis. Also: don't forget to convert to ppt and add in the
    # stellar flux parameter
    light_curve = tt.squeeze(star.get_light_curve(
        orbit=orbit, r=r, t=x, texp=texp))*1e3 + mean_flux
    
    # Finally, this is the likelihoood for the observations
    pm.Normal("obs", mu=light_curve, sd=tt.exp(logs), observed=y)

After defining the model, it can be useful to look at a plot of the initial model (as defined using the `testval` parameters above) to make sure that it's not completely unreasonable.
In PyMC3 (and therefore also exoplanet) when you define a model, none of the operations are actually *executed*.
Instead, the previous cell just defines the *relationships* between parameters.
exoplanet [comes with some features](https://exoplanet.dfm.io/en/stable/user/api/#utilities) that make it easier to inspect the model after you have defined it.
For example, we can use `xo.utils.eval_in_model` to plot the initial light curve:

In [None]:
# Plot the data as above
plt.scatter(x_fold, y, c=x)

# Compute the initial transit model evaluated at each data point
# and overplot that
with model:
    transit_model = xo.utils.eval_in_model(light_curve)

# For plotting purposes, sort the folded times
inds = np.argsort(x_fold)
plt.plot(x_fold[inds], transit_model[inds], "k", label="initial model")

plt.legend(fontsize=12, loc=3)
plt.xlabel("time since transit [days]")
plt.ylabel("relative flux [ppt]");

That doesn't look perfect (that's why we're here after all!) but it'll do as a starting point.

Now we can sample the posterior for this model.
We won't directly use the `pm.sample` function like we did in the previous tutorial because there are significant covariances between some of the parameters and the built in sampler can't handle that.
exoplanet [comes with a wrapper around the "sample" function that adds support for covariances](https://exoplanet.dfm.io/en/stable/tutorials/pymc3-extras/#custom-tuning-schedule) and this makes a big difference here.

In [None]:
np.random.seed(42)
sampler = xo.PyMC3Sampler(finish=200)
with model:
    sampler.tune(tune=2000, step_kwargs=dict(target_accept=0.9))
    trace = sampler.sample(draws=2000)

After sampling using this method, we can still do the usual convergence checks.
For example, we can make a trace plot:

In [None]:
pm.traceplot(trace, varnames=["period", "t0"]);

And look at the quantitative summary to make sure that the effective number of samples is high enough to trust the result:

In [None]:
pm.summary(trace)

And we can make a corner plot of the key parameters:

In [None]:
import corner
varnames = ["period", "t0", "r", "b"]
labels = ["period [days]", "transit time [days]", "radius ratio", "impact parameter"]
samples = pm.trace_to_dataframe(trace, varnames=varnames)
corner.corner(samples[["period", "t0", "r__0", "b__0"]], labels=labels);

Finally, we can plot the posterior constraint on the transit model:

In [None]:
# Compute the posterior parameters
median_period = np.median(trace["period"])
median_t0 = np.median(trace["t0"])
median_x_fold = (x - median_t0 + 0.5*median_period) % median_period - 0.5*median_period
median_inds = np.argsort(median_x_fold)

# Plot the data
plt.scatter(median_x_fold, y, c=x)

# This is a little convoluted, but we'll take 100 random samples from the chain
# and for each sample, we'll evaluate the predicted transit model and overplot it
with model:
    # Pre-compile a function to evaluate the light curve
    func = xo.utils.get_theano_function_for_var(light_curve)
    
    # Loop over 100 random samples
    for sample in xo.utils.get_samples_from_trace(trace, size=100):
        
        # Fold the times based on the period and phase of this sample
        fold = (x - sample["t0"]+0.5*sample["period"])%sample["period"]-0.5*sample["period"]
        inds = np.argsort(fold)
        
        # Evaluate the light curve
        args = xo.utils.get_args_for_theano_function(sample)
        transit_model = func(*args)
        
        # And plot the light curve model
        plt.plot(fold[inds], transit_model[inds], "k", lw=0.5, alpha=0.2)

# Format the plot
plt.xlim(-0.39, 0.39)
plt.xlabel("time since transit [days]")
plt.ylabel("relative flux [ppt]");