=(mw_transit)

# Multi-wavelength Transit Fitting

In this tutorial, we'll see how we can set up a transit model to do inference for multiwavelength transit data. If you haven't checked out the [Transit tutorial](transit.ipynb) yet, we'd recommend doing that first since this model is an extension of the one there!
 

```{note}
This tutorial requires some [extra packages](about.ipynb) that are not included in the `jaxoplanet` dependencies.
```

## Setup

We first setup the number of CPUs to use and enable the use of double-precision numbers with jax. We also import the required packages.

In [None]:
# import os
# os.environ['XLA_FLAGS'] = '--xla_cpu_use_thunk_runtime=false'

In [None]:
import jaxoplanet
from jaxoplanet.light_curves import limb_dark_light_curve
from jaxoplanet.orbits import TransitOrbit
import numpy as np
import matplotlib.pyplot as plt
import numpyro
import numpyro.distributions as dist
import numpyro_ext
import numpyro_ext.distributions as distx, numpyro_ext.optim as optimx
import jax
import jax.numpy as jnp
import corner
import arviz as az
import copy
from functools import partial

numpyro.set_host_device_count(
    2
)  # For multi-core parallelism (useful when running multiple MCMC chains in parallel)
numpyro.set_platform("cpu")  # For CPU (use "gpu" for GPU)
jax.config.update(
    "jax_enable_x64", True
)  # For 64-bit precision since JAX defaults to 32-bit

## Simulating multiwavelength data

We'll start this tutorial by simulating some multiwavelength transit data *somewhat like* what we might expect from JWST transit observations (e.g., single transit, high precision). <br> 
But first, to get some terminology out of the way:
1. **Spectroscopic** light curves refer to those taken within a *very narrow* wavelength range.
2. **Broadband/White** light curves refer to those taken over a wider wavelength range, and in the case of JWST the broadband light curves would be generated by summing the spectroscopic light curves.

For our simulated transit observation, let's change the **depth/radius ratio** and **flux uncertainties** for each spectroscopic light curve but otherwise keep everything else (e.g., limb-darkening coefficients [LDC]) the same. In practice the LDC's would likely have some wavelength dependence.

We'll also need to set a value for the orbital period. With only one transit observation we can't really say anything so let's just arbitrarily set it to `PERIOD=10.0` days.

And to keep things easy we'll stick with 10 spectroscopic light curves. <br>

In [None]:
PERIOD = 10.0  # day
DURATION = 0.3  # day
T0 = 0.0  # day
B = 0.3  # impact parameter
U = np.array([0.4, 0.3])  # LDCs

num_lcs = 10
wavelengths = np.linspace(4.0, 4.9, num_lcs)

# Let's make the depth change for each spectroscopic light curve
DEPTHS = 0.01
DEPTHS += 1e-3 * np.exp(-(((wavelengths - 4.3) / 0.2) ** 2))

# Let's see what the theoretical transmission spectrum would look like
fig, ax = plt.subplots(dpi=150)
ax.plot(wavelengths, DEPTHS, marker=".", ms=10, ls=":")
ax.set_xlabel("wavelength [nm]", fontsize=10)
ax.set_ylabel("transit depth [unitless]", fontsize=10);

In [None]:
t = np.linspace(-0.5, 0.5, 300)
params = {
    "period": PERIOD,
    "duration": DURATION,
    "b": B,
    "t0": T0,
    "u": U,
    "rors": jnp.sqrt(DEPTHS),
}


def eval_limb_dark_light_curve(params, t):
    orbit = TransitOrbit(
        period=params["period"],
        duration=params["duration"],
        impact_param=params["b"],
        time_transit=params["t0"],
        radius_ratio=params["rors"],
    )
    return limb_dark_light_curve(orbit, params["u"])(t)


y_true = jax.vmap(
    eval_limb_dark_light_curve,
    in_axes=(
        {
            "period": None,
            "duration": None,
            "b": None,
            "t0": None,
            "u": None,
            "rors": 0,
        },
        None,
    ),
)(params, t)

stddevs = 1e-5 * wavelengths**3
yerr = np.repeat(stddevs, repeats=t.size).reshape(num_lcs, t.size)
keys = jax.random.split(jax.random.PRNGKey(99), num=stddevs.size)
dy = jax.vmap(
    lambda stddev, key: stddev * jax.random.normal(key, shape=(t.size,)), in_axes=(0, 0)
)(stddevs, keys)
y = y_true + dy

# Let's check our spectroscopic light curves
fig, ax = plt.subplots(dpi=200)
offset = 0.0
for _y_true, _y, stddev, wv in zip(y_true, y, stddevs, wavelengths):
    ax.plot(t, _y_true + offset, lw=0.5, color="k")
    ax.errorbar(
        t, _y + offset, yerr=stddev, marker=".", ms=1, ls="none", lw=0.8, capsize=0
    )
    ax.annotate(f"{100*wv:.1f} nm", xy=(-0.5, 0.002 + offset), fontsize=8)
    offset += 0.01
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux + arbitrary offset", fontsize=10);

## Broadband light curve
Let's also see what the broadband light curve would look like. We'll use [inverse-variance weighting](https://en.wikipedia.org/wiki/Inverse-variance_weighting) to combine the spectroscopic light curves.

In [None]:
inv_var = 1 / stddevs**2
y_bb = np.dot(y.T, inv_var)
y_bb /= np.sum(inv_var)
yerr_bb = np.sqrt(np.sum(inv_var) ** -1)

fig, ax = plt.subplots(dpi=200)
ax.errorbar(t, y_bb, yerr=yerr_bb, marker=".", ls="none", lw=0.8, color="k")
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux", fontsize=10);

## Setting up our Numpyro model

We'll follow a pretty similar setup for the numpyro model as the one we set up in the [single transit tutorial](transit.ipynb).

Let's also assume we have some informative priors from previous measurements that are relatively close to the true values for all the parameters besides the limb-darkening coefficients (LDCs). For those, we'll use the `QuadLDParams` distribution from the `numpyro_ext` package which implements the uninformative prior for quadratic LD as specified in [Kipping (2013)](https://doi.org/10.1093/mnras/stt1435).

In [None]:
def jitter_value(value, jitter_fraction, key):
    jitter = jitter_fraction * value * jax.random.normal(key)
    return value + jitter


# Priors
mu_duration = jitter_value(DURATION, 1e-3, jax.random.PRNGKey(8))
mu_t0 = jitter_value(T0, 1e-4, jax.random.PRNGKey(131))
mu_b = jitter_value(B, 1e-2, jax.random.PRNGKey(23))
keys = jax.random.split(jax.random.PRNGKey(55), num=num_lcs)
mu_depths = jax.vmap(jitter_value, in_axes=(0, None, 0))(DEPTHS, 1e-3, keys)
# mu_depths = np.mean(mu_depths)

# Set to truth
# mu_duration = DURATION
# mu_t0 = T0
# mu_b = B
# mu_depths = DEPTHS

In [None]:
def model(t, yerr, y=None):
    num_lcs = jnp.atleast_2d(yerr).shape[0]

    # Priors

    ## Parameters shared across spectroscopic light curves
    logD = numpyro.sample("logD", dist.Normal(jnp.log(mu_duration), 1e-2))
    duration = numpyro.deterministic("duration", jnp.exp(logD))

    t0 = numpyro.sample(
        "t0", dist.Normal(mu_t0, 1e-3)
    )  # We usually have pretty good constraints on t0
    b = numpyro.sample(
        "b",
        dist.TruncatedNormal(mu_b, 0.1, low=0.0, high=1.0),
    )
    # u = numpyro.sample("u", distx.QuadLDParams())

    ## Parameters for each light curve
    depths = numpyro.sample(
        "depths",
        dist.TruncatedNormal(
            mu_depths,
            1e-3 * jnp.ones_like(mu_depths),
            low=0.0,
            high=1.0,
        ),
    )
    rors = jnp.atleast_1d(numpyro.deterministic("rors", jnp.sqrt(depths)))

    params = {
        "period": PERIOD,
        "duration": duration,
        "t0": t0,
        "b": b,
        "u": U,
        # "rors": rors,
    }

    _light_curve = lambda ror: eval_limb_dark_light_curve(params | {"rors": ror}, t)
    y_model = jax.vmap(_light_curve)(rors)

    # )

    # y_model = jax.vmap(
    #     eval_limb_dark_light_curve,
    #     in_axes=(
    #         {
    #             "period": None,
    #             "duration": None,
    #             "b": None,
    #             "t0": None,
    #             "u": None,
    #             "rors": 0,
    #         },
    #         None,
    #     ),
    # )(params, t)

    numpyro.sample("obs", dist.Normal(y_model, yerr), obs=y)

## Checking priors
Let's check our priors to:
1. Make sure the range of our priors are physically sensible, and
2. We're not *too* off from the true values

In [None]:
n_prior_samples = 2000
prior_samples = numpyro.infer.Predictive(model, num_samples=n_prior_samples)(
    jax.random.PRNGKey(0), t, yerr
)

# Let's make it into an arviz InferenceData object.
# To do so we'll first need to reshape the samples to be of shape (chains, draws, *shape)
converted_prior_samples = {
    f"{p}": np.expand_dims(prior_samples[p], axis=0) for p in prior_samples
}
prior_samples_inf_data = az.from_dict(converted_prior_samples)

# Plot the corner plot
fig = plt.figure(figsize=(20, 20))
_ = corner.corner(
    prior_samples_inf_data,
    fig=fig,
    var_names=["t0", "duration", "b", "rors"],
    truths=[T0, DURATION, B, *jnp.sqrt(DEPTHS)],
    show_titles=True,
    title_kwargs={"fontsize": 10},
    label_kwargs={"fontsize": 10},
)

## Optimize and get MAP estimate
Let's optimize the model to calculate the *maximum a posteriori* (MAP) estimate so that we can use it as the starting point for our MCMC run.

We've found the optimization to be more robust (i.e., not sensitive to the random seed) when we optimize the parameters in batches instead of all at once. 

In [None]:
init_params = {
    "period": PERIOD,
    "duration": mu_duration,
    "b": mu_b,
    # "u": numpyro_ext.distributions.QuadLDParams().sample(jax.random.PRNGKey(2345)),
    # "u": U,
    "t0": mu_t0,
    "depths": mu_depths,
}

keys = jax.random.split(jax.random.PRNGKey(535), num=3)

soln = optimx.optimize(
    model,
    sites=["duration", "t0", "b"],
    start=init_params,
)(keys[0], t, yerr, y=y)

soln = optimx.optimize(
    model,
    sites=["depths"],
    start=soln,
)(keys[1], t, yerr, y=y)

soln = optimx.optimize(
    model,
    start=soln,
)(keys[2], t, yerr, y=y)

Let's extract the model parameters from the `soln` dictionary and plot our MAP model

In [None]:
keys_to_retrieve = list(params.keys())
keys_to_retrieve.remove("u")
keys_to_retrieve.remove("period")
keys_to_retrieve

In [None]:
param_keys = [k for k in params.keys() if k != "period"]
param_keys
map_params = {"period": PERIOD} | {"u": U} | {k: soln[k] for k in keys_to_retrieve}
print(map_params)

in_axes = {
    "period": None,
    "duration": None,
    "b": None,
    "t0": None,
    "u": None,
    "rors": 0,
}

y_model = jax.vmap(eval_limb_dark_light_curve, in_axes=(in_axes, None))(map_params, t)

fig, ax = plt.subplots(dpi=200)
offset = 0.0
_label = "MAP model"
for _y_model, _y, stddev, wv in zip(y_model, y, stddevs, wavelengths):
    ax.errorbar(
        t, _y + offset, yerr=stddev, marker=".", ms=1, ls="none", lw=0.8, capsize=0
    )
    ax.annotate(f"{wv:.1f} nm", xy=(-0.5, 0.002 + offset), fontsize=8)
    ax.plot(t, _y_model + offset, lw=0.5, color="k", label=_label)
    offset += 0.01
    _label = None
ax.set_xlabel("time [day]", fontsize=10)
ax.set_ylabel("relative flux + arbitrary offset", fontsize=10)
ax.legend(markerscale=2, edgecolor="k");

In [None]:
map_params

In [None]:
sampler = numpyro.infer.MCMC(
    numpyro.infer.NUTS(
        model,
        dense_mass=True,
        regularize_mass_matrix=True,
        init_strategy=numpyro.infer.init_to_value(values=map_params),
    ),
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=True,
)

In [None]:
sampler.run(jax.random.PRNGKey(432), t, yerr, y=y)

In [None]:
inf_data = az.from_numpyro(sampler)

In [None]:
samples = sampler.get_samples()
inf_data = az.from_numpyro(sampler)

In [None]:
az.summary(inf_data)

In [None]:
sampler.print_summary()

In [None]:
corner.corner(
    inf_data,
    var_names=["duration", "t0", "b", "rors"],
    truths=[DURATION, T0, B, *jnp.sqrt(DEPTHS)],
);

In [None]:
az.plot_trace(
    inf_data,
    var_names=["duration", "t0", "b", "rors"],
    backend_kwargs={"constrained_layout": True},
);

In [None]:
substituted_model = numpyro.handlers.substitute(model, data=map_params)

In [None]:
with numpyro.handlers.seed(rng_seed=3):
    unconstrain_fn = numpyro.infer.util.unconstrain_fn(
        model, model_args=(t, yerr), model_kwargs={"y": y}, params=map_params
    )

In [None]:
unconstrain_fn

## Sampling

In [None]:
sampler = numpyro.infer.MCMC(
    numpyro.infer.NUTS(
        model,
        dense_mass=True,
        regularize_mass_matrix=True,
        init_strategy=numpyro.infer.init_to_value(values=map_params),
    ),
    num_warmup=1000,
    num_samples=1000,
    num_chains=2,
    progress_bar=True,
)

In [None]:
sampler.run(jax.random.PRNGKey(432), t, yerr, y=y)

In [None]:
samples = sampler.get_samples()
inf_data = az.from_numpyro(sampler)

In [None]:
az.summary(inf_data)

In [None]:
corner.corner(
    inf_data,
    var_names=["duration", "t0", "b", "rors"],
    truths=[DURATION, T0, B, *jnp.sqrt(DEPTHS)],
);

In [None]:
func_unconstrain = partial(
    numpyro.infer.util.unconstrain_fn, model, (t, yerr), {"y": y}
)

In [None]:
unconstrained_samples = jax.vmap(func_unconstrain)(samples)
unconstrained_samples = {
    k: jnp.expand_dims(v, axis=0) for k, v in unconstrained_samples.items()
}

In [None]:
fig = corner.corner(
    inf_data,
    var_names=["logD", "t0", "b", "rors"],
    truths=[np.log(DURATION), T0, B, *jnp.sqrt(DEPTHS)],
)

corner.corner(
    unconstrained_samples, fig=fig, var_names=["logD", "t0", "b", "rors"], color="C3"
);

In [None]:
corner.corner(unconstrained_samples, var_names=["logD", "t0", "b", "rors"], color="C3");

In [None]:
_samples_besides_rors = {k: v for k, v in samples.items() if k != "rors"}

In [None]:
_samples_besides_rors = {k: v[0] for k, v in samples.items()}
# with numpyro.handlers.seed(rng_seed=4):
unconstrained_samples = numpyro.infer.util.unconstrain_fn(
    model, (t, yerr), {"y": y}, _samples_besides_rors
)
func_unconstrain = partial()
numpyro.infer.util.unconstrain_fn(model, (t, yerr), {"y": y})

In [None]:
unconstrained_samples

In [None]:
samples["rors"]

In [None]:
jax.config.update("jax_explain_cache_misses", True)

In [None]:
com_sampler = numpyro.infer.MCMC(
    numpyro.infer.NUTS(
        model,
        dense_mass=True,
        regularize_mass_matrix=True,
        init_strategy=numpyro.infer.init_to_value(values=map_params),
    ),
    num_warmup=1000,
    num_samples=1000,
    num_chains=1,
    progress_bar=True,
)

In [None]:
com_sampler.run(jax.random.PRNGKey(432), t, yerr, y=y)

1. Corner plots of the samples in the unconstrained space
2. 

In [None]:
samples = sampler.get_samples()
inf_data = az.from_numpyro(sampler)
az.summary(inf_data)

In [None]:
corner.corner(
    inf_data,
    var_names=["duration", "t0", "b", "depths"],
    truths=[DURATION, T0, B, *DEPTHS],
);

In [None]:
substitute = numpyro.handlers.substitute
trace = numpyro.handlers.trace

In [None]:
map_params

In [None]:
substituted_model = substitute(model, map_params)

In [None]:
model_trace = trace(substituted_model).get_trace(*(t, yerr), {"y": y})

In [None]:
yerr.shape

In [None]:
y.shape

In [None]:
transposed_samples = {k: v.T for k, v in samples.items()}

In [None]:
transposed_samples["b"].shape

In [None]:
for k, v in samples.items():
    print(f"{k}: {v.shape}")

In [None]:
param_info, potential_fn, postprocess_fn, *_ = numpyro.infer.util.initialize_model(
    jax.random.PRNGKey(44),
    model,
    model_args=(t, yerr),
    dynamic_args=True,
)

In [None]:
asdfasdf_sampels = jax.vmap(postprocess_fn(t, yerr))(samples)

In [None]:
t.shape

In [None]:
y.shape

In [None]:
samples

In [None]:
unconstrained_params = numpyro.infer.util.unconstrain_fn(
    model,
    (t, yerr),
    {"y": y},
    params={"b": samples["b"]},
)

In [None]:
az.plot_trace(
    inf_data,
    var_names=["duration", "t0", "b", "depths"],
    backend_kwargs={"constrained_layout": True},
)

In [None]:
median_sample_depths = np.nanmedian(samples["depths"], axis=0)
MAD_sample_depths = np.nanmedian(
    np.abs(samples["depths"] - median_sample_depths), axis=0
)

In [None]:
fig, ax = plt.subplots(dpi=150)
ax.plot(wavelengths, DEPTHS, marker=".", ms=10, ls=":")
ax.errorbar(
    wavelengths,
    median_sample_depths,
    yerr=MAD_sample_depths,
    capsize=0,
    ls="none",
    marker=".",
)
ax.set_xlabel("wavelength [nm]", fontsize=10)
ax.set_ylabel("transit depth [unitless]", fontsize=10);

In [None]:
y.shape

In [None]:
fig, ax = plt.subplots(dpi=200, figsize=(8, 6))

offset = 0
for _y, _y_true in zip(y, y_true):
    ax.plot(t, _y + offset, marker=".", ls="none")
    ax.plot(t, _y_true + offset, ls="-", marker="none")
    offset += 0.02

In [None]:
idx = jax.random.choice(
    jax.random.PRNGKey(22), jnp.arange(samples["b"].size), shape=(100,), replace=False
)
model_params = (
    {"period": PERIOD}
    | {"u": U}
    | {k: v[idx] for k, v in samples.items() if k in keys_to_retrieve}
)

In [None]:
inner_vmap = jax.vmap(
    eval_limb_dark_light_curve,
    in_axes=(
        {
            "period": None,
            "u": None,
            "duration": None,
            "b": None,
            "t0": None,
            "rors": 0,
        },
        None,
    ),
)
y_sampled = jax.vmap(
    inner_vmap,
    in_axes=(
        {
            "period": None,
            "u": None,
            "duration": 0,
            "b": 0,
            "t0": 0,
            "rors": 0,
        },
        None,
    ),
)(model_params, t)

In [None]:
fig, ax = plt.subplots(dpi=200)
for _y_arr in y_sampled:
    offset = 0
    for _y in _y_arr:
        ax.plot(t, _y + offset, alpha=0.1)
        offset += 0.01
    offset = 0
for _y in y_true:
    ax.plot(t, _y + offset, zorder=100, color="k")
    offset += 0.01

ax.set_xlim(-0.2, +0.2)

In [None]:
fig, ax = plt.subplots(dpi=200)
ax.hist(samples["t0"], bins="auto", histtype="step")

In [None]:
for _y in y_sampled:
    for __y in _y.T:

In [None]:
y_true = jax.vmap(
    eval_limb_dark_light_curve,
    in_axes=(
        {
            "period": None,
            "duration": None,
            "b": None,
            "t0": None,
            "u": None,
            "rors": 0,
        },
        None,
    ),
)(params, t)

In [None]:
??TransitOrbit

In [None]:
jax.vmap(eval_limb_dark_light_curve)

In [None]:
inf_data.log_likelihood.obs.data.shape

In [None]:
ll = jnp.sum(inf_data.log_likelihood.obs.data, axis=(2, 3))

In [None]:
plt.plot(ll.T);

In [None]:
residuals = y - y_true

In [None]:
fig, ax = plt.subplots(dpi=200)
for res in residuals:
    plt.plot(t, res, marker=".", ls="none")

In [None]:
fig, ax = plt.subplots(dpi=200)
for res in residuals:
    ax.hist(res, bins="auto", histtype="step", lw=2, color="k")