# Doppler Solve: PyMC3

## Setup

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import starry

starry.config.lazy = True
starry.config.quiet = True

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import starry
import george
import pymc3 as pm
import pymc3_ext as pmx
import theano.tensor as tt
from tqdm.auto import tqdm

In [None]:
def generate(
    nc=1,
    show=True,
    flux_err=1e-4,
    ydeg=15,
    u=[0.5, 0.25],
    nt=16,
    inc=40,
    veq=60000,
    smoothing=0.075,
    **kwargs
):

    # Instantiate
    wav = np.linspace(642.85, 643.15, 200)
    map_kwargs = dict(
        ydeg=ydeg, udeg=len(u), nc=nc, veq=veq, inc=inc, nt=nt, wav=wav
    )
    map = starry.DopplerMap(lazy=False, **map_kwargs)

    # Limb darkening (TODO: fix __setitem__)
    map._u = np.append([-1.0], u)

    # Component surface images
    if nc == 1:
        images = ["spot"]
    elif nc == 2:
        images = ["star", "spot"]
    else:
        raise NotImplementedError("")

    # Component spectra
    if nc == 1:
        mu = np.array([643.0])
    elif nc == 2:
        mu = np.array([643.025, 642.975])
    else:
        raise NotImplementedError("")
    sig = 0.0085
    dw = map.wav0.reshape(1, -1) - mu.reshape(-1, 1)
    spectra = (1.0 - np.exp(-0.5 * dw ** 2 / sig ** 2))[:nc]

    # Load the component maps
    map.load(maps=images, spectra=spectra, smoothing=smoothing)

    # Show
    if show:
        map.show_components(show_spectra=True)

    # Generate unnormalized data
    flux0 = map.flux(normalize=False)
    flux0 += flux_err * np.random.randn(*flux0.shape)

    # Generate normalized data
    flux = map.flux(normalize=True)
    flux += flux_err * np.random.randn(*flux.shape)

    return flux0, flux, map_kwargs

## Solve

In [None]:
settings = dict(
    flux_err=1e-4, ydeg=15, nt=16, inc=40, veq=60000, smoothing=0.075
)
flux0, flux, map_kwargs = generate(nc=1, **settings)

In [None]:
# Regularization params
pb = 1e-3
sb = 1e-5

with pm.Model() as model:

    # Instantiate a uniform map
    map = starry.DopplerMap(**map_kwargs)
    map._u = np.array([-1.0, 0.5, 0.25])

    # SHT matrix
    _, _, _, SHT, _, _ = map._map.get_pixel_transforms()
    npix = SHT.shape[1]

    # Initial guesses
    np.random.seed(0)
    guess_p = 0.5 + 0.01 * np.random.randn(npix)
    guess_spectrum_ = 1 + 0.01 * np.random.randn(map.nw0_)

    # The data
    flux = flux.reshape(-1)
    flux_err = settings["flux_err"]

    # Prior on the map
    p = pm.Laplace("p", mu=1, b=pb, shape=(npix,), testval=guess_p)
    map._y = tt.reshape(
        tt.dot(SHT, p),
        (map.Ny, 1),
    )

    # Prior on the spectrum
    spectrum_ = pm.Laplace(
        "spectrum_", mu=1, b=sb, shape=(map.nw0_,), testval=guess_spectrum_
    )
    map._spectrum = tt.reshape(
        spectrum_,
        (1, map.nw0_),
    )

    # Compute the model
    flux_model = map.flux()
    flux_model = tt.reshape(flux_model, (map.nt * map.nw,))

    # Likelihood term
    pm.Normal("obs", mu=flux_model, sd=flux_err, observed=flux)

In [None]:
niter = 1000
lr = 1e-1

loss = []
best_loss = np.inf
map_soln = model.test_point
with model:
    for obj, point in tqdm(
        pmx.optim.optimize_iterator(
            pmx.optim.Adam(lr=lr), niter, vars=[p, spectrum_], start=map_soln
        )
    ):
        loss.append(obj)
        if obj < best_loss:
            best_loss = obj
            map_soln = point

In [None]:
with model:
    map.show_components(point=map_soln);