# Modeling differential rotation

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
import starry
from scipy.linalg import dft

starry.config.lazy = False
starry.config.quiet = True

## DFT

In [None]:
def func(x):
    return (
        2 * np.sin(np.pi / 180 * x - 0.5)
        + 2 * np.cos(2 * np.pi / 180 * x)
        - 0.5 * np.sin(4 * np.pi / 180 * x + 1)
    )


x0 = np.linspace(-180, 180, 1000, endpoint=False)
f0 = func(x0)

N = 15
x = np.linspace(-180, 180, N, endpoint=False)
f = func(x)

plt.plot(x0, func(x0))
plt.plot(x, f, "o");

In [None]:
dx = 30
plt.plot(x0, func(x0), "C0--")
plt.plot(x0, func(x0 - dx), "C0-")


X = dft(N)
XInv = np.linalg.inv(X)
xi = np.fft.fftfreq(N, d=360 / N)
T = np.diag(np.exp(-2 * np.pi * 1j * xi * dx))
A = XInv.dot(T).dot(X).real


plt.plot(x, A.dot(f), "C1.");

In [None]:
plt.imshow(A)

## Differential rotation operators

In [None]:
class DiffOpVander(object):
    def __init__(self, ydeg, oversample=2, eps=1e-12):

        # Get pixel transforms
        self.ydeg = ydeg
        map = starry.Map(ydeg)
        self.lat, self.lon, self.P, self.Q, _, _ = map.get_pixel_transforms(
            oversample=oversample
        )

        # Get indices of unique latitudes
        self.unique_lat = np.sort(list(set(self.lat)))
        self.idx = np.array([self.lat == l for l in self.unique_lat])

        # Dimensions
        self.npix = len(self.lat)
        self.nlat = len(self.unique_lat)
        self.ncoeff = (self.ydeg + 1) ** 2

        # Transform tensor
        self.T = [None for lat in self.unique_lat]
        for i, row in enumerate(self.idx):
            X = self._get_X(self.lon[row])
            A = np.linalg.solve(X.T.dot(X) + eps * np.eye(X.shape[1]), X.T)
            # The poles don't rotate
            if (i == 0) or (i == self.nlat - 1):
                A[:, :] = 0
                A[0, 0] = 1
            self.T[i] = A.dot(self.P[row])

        # Misc
        self.Dp = np.empty((self.npix, self.ncoeff))
        self.mag = np.sin(self.unique_lat * np.pi / 180.0) ** 2

    def _get_X(self, theta):
        # TODO: Figure out the best regression order
        order = len(theta) // 2 - 1
        return np.hstack(
            (
                np.ones((len(theta), 1)),
                np.transpose(
                    np.reshape(
                        [np.sin(n * theta * np.pi / 180) for n in range(1, order)],
                        (-1, len(theta)),
                    )
                ),
                np.transpose(
                    np.reshape(
                        [np.cos(n * theta * np.pi / 180) for n in range(1, order)],
                        (-1, len(theta)),
                    )
                ),
            )
        )

    def get_D(self, theta):
        for i, row in enumerate(self.idx):
            new_lon = self.lon[row] + theta * self.mag[i]
            new_lon = ((new_lon + 180) % 360) - 180
            X = self._get_X(new_lon)
            self.Dp[row] = X.dot(self.T[i])
        D = self.Q.dot(self.Dp)
        # Preserve luminosity
        D[0, :] = 0
        D[0, 0] = 1
        return D

In [None]:
class DiffOpDFT(object):
    def __init__(self, ydeg, dr_oversample=2, dr_lam=1e-12):

        # DEBUG
        map = starry.Map(ydeg)
        self.ydeg = ydeg
        self.RAxisAngle = map.ops.RAxisAngle
        self.pT = map.ops.pT
        self._c_ops = map.ops._c_ops
        # / DEBUG

        # Get pixel transforms
        npix = dr_oversample * (self.ydeg + 1) ** 2
        Ny = int(np.sqrt(npix * np.pi / 4.0))
        Nx = 2 * Ny
        y, x = np.meshgrid(
            np.sqrt(2) * np.linspace(-1, 1, Ny),
            2 * np.sqrt(2) * np.linspace(-1, 1, Nx),
        )
        x = x.flatten()
        y = y.flatten()

        # Remove off-grid points
        a = np.sqrt(2)
        b = 2 * np.sqrt(2)
        idx = (y / a) ** 2 + (x / b) ** 2 <= 1
        y = y[idx]
        x = x[idx]

        # https://en.wikipedia.org/wiki/Mollweide_projection
        theta = np.arcsin(y / np.sqrt(2))
        lat = np.arcsin((2 * theta + np.sin(2 * theta)) / np.pi)
        lon0 = 3 * np.pi / 2
        lon = lon0 + np.pi * x / (2 * np.sqrt(2) * np.cos(theta))

        # Add points at the poles
        lat = np.append(lat, [-np.pi / 2, 0, 0, np.pi / 2])
        lon = np.append(lon, [1.5 * np.pi, 1.5 * np.pi, 2.5 * np.pi, 1.5 * np.pi])

        # Back to Cartesian, this time on the *sky*
        x = np.reshape(np.cos(lat) * np.cos(lon), [1, -1])
        y = np.reshape(np.cos(lat) * np.sin(lon), [1, -1])
        z = np.reshape(np.sin(lat), [1, -1])
        R = self.RAxisAngle(np.array([1.0, 0.0, 0.0]), np.array(-np.pi / 2))
        x, y, z = np.dot(R, np.concatenate((x, y, z)))
        x = x.reshape(-1)
        y = y.reshape(-1)
        z = z.reshape(-1)

        # Flatten and fix the longitude offset, then sort by latitude
        lat = lat.reshape(-1)
        lon = (lon - 1.5 * np.pi).reshape(-1)
        idx = np.lexsort([lon, lat])
        lat = lat[idx]
        lon = lon[idx]
        x = x[idx]
        y = y[idx]
        z = z[idx]

        # Get indices of unique latitudes
        unique_lat = np.sort(list(set(lat)))
        self._dr_idx = np.array([lat == l for l in unique_lat])

        # Dimensions
        npix = len(lat)
        nlat = len(unique_lat)
        ncoeff = (self.ydeg + 1) ** 2

        # Get the forward (P) and reverse (Q) pixel transforms
        pT = self.pT(x, y, z)[:, :ncoeff]
        P = pT * self._c_ops.A1
        Q = np.linalg.solve(P.T.dot(P) + dr_lam * np.eye(ncoeff), P.T)

        # Pre-compute the transform matrices
        self._dr_arg = [None for row in self._dr_idx]
        self._dr_QXInvR = [None for row in self._dr_idx]
        self._dr_QXInvI = [None for row in self._dr_idx]
        self._dr_XRP = [None for row in self._dr_idx]
        self._dr_XIP = [None for row in self._dr_idx]
        for i, row in enumerate(self._dr_idx):
            N = len(lon[row])
            X = dft(N)
            XInv = np.linalg.inv(X)
            XInvR = XInv.real
            XInvI = XInv.imag
            XR = X.real
            XI = X.imag
            d = (lon[row][-1] - lon[row][0]) / N
            if d == 0:
                d = 2 * np.pi
            self._dr_arg[i] = (
                -2 * np.pi * np.fft.fftfreq(N, d=d) * np.sin(unique_lat[i]) ** 2
            )
            self._dr_QXInvR[i] = Q[:, row].dot(XInvR)
            self._dr_QXInvI[i] = Q[:, row].dot(XInvI)
            self._dr_XRP[i] = XR.dot(P[row])
            self._dr_XIP[i] = XI.dot(P[row])

        self._dr_D = np.zeros((ncoeff, ncoeff))

    def get_D(self, theta):

        # DEBUG
        theta *= np.pi / 180
        # /DEBUG

        self._dr_D[:, :] = 0
        for i, row in enumerate(self._dr_idx):
            TR = np.cos(self._dr_arg[i] * theta)
            TI = -np.sin(self._dr_arg[i] * theta)
            self._dr_D += (self._dr_QXInvR[i] * TR).dot(self._dr_XRP[i])
            self._dr_D -= (self._dr_QXInvR[i] * TI).dot(self._dr_XIP[i])
            self._dr_D -= (self._dr_QXInvI[i] * TR).dot(self._dr_XIP[i])
            self._dr_D -= (self._dr_QXInvI[i] * TI).dot(self._dr_XRP[i])

        # Preserve luminosity
        self._dr_D[0, :] = 0
        self._dr_D[0, 0] = 1

        return self._dr_D

In [None]:
map = starry.Map(20)
map.load("earth", sigma=0.08)
map.amp = 1
map[15:, :] = 0
y0 = np.array(map.y)
img0 = map.render(projection="moll")

In [None]:
map[:, :] = y0
DiffOp = DiffOpVander(20)
map[:, :] = DiffOp.get_D(30).dot(map.y)
map[:, :] = DiffOp.get_D(-30).dot(map.y)
imgV = map.render(projection="moll")

In [None]:
map[:, :] = y0
DiffOp = DiffOpDFT(20)
map[:, :] = DiffOp.get_D(30).dot(map.y)
map.show(projection="rect")
map[:, :] = DiffOp.get_D(-30).dot(map.y)
imgD = map.render(projection="moll")

In [None]:
vmin0 = np.nanmin(imgV - img0)
vmax0 = np.nanmax(imgV - img0)
vmin1 = np.nanmin(imgD - img0)
vmax1 = np.nanmax(imgD - img0)
vmin = min(vmin0, vmin1)
vmax = max(vmax0, vmax1)
vmin = min(vmin, -vmax)
vmax = -vmin

fig, ax = plt.subplots(2, figsize=(8, 8))
ax[0].imshow(
    (imgV - img0),
    extent=(-2, 2, -1, 1),
    origin="lower",
    cmap="RdBu",
    vmin=vmin,
    vmax=vmax,
)
ax[0].set_title("{:.3f}".format(np.nansum((imgV - img0) ** 2)))

ax[1].imshow(
    (imgD - img0),
    extent=(-2, 2, -1, 1),
    origin="lower",
    cmap="RdBu",
    vmin=vmin,
    vmax=vmax,
)
ax[1].set_title("{:.3f}".format(np.nansum((imgD - img0) ** 2)))
print(vmin, vmax);