# Doppler timing tests

Benchmark tests for various methods in the ``DopplerMap`` class.

In [None]:
# Enable progress bars?
TQDM = False

In [None]:
%matplotlib inline

In [None]:
%run notebook_setup.py

In [None]:
import starry

starry.config.lazy = False
starry.config.quiet = True

In [None]:
import starry
import numpy as np
import matplotlib.pyplot as plt
import timeit
from tqdm.notebook import tqdm as _tqdm

tqdm = lambda *args, **kwargs: _tqdm(*args, disable=not TQDM, **kwargs)

In [None]:
def get_time(statement="map.flux()", number=100, **kwargs):
    setup = f"map = starry.DopplerMap(**kwargs); {statement}"
    t0 = timeit.timeit(
        statement, setup=setup, number=1, globals={**locals(), **globals()}
    )
    if t0 > 0.1:
        return t0
    else:
        return (
            timeit.timeit(
                statement, setup=setup, number=number, globals={**locals(), **globals()}
            )
            / number
        )

## `DopplerMap.flux()`

Benchmarks for different evaluation ``method``s.

### As a function of `ydeg`

With `nt = 1`, `nc = 1`, `nw = 200`.

In [None]:
methods = ["dotconv", "convdot", "conv", "design"]
ydegs = [1, 2, 3, 5, 8, 10, 13, 15]
nt = 1
nc = 1
wav = np.linspace(500, 501, 200)
time = np.zeros((len(methods), len(ydegs)))
for i, method in tqdm(enumerate(methods), total=len(methods)):
    for j, ydeg in tqdm(enumerate(ydegs), total=len(ydegs), leave=False):
        time[i, j] = get_time(
            f"map.flux(method='{method}')", ydeg=ydeg, nt=nt, nc=nc, wav=wav
        )
plt.figure(figsize=(8, 5))
plt.plot(ydegs, time.T, "o-", label=methods)
plt.legend(fontsize=10)
plt.yscale("log")
plt.xscale("log")
plt.xlabel("spherical harmonic degree")
plt.ylabel("time [s]");

### As a function of `nt`

With `ydeg = 3`, `nc = 1`, `nw = 200`.

In [None]:
methods = ["dotconv", "convdot", "conv", "design"]
ydeg = 3
nts = [1, 2, 3, 5, 10, 20]
nc = 1
wav = np.linspace(500, 501, 200)
time = np.zeros((len(methods), len(nts)))
for i, method in tqdm(enumerate(methods), total=len(methods)):
    for j, nt in tqdm(enumerate(nts), total=len(nts), leave=False):
        time[i, j] = get_time(
            f"map.flux(method='{method}')", ydeg=ydeg, nt=nt, nc=nc, wav=wav
        )
plt.figure(figsize=(8, 5))
plt.plot(nts, time.T, "o-", label=methods)
plt.legend(fontsize=10)
plt.yscale("log")
plt.xscale("log")
plt.xlabel("number of epochs")
plt.ylabel("time [s]");

### As a function of `nw`

With `ydeg = 3`, `nt = 1`, `nc = 1`.

In [None]:
methods = ["dotconv", "convdot", "conv", "design"]
ydeg = 3
nt = 1
nc = 1
nws = [100, 200, 300, 400, 500, 800, 1000]
wavs = [np.linspace(500, 501, nw) for nw in nws]
time = np.zeros((len(methods), len(wavs)))
for i, method in tqdm(enumerate(methods), total=len(methods)):
    for j, wav in tqdm(enumerate(wavs), total=len(wavs), leave=False):
        time[i, j] = get_time(
            f"map.flux(method='{method}')", ydeg=ydeg, nt=nt, nc=nc, wav=wav
        )
plt.figure(figsize=(8, 5))
plt.plot(nws, time.T, "o-", label=methods)
plt.legend(fontsize=10)
plt.yscale("log")
plt.xscale("log")
plt.xlabel("number of wavelength bins")
plt.ylabel("time [s]");