# Modeling differential rotation

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import matplotlib.pyplot as plt
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from tqdm.notebook import tqdm
from scipy.special import factorial
from scipy.interpolate import interp1d
import starry

starry.config.lazy = False
starry.config.quiet = True

## Differential rotation operator

In [None]:
class DifferentialOperator(object):
    def __init__(self, ydeg, oversample=3, eps=1e-6):

        # Get pixel transforms
        self.ydeg = ydeg
        map = starry.Map(ydeg)
        self.lat, self.lon, self.P, self.Q, _, _ = map.get_pixel_transforms(
            oversample=oversample
        )

        # Get indices of unique latitudes
        self.unique_lat = np.sort(list(set(self.lat)))
        self.idx = np.array([self.lat == l for l in self.unique_lat])

        # Dimensions
        self.npix = len(self.lat)
        self.nlat = len(self.unique_lat)
        self.ncoeff = (self.ydeg + 1) ** 2

        # Transform tensor
        self.T = [None for lat in self.unique_lat]
        for i, row in enumerate(self.idx):
            X = self._get_X(self.lon[row])
            A = np.linalg.solve(X.T.dot(X) + eps * np.eye(X.shape[1]), X.T)
            # The poles don't rotate
            if (i == 0) or (i == self.nlat - 1):
                A[:, :] = 0
                A[0, 0] = 1
            self.T[i] = A.dot(self.P[row])

        # Misc
        self.Dp = np.empty((self.npix, self.ncoeff))
        self.mag = np.sin(self.unique_lat * np.pi / 180.0) ** 2

    def _get_X(self, theta):
        # TODO: Figure out the best regression order
        order = len(theta) // 2 - 1
        return np.hstack(
            (
                np.ones((len(theta), 1)),
                np.transpose(
                    np.reshape(
                        [np.sin(n * theta * np.pi / 180) for n in range(1, order)],
                        (-1, len(theta)),
                    )
                ),
                np.transpose(
                    np.reshape(
                        [np.cos(n * theta * np.pi / 180) for n in range(1, order)],
                        (-1, len(theta)),
                    )
                ),
            )
        )

    def get_D(self, theta):
        for i, row in enumerate(self.idx):
            new_lon = self.lon[row] + theta * self.mag[i]
            new_lon = ((new_lon + 180) % 360) - 180
            X = self._get_X(new_lon)
            self.Dp[row] = X.dot(self.T[i])
        D = self.Q.dot(self.Dp)
        # Preserve luminosity
        D[0, :] = 0
        D[0, 0] = 1
        return D

## Global parameters

In [None]:
# Spherical harmonic degrees
ydeg_true = 30
ydeg_inf = 20

# Stellar parameters
prot = 1.0
alpha = 0.02
inc = 75

# Light curve noise
ferr = 1e-3

# Time & rotational phase array
time = np.linspace(-10, 10, 1000)
theta = 360.0 / prot * time

In [None]:
cmap = plt.get_cmap("plasma")
cmap.set_under("k")
cmap.set_over("w")

## Generate a surface map

Let's create an image of a stellar surface with several funny-looking spots:

In [None]:
map = starry.Map(ydeg_true, inc=inc)

# Add several spots
y = [10, 42.5, 10, 10, -50, -30, -10, 20, -40, 50]
x = [0, 0, -32.5, 32.5, -90, -130, 130, 130, 130, -120]
s = [0.1, 0.1, 0.1, 0.1, 0.15, 0.1, 0.1, 0.075, 0.075, 0.125]
for i in range(len(y)):
    map.add_spot(amp=-0.2 * s[i], lat=y[i], lon=x[i], sigma=0.35 * s[i], relative=False)

# Store the spherical harmonic coefficients
y0 = np.array(map.y)

# Set the amplitude so the background intensity is ~1.0
amp0 = 2.4
map.amp = amp0

# Visualize on a rectangular projection
map.show(
    projection="rect", cmap=cmap, norm=colors.Normalize(vmin=0, vmax=1), colorbar=True
)

In order for us to experiment with differential rotation, let's get the map as a matrix of pixel intensities on a rectagular latitude/longitude grid:

In [None]:
I0 = map.render(projection="rect", res=200)

The columns of the matrix `I0` are the longitudes `lon_grid` and the rows are the latitudes `lat_grid`, which we can obtain as follows:

In [None]:
lat_grid, lon_grid = map.get_latlon_grid(projection="rect", res=200)

These are just (flattened) meshgrids of a latitude array,

In [None]:
lat_arr = np.sort(list(set(lat_grid)))

and a longitude array,

In [None]:
lon_arr = np.sort(list(set(lon_grid)))

each of length 200:

In [None]:
assert len(lon_arr) == len(lat_arr) == 200

If we visualized this image with `imshow` (as you can check), we'd get the same figure as above.

## Add differential rotation

It's straightforward to apply differential rotation to this map given an equatorial rotational period ``prot``, a differential rotation shear ``alpha``, and a time ``t`` since the time at which the original map is defined. Let's code up a function that returns a new version of the map after differential rotation:

In [None]:
def diff_rotate(I, lat_arr, lon_arr, prot, alpha, t):
    Irot = np.zeros_like(I)
    omega_eq = 360.0 / prot
    for i, lat in enumerate(lat_arr):
        new_lon_arr = lon_arr + omega_eq * alpha * t * np.sin(lat * np.pi / 180.0) ** 2
        new_lon_arr = ((new_lon_arr + 180) % 360) - 180
        func = interp1d(lon_arr, I[i], fill_value="extrapolate")
        Irot[i] = func(new_lon_arr)
    return Irot

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 6))
ax = ax.flatten()
Irot_exact = np.zeros((len(ax), *I0.shape))
for i, t in enumerate(np.linspace(-10, 10, len(ax))):
    Irot_exact[i] = diff_rotate(I0, lat_arr, lon_arr, prot, alpha, t)
    ax[i].imshow(
        Irot_exact[i],
        extent=(-180, 180, -90, 90),
        origin="lower",
        cmap=cmap,
        vmin=0,
        vmax=1,
    )
    ax[i].set(xticks=[], yticks=[])
    ax[i].set_ylabel(r"${:.2f}$".format(t), fontsize=10)

## Expand in spherical harmonics

In [None]:
DiffOp = DifferentialOperator(ydeg_true)

In [None]:
DiffOp.mag

In [None]:
plt.imshow(DiffOp.T[10], aspect="auto")

In [None]:
plt.imshow(DiffOp.get_D(90))

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 6))
ax = ax.flatten()
Irot_ylm = np.zeros_like(Irot_exact)
for i, t in enumerate(np.linspace(-10, 10, len(ax))):
    D = DiffOp.get_D(360 / prot * alpha * t)
    map[1:, :] = D.dot(y0)[1:]
    Irot_ylm[i] = map.render(projection="rect", res=200)
    ax[i].imshow(
        Irot_ylm[i],
        extent=(-180, 180, -90, 90),
        origin="lower",
        cmap=cmap,
        vmin=0,
        vmax=1,
    )
    ax[i].set(xticks=[], yticks=[])
    ax[i].set_ylabel(r"${:.2f}$".format(t), fontsize=10)

## Compare

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 6))
ax = ax.flatten()
for i, t in enumerate(np.linspace(-10, 10, len(ax))):
    ax[i].imshow(
        (Irot_exact[i] - Irot_ylm[i]),
        extent=(-180, 180, -90, 90),
        origin="lower",
        cmap=cmap,
        vmin=-0.01,
        vmax=0.01,
    )
    ax[i].set(xticks=[], yticks=[])
    ax[i].set_ylabel(r"${:.2f}$".format(t), fontsize=10)

## Generate a light curve

In [None]:
A = map.design_matrix(theta=theta)

In [None]:
flux0 = np.einsum("ij,ji->i", A, y.T)
flux0 /= np.nanmedian(flux0)

In [None]:
flux = flux0 + ferr * np.random.randn(len(flux0))

In [None]:
plt.plot(time, flux, "k.", alpha=0.5, ms=3)
plt.plot(time, flux0, "C0-", lw=1, alpha=0.5)
plt.xlabel("time [periods]")
plt.ylabel("flux [normalized]");

## Inference