# Ensemble analyses

In this tutorial we'll show how to perform a very simple ensemble analysis to infer the statistical properties of the spots on a group of stars.

In [None]:
%matplotlib inline

In [None]:
%config InlineBackend.figure_format = "retina"

In [None]:
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# Disable annoying font warnings
matplotlib.font_manager._log.setLevel(50)

# Disable theano deprecation warnings
import warnings

warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=matplotlib.MatplotlibDeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning, module="theano")

# Style
plt.style.use("default")
plt.rcParams["savefig.dpi"] = 100
plt.rcParams["figure.dpi"] = 100
plt.rcParams["figure.figsize"] = (12, 4)
plt.rcParams["font.size"] = 14
plt.rcParams["text.usetex"] = False
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = ["Liberation Sans"]
plt.rcParams["font.cursive"] = ["Liberation Sans"]
try:
    plt.rcParams["mathtext.fallback"] = "cm"
except KeyError:
    plt.rcParams["mathtext.fallback_to_cm"] = True
plt.rcParams["mathtext.fallback_to_cm"] = True

# Short arrays when printing
np.set_printoptions(threshold=0)

In [None]:
del matplotlib
del plt
del warnings

## Setup

In [None]:
from starry_process import StarryProcess, calibrate
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import pymc3 as pm
import pymc3_ext as pmx
from corner import corner
import theano
import theano.tensor as tt

In [None]:
import matplotlib
from corner import corner as _corner


def corner(*args, **kwargs):
    """
    Override `corner.corner` by making some appearance tweaks.

    """
    # Get the usual corner plot
    figure = _corner(*args, **kwargs)

    # Get the axes
    ndim = int(np.sqrt(len(figure.axes)))
    axes = np.array(figure.axes).reshape((ndim, ndim))

    # Smaller tick labels
    for ax in axes[1:, 0]:
        for tick in ax.yaxis.get_major_ticks():
            tick.label.set_fontsize(10)
        formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
        ax.yaxis.set_major_formatter(formatter)
        ax.set_ylabel(ax.get_ylabel(), fontsize=kwargs.get("corner_label_size", 16))
    for ax in axes[-1, :]:
        for tick in ax.xaxis.get_major_ticks():
            tick.label.set_fontsize(10)
        formatter = matplotlib.ticker.ScalarFormatter(useOffset=False)
        ax.xaxis.set_major_formatter(formatter)
        ax.set_xlabel(ax.get_xlabel(), fontsize=kwargs.get("corner_label_size", 16))

    # Pad the axes to always include the truths
    truths = kwargs.get("truths", None)
    if truths is not None:
        for row in range(1, ndim):
            for col in range(row):
                lo, hi = np.array(axes[row, col].get_xlim())
                if truths[col] < lo:
                    lo = truths[col] - 0.1 * (hi - truths[col])
                    axes[row, col].set_xlim(lo, hi)
                    axes[col, col].set_xlim(lo, hi)
                elif truths[col] > hi:
                    hi = truths[col] - 0.1 * (hi - truths[col])
                    axes[row, col].set_xlim(lo, hi)
                    axes[col, col].set_xlim(lo, hi)

                lo, hi = np.array(axes[row, col].get_ylim())
                if truths[row] < lo:
                    lo = truths[row] - 0.1 * (hi - truths[row])
                    axes[row, col].set_ylim(lo, hi)
                    axes[row, row].set_xlim(lo, hi)
                elif truths[row] > hi:
                    hi = truths[row] - 0.1 * (hi - truths[row])
                    axes[row, col].set_ylim(lo, hi)
                    axes[row, row].set_xlim(lo, hi)

    return figure

## Generate the ensemble

In this section we will generate a synthetic ensemble of light curves of stars with "similar" spot properties. Let's define some true values for the spot properties of the ensemble:

In [None]:
truths = {"r": 15, "mu": 30, "sigma": 5, "c": 0.05, "n": 20}

In [None]:
from IPython.display import display, Markdown
from starry_process.defaults import defaults

display(
    Markdown(
        """
| parameter | description | true value
| - | :- | :-:
| `r` | mean radius in degrees | `{r}`
| `mu` | latitude distribution mode in degrees | `{mu}`
| `sigma` | latitude distribution standard deviation in degrees | `{sigma}`
| `c` | fractional spot contrast | `{c}`
| `n` | number of spots | `{n}`
""".format(
            **truths
        )
    )
)

Now let's generate 500 light curves from stars at random inclinations with spots drawn from the distributions above.
We'll do this by adding discrete circular spots to each star via the `starry_process.calibrate.generate`
function.
Note that in order to mimic real observations, we'll normalize each light curve to its mean value and subtract unity to get the "relative" flux.
For simplicity, we'll give all of the light curves the same period and photometric uncertainty.

In [None]:
data = calibrate.generate(
    generate=dict(
        normalized=True,
        nlc=500,
        period=1.0,
        ferr=1e-3,
        nspots=dict(mu=truths["n"]),
        radius=dict(mu=truths["r"]),
        latitude=dict(mu=truths["mu"], sigma=truths["sigma"]),
        contrast=dict(mu=truths["c"]),
    )
)

The variable `data` is a dictionary containing the light curves, the stellar maps (expressed as vectors of spherical harmonic coefficients `y`), plus some metadata.

In [None]:
t = data["t"]
flux = data["flux"]
ferr = data["ferr"]
y = data["y"]

Let's visualize some of the light curves, all on the same scale:

In [None]:
fig, ax = plt.subplots(3, 5)
for j, axis in enumerate(ax.flatten()):
    axis.plot(t, flux[j] * 1000)
    axis.set_ylim(-50, 50)
    axis.set_xticks([0, 1, 2, 3, 4])
    if j != 10:
        axis.set_xticklabels([])
        axis.set_yticklabels([])
    else:
        axis.set_xlabel("rotations")
        axis.set_ylabel("flux [ppt]");

In the next section, we'll assume we observe only these 500 light curves. We do not know the inclinations of any of the stars or anything about their spot properties: only that all the stars have statistically similar spot distributions.

## Inference

Let's set up a simple probabilistic model using `pymc3` and solve for the five quantities above: the spot radius, the mode and standard deviation of the spot latitude, the spot contrast, and the number of spots. We'll place uniform priors on everything except for the latitude mode `mu`, on which we'll place an isotropic prior.

In [None]:
with pm.Model() as model:

    # For use later
    varnames = ["r", "mu", "sigma", "c", "n"]

    # Spot latitude params. Isotropic prior on the mode
    # and uniform prior on the standard deviation
    u = pm.Uniform("u", 0.0, 1.0)
    mu = 90 - tt.arccos(u) * 180 / np.pi
    pm.Deterministic("mu", mu)
    sigma = pm.Uniform("sigma", 1.0, 20.0)

    # Spot radius (uniform prior)
    r = pm.Uniform("r", 10.0, 30.0)

    # Spot contrast & number of spots (uniform prior)
    c = pm.Uniform("c", 0.0, 1.0, testval=0.1)
    n = pm.Uniform("n", 1.0, 50.0, testval=5)

    # Instantiate the GP
    sp = StarryProcess(r=r, mu=mu, sigma=sigma, c=c, n=n)

    # Compute the log likelihood
    lnlike = sp.log_likelihood(t, flux, ferr ** 2, p=1.0)
    pm.Potential("lnlike", lnlike)

We could go on to do inference using `NUTS` or `ADVI` or any of the other samplers supported by `pymc3`. But that would take a few hours (at least). Since we have many light curves,let's just optimize the log probability function to get the MAP (maximum a posteriori) solution -- that will be a good estimate of the true spot properties.

In [None]:
map_soln = pmx.optimize(model=model)

Here's what we got:

In [None]:
from IPython.display import display, Markdown
from starry_process.defaults import defaults

display(
    Markdown(
        """
| parameter | description | true value | inferred value
| - | :- | :-: | :-:
| `r` | mean radius in degrees | `{r}` | `{{r:.2f}}`
| `mu` | latitude distribution mode in degrees | `{mu}` | `{{mu:.2f}}`
| `sigma` | latitude distribution standard deviation in degrees | `{sigma}` | `{{sigma:.2f}}`
| `c` | fractional spot contrast | `{c}` | `{{c:.4f}}`
| `n` | number of spots | `{n}` | `{{n:.2f}}`
""".format(
            **truths
        ).format(
            **map_soln
        )
    )
)

Not bad! We correctly inferred *all* the hyperparameters of the GP! Note, importantly, that we don't yet have any estimate of the uncertainty on any of these parameters. To get that, we need to actually sample the posterior. Stay tuned!