# Modeling differential rotation

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from scipy.special import factorial
from scipy.interpolate import interp1d
import starry

starry.config.lazy = True
starry.config.quiet = True

## Global parameters

In [None]:
# Spherical harmonic degrees
ydeg_true = 30
ydeg_inf = 20

# Stellar parameters
prot = 1.0
alpha = 0.02
inc = 75

# Light curve noise
ferr = 1e-3

# Time & rotational phase array
time = np.linspace(-10, 10, 1000)
theta = 360.0 / prot * time

## Generate a surface map

Let's generate a rectangular lat/lon grid:

In [None]:
nlat = 101
nlon = 200
lat_arr = np.linspace(-90, 90, nlat)
lon_arr = np.linspace(-180, 180, nlon)
lon_grid, lat_grid = np.meshgrid(lon_arr, lat_arr)

and create an image of a stellar surface with several funny-looking spots:

In [None]:
I = np.ones((nlat, nlon))
y = [10, 42.5, 10, 10, -50, -30, -10, 20, -40, 50]
x = [0, 0, -32.5, 32.5, -90, -130, 130, 130, 130, -120]
s = [0.1, 0.1, 0.1, 0.1, 0.15, 0.1, 0.1, 0.075, 0.075, 0.125]
for i in range(len(y)):
    r = np.sqrt(((lat_grid - y[i]) / 180) ** 2 + ((lon_grid - x[i]) / 180) ** 2)
    I -= np.exp(-((r / s[i]) ** 2))
I[I < 0] = 0

plt.imshow(I, origin="lower", vmax=1, extent=(-180, 180, -90, 90))
plt.xlabel("longitude [deg]")
plt.ylabel("latitude [deg]")
plt.colorbar();

## Add differential rotation

It's straightforward to apply differential rotation to this map given an equatorial rotational period ``prot``, a differential rotation shear ``alpha``, and a time ``t`` since the time at which the original map is defined. Let's code up a function that returns a new version of the map after differential rotation:

In [None]:
def diff_rotate(I, lon_arr, lat_arr, prot, alpha, t):
    Irot = np.zeros_like(I)
    omega_eq = 360.0 / prot
    for i, lat in enumerate(lat_arr):
        new_lon_arr = lon_arr + omega_eq * alpha * t * np.sin(lat * np.pi / 180.0) ** 2
        new_lon_arr = ((new_lon_arr + 180) % 360) - 180
        func = interp1d(lon_arr, I[i])
        Irot[i] = func(new_lon_arr)
    return Irot

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 6))
ax = ax.flatten()
for i, t in enumerate(np.linspace(-10, 10, len(ax))):
    Irot = diff_rotate(I, lon_arr, lat_arr, prot, alpha, t)
    ax[i].imshow(Irot, origin="lower")
    ax[i].set(xticks=[], yticks=[])
    ax[i].set_ylabel(r"${:.2f}$".format(t), fontsize=10)

## Expand in spherical harmonics

In [None]:
map = starry.Map(ydeg_true, inc=inc)
P = map.intensity_design_matrix(lat=lat_grid.flatten(), lon=lon_grid.flatten()).eval()
Q = np.linalg.solve(P.T.dot(P) + 1e-8 * np.eye(P.shape[1]), P.T)

In [None]:
fig, ax = plt.subplots(3, 3, figsize=(12, 6))
ax = ax.flatten()
for i, t in enumerate(np.linspace(-10, 10, len(ax))):
    Irot = diff_rotate(I, lon_arr, lat_arr, prot, alpha, t)
    map[:, :] = Q.dot(Irot.flat)
    map.show(ax=ax[i], projection="rect", cmap="viridis", grid=False)
    ax[i].set(xticks=[], yticks=[])
    ax[i].set_ylabel(r"${:.2f}$".format(t), fontsize=10)

In [None]:
y = np.empty((len(time), map.Ny))
for i in tqdm(range(len(time))):
    Irot = diff_rotate(I, lon_arr, lat_arr, prot, alpha, time[i])
    y[i] = Q.dot(Irot.flat)

In [None]:
# Store the base map (at t = 0)
y0 = Q.dot(I.flat)

## Generate a light curve

In [None]:
A = map.design_matrix(theta=theta).eval()

In [None]:
flux0 = np.einsum("ij,ji->i", A, y.T)
flux0 /= np.nanmedian(flux0)

In [None]:
flux = flux0 + ferr * np.random.randn(len(flux0))

In [None]:
plt.plot(time, flux, "k.", alpha=0.5, ms=3)
plt.plot(time, flux0, "C0-", lw=1, alpha=0.5)
plt.xlabel("time [periods]")
plt.ylabel("flux [normalized]");

## Inference

In [None]:
map = starry.Map(ydeg_inf)

In [None]:
# Base map
y0_inf = y0[: (ydeg_inf + 1) ** 2]
map[:, :] = y0_inf
map.show(projection="rect");

In [None]:
class DifferentialOperator(object):
    def __init__(self, ydeg, oversample=2, eps=1e-6):

        # Get pixel transforms
        self.ydeg = ydeg
        self.order = ydeg
        map = starry.Map(ydeg)
        self.lat, self.lon, self.P, self.Q, _, _ = map.get_pixel_transforms(
            oversample=oversample
        )

        # Get indices of unique latitudes
        self.unique_lat = np.sort(list(set(self.lat)))
        self.idx = np.array([self.lat == l for l in self.unique_lat])

        # Dimensions
        self.npix = len(self.lat)
        self.nlat = len(self.unique_lat)
        self.nfourier = 2 * self.order - 1
        self.ncoeff = (self.ydeg + 1) ** 2

        # Transform tensor
        self.T = np.empty((self.nlat, self.nfourier, self.ncoeff))
        for i, row in enumerate(self.idx):
            X = self._get_X(self.lon[row])
            A = np.linalg.solve(X.T.dot(X) + eps * np.eye(X.shape[1]), X.T)
            if (i == 0) or (i == self.nlat - 1):
                A[:, :] = 0
                A[0, 0] = 1
            self.T[i] = A.dot(self.P[row])

        # Misc
        self.Dp = np.empty((self.npix, self.ncoeff))
        self.mag = np.sin(self.unique_lat * np.pi / 180.0) ** 2

    def _get_X(self, theta):
        return np.vstack(
            (
                np.ones((1, len(theta))),
                np.array(
                    [
                        [
                            np.sin(n * theta * np.pi / 180),
                            np.cos(n * theta * np.pi / 180),
                        ]
                        for n in range(1, self.order)
                    ]
                ).reshape(-1, len(theta)),
            )
        ).T

    def get_D(self, theta):
        for i, row in enumerate(self.idx):
            new_lon = self.lon[row] + theta * self.mag[i]
            new_lon = ((new_lon + 180) % 360) - 180
            X = self._get_X(new_lon)
            self.Dp[row] = X.dot(self.T[i])
        return self.Q.dot(self.Dp)


DiffOp = DifferentialOperator(ydeg_inf)
D = DiffOp.get_D(60)

map.load("earth")
map[:, :] = D.dot(map.y.eval())
map.show(projection="rect")

In [None]:
plt.imshow(D)