# Time evolution of `starry` maps

In this notebook, we're going to take a look at how to model a star whose light curve evolves in time. The assumption here is that the evolution is due to either spot migration / evolution or differential rotation, so we need a way to model a time-variable surface map. There's a few different ways we can do that. Please note that these are all **experimental features** -- we're still working on the most efficient way of modeling temporal variability, so stay tuned!

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

Let's begin with our usual imports.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from scipy.special import factorial
import starry

starry.config.lazy = False
starry.config.quiet = True

## Generate the data

In [None]:
np.random.seed(0)
map = starry.Map(10, 1)
map[1] = 0.5

inc = 60
alpha = 0.02
P = 1.0
omega_eq = 360.0 / P
time = np.linspace(0, 30, 1000)
time_ani = time[::10]
true_flux = np.zeros_like(time)

res = 300
true_image_rect = np.zeros((len(time_ani), res, res))
true_image_ortho = np.zeros((len(time), res, res))

for lat, lon in zip([-30, 30, -20], [-90, 60, 135]):
    intensity = -0.5
    sigma = 0.05
    omega = omega_eq * (1 - alpha * np.sin(lat * np.pi / 180.0) ** 2)
    map.reset(inc=inc)
    map[1] = 0.5
    map.add_spot(intensity=intensity, sigma=sigma, lat=lat, lon=lon)
    true_flux += map.flux(theta=omega * time)
    true_image_ortho += map.render(theta=omega * time)

    # HACK
    tmp = map.render(projection="rect")
    shift = np.array((omega - omega_eq) * time_ani * res / 360, dtype=int)
    for n in range(len(time_ani)):
        true_image_rect[n] += np.roll(tmp, shift[n], axis=1)

sigma = 1e-3
flux = true_flux / np.nanmedian(true_flux)
flux += sigma * np.random.random(len(time))

In [None]:
map.show(image=true_image_rect, projection="rect")

In [None]:
plt.plot(time, flux);

## Taylor expansion

In [None]:
map = starry.Map(5, 1)
map.inc = inc
map[1] = 0.5
order = 4
P = 1.00
theta = 360.0 / P * time
A0 = map.design_matrix(theta=theta)

In [None]:
t = 2.0 * (time / time[-1] - 0.5)
coeff = 1.0 / factorial(np.arange(order + 1))
T = np.vander(t, order + 1, increasing=True) * coeff
A = np.hstack([(A0 * T[:, n].reshape(-1, 1)) for n in range(order + 1)])

In [None]:
plt.imshow(T, aspect="auto")
plt.colorbar();

In [None]:
plt.imshow(A, aspect="auto")
plt.colorbar();

In [None]:
cho_C = np.diag(np.ones_like(flux) * sigma)
mu = np.zeros(A.shape[1])
mu[0] = 1.0

LInv = np.diag(np.ones(map.Ny * (order + 1)) * 1e2)
LInv[0, 0] = 1.0

x, cho_cov = starry.linalg.solve(A, flux, cho_C, mu, LInv)

In [None]:
model = A.dot(x)
plt.plot(time, flux, ".", ms=3, label="data")
plt.plot(time, model, label="model")
plt.ylabel("flux [normalized]")
plt.xlabel("time [days]")
plt.legend();

In [None]:
image = np.empty((len(time_ani), res, res))

T_ani = (
    np.vander(2.0 * (time_ani / time_ani[-1] - 0.5), order + 1, increasing=True) * coeff
)

for n in tqdm(range(len(time_ani))):
    xn = x.reshape(order + 1, -1).T.dot(T_ani[n])
    map.amp = xn[0]
    map[1:, :] = xn[1:] / map.amp
    image[n] = map.render(res=res, projection="rect")

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

fig, ax = plt.subplots(2, figsize=(7, 6))

img1 = ax[0].imshow(
    image[0], origin="lower", cmap="plasma", extent=(-180, 180, -90, 90)
)
img2 = ax[1].imshow(
    true_image_rect[0], origin="lower", cmap="plasma", extent=(-180, 180, -90, 90)
)

for i, axis in enumerate(ax):
    lats = np.linspace(-90, 90, 7)[1:-1]
    lons = np.linspace(-180, 180, 13)
    latlines = [None for n in lats]
    for n, lat in enumerate(lats):
        latlines[n] = axis.axhline(lat, color="k", lw=0.5, alpha=0.5, zorder=100)
    lonlines = [None for n in lons]
    for n, lon in enumerate(lons):
        lonlines[n] = axis.axvline(lon, color="k", lw=0.5, alpha=0.5, zorder=100)
    axis.set_yticks(lats)
    axis.set_ylabel("Latitude [deg]")
    axis.set_xticks(lons)
    if i == 1:
        axis.set_xlabel("Longitude [deg]")
    else:
        axis.set_xticklabels([])


def updatefig(i):
    img1.set_array(image[i])
    img2.set_array(true_image_rect[i])
    return (img1, img2)


ani = FuncAnimation(fig, updatefig, interval=75, blit=True, frames=image.shape[0],)

plt.close()
display(HTML(ani.to_html5_video()))