# Modeling differential rotation

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import matplotlib.pyplot as plt
from matplotlib import colors
from mpl_toolkits.axes_grid1 import make_axes_locatable
import numpy as np
from tqdm.notebook import tqdm
from scipy.special import factorial
from scipy.interpolate import interp1d
import starry

starry.config.lazy = False
starry.config.quiet = True

## Differential rotation operator

In [None]:
map = starry.Map(15)
map.load("earth")
map.amp = 1
y0 = np.array(map.y)

In [None]:
class DifferentialOperator(object):
    def __init__(self, ydeg, oversample=3, eps=1e-6):

        # Get pixel transforms
        self.ydeg = ydeg
        map = starry.Map(ydeg)
        self.lat, self.lon, self.P, self.Q, _, _ = map.get_pixel_transforms(
            oversample=oversample
        )

        self.ops = map.ops

        # Get indices of unique latitudes
        self.unique_lat = np.sort(list(set(self.lat)))
        self.idx = np.array([self.lat == l for l in self.unique_lat])

        self.mag = np.sin(self.unique_lat * np.pi / 180.0) ** 2

        self.npix = self.P.shape[0]
        self.nlat = len(self.mag)
        self.Ny = (self.ydeg + 1) ** 2

        # Polar transform
        RP = self.ops.dotR(
            np.eye(self.Ny),
            np.array(1.0),
            np.array(0.0),
            np.array(0.0),
            np.array(-0.5 * np.pi),
        )
        RP = np.tile(np.expand_dims(RP, 1), (1, self.nlat, 1))
        self.RP = np.reshape(RP, (-1, self.Ny))

    def get_D(self, theta):

        # Convert to radians
        theta *= np.pi / 180

        # Apply the differential rotation
        t = np.tile(-theta * self.mag, self.Ny)
        Yzr = np.reshape(self.ops.tensordotRz(self.RP, t), (-1, self.Ny))

        # Transform back out of the polar frame
        Yr = np.reshape(
            self.ops.dotR(
                Yzr, np.array(1.0), np.array(0.0), np.array(0.0), np.array(0.5 * np.pi),
            ),
            (self.Ny, -1, self.Ny),
        )
        Yr = np.swapaxes(Yr, 1, 2)

        # Convert to pixels
        Pr = np.swapaxes(np.tensordot(self.P, Yr, (1, 0)), 0, 1)

        # Select the pixels at each latitude
        Lr = np.zeros((self.npix, self.Ny))
        for j in range((self.ydeg + 1) ** 2):
            for i, row in enumerate(self.idx):
                Lr[row, j] = Pr[j, row, i]

        # Convert back to Ylms
        D = np.dot(self.Q, Lr)

        return D

In [None]:
DiffOp = DifferentialOperator(15)
map[:, :] = DiffOp.get_D(60).dot(y0)
map.show(projection="rect")

In [None]:
tt.tile(np.array([0.0, 1.0]), 2).eval()

In [None]:
np.tile(np.array([0.0, 1.0]), 2)

In [None]:
tt.swapaxes

In [None]:
np.tensordot?

In [None]:
ncoeff = 4
theta = np.array([30.0, 35.0])
mag = np.array([1.0, 2.0, 3.0])
tt.tile(-tt.shape_padright(theta) * tt.shape_padleft(mag), ncoeff).eval()