In [1]:
import os

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"

In [151]:
import jax
from spotter import Star, show, light_curves

star = Star.from_sides(50, period=1.0)
star = star.set(y=1 - star.spot(0, 0, 0.2))
time = np.linspace(0, 1, 1000)

In [152]:
import jax.numpy as jnp

jax.jit


def light_curve(star, time, normalize=True):
    """
    Compute the light curve of a rotating Star.

    Parameters
    ----------
    star : Star
        Star object.
    time : ArrayLike
        Time array in days.
    normalize : bool, optional
        Whether to normalize the light curve (default True).

    Returns
    -------
    lc : ndarray
        Light curve array.
    """

    def impl(star, time):
        return jnp.einsum("ij,ij->i", light_curves.design_matrix(star, time), star.y)

    norm = 1 / jnp.mean(star.y) if normalize else 1.0

    return (
        jax.vmap(impl, in_axes=(None, 0))(star, time).T * norm
    )  # np.vectorize(impl, excluded=(0,), signature="()->(n)")(star, time).T * norm

In [165]:
from spotter.core import design_matrix, mask_projected_limb
import numpy as np

n = os.cpu_count()

mask_projected_limb = jax.jit(mask_projected_limb)


def mask_projected_limb2(x, phase=None, inc=None, u=None, obl=None):
    p = int(np.ceil(x.shape[0] / n) * n - x.shape[0])
    xpad = jnp.pad(x, ((0, p), (0, 0)), mode="constant", constant_values=0)
    xsplit = jnp.array(jnp.split(xpad, n, axis=0))
    pmap_mpl = jax.pmap(mask_projected_limb, in_axes=(0, None, None, None, None))
    ms, ps, ls = pmap_mpl(xsplit, phase, inc, u, obl)
    return jnp.hstack(ms)[0:-p], jnp.hstack(ps)[0:-p], jnp.hstack(ls)[0:-p]

In [163]:
mask_projected_limb(star.x, 0)[0].block_until_ready()
%timeit mask_projected_limb(star.x, 0)[0].block_until_ready()

1.82 ms ± 45.6 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [166]:
mask_projected_limb2(star.x, 0)[0].block_until_ready()
%timeit mask_projected_limb2(star.x, 0)[0].block_until_ready()

4.31 ms ± 497 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [98]:
pmap_mpl(xsplit, star.phase(0), star.inc, star.u, star.obl)

(Array([[ True, False, False, ..., False, False, False],
        [False, False, False, ..., False,  True,  True],
        [ True,  True,  True, ..., False, False, False],
        ...,
        [False, False, False, ...,  True,  True,  True],
        [ True, False, False, ..., False, False, False],
        [False, False, False, ...,  True,  True,  True]], dtype=bool),
 Array([[ 4.4389635e-02, -4.4389602e-02, -4.4389602e-02, ...,
         -5.9446234e-01, -5.7982469e-01, -5.5090988e-01],
        [-5.0842965e-01, -4.5343032e-01, -3.8726622e-01, ...,
         -4.3711388e-08,  9.5010236e-02,  1.8863508e-01],
        [ 2.7950913e-01,  3.6630732e-01,  4.4776398e-01, ...,
         -5.1805818e-01, -4.2381376e-01, -3.2338905e-01],
        ...,
        [-4.2381376e-01, -5.1805818e-01, -6.0474807e-01, ...,
          3.6630732e-01,  2.7950913e-01,  1.8863508e-01],
        [ 9.5010236e-02, -4.3711388e-08, -9.5010206e-02, ...,
         -5.0842965e-01, -5.5090988e-01, -5.7982469e-01],
        [-5.944623

In [None]:
xpad.

In [88]:
# pad so that the division by n results in equal division
xpad = np.pad(
    x,
    ((0, int(np.ceil(x.shape[0] / n) * n - x.shape[0])), (0, 0)),
    mode="constant",
    constant_values=0,
)

In [87]:
int(np.ceil(x.shape[0] / n) * n - x.shape[0])

2

In [None]:
np.ceil(x.shape[0] / n) * n - x.shape[0]

np.float64(2.0)

In [None]:
xpad.shape[0] / n

203.6

In [47]:
x.shape

(2028, 3)

In [34]:
timeit(lambda: light_curve(star, time).block_until_ready(), number=10)

0.30065070802811533

In [35]:
timeit(lambda: light_curves.light_curve(star, time).block_until_ready(), number=10)

0.3433654169784859

In [10]:
import numpy as np

time = np.linspace(0, 1, 1000)


@jax.jit
def flux(star, time):
    return light_curves.light_curve(star, time)


flux(star, time)

Array([[0.9646643 , 0.9646649 , 0.964666  , 0.9646706 , 0.9646774 ,
        0.9646857 , 0.9646977 , 0.9647089 , 0.9647227 , 0.96473783,
        0.9647573 , 0.96477616, 0.96479785, 0.9648207 , 0.9648452 ,
        0.9648738 , 0.9649026 , 0.9649338 , 0.96496534, 0.9650013 ,
        0.9650366 , 0.9650762 , 0.9651162 , 0.9651576 , 0.96520096,
        0.96524835, 0.965296  , 0.96534663, 0.96539533, 0.9654506 ,
        0.9655043 , 0.96556246, 0.96562207, 0.9656807 , 0.96574295,
        0.9658076 , 0.9658751 , 0.9659416 , 0.96601063, 0.9660827 ,
        0.9661552 , 0.96623206, 0.9663086 , 0.9663889 , 0.9664665 ,
        0.9665492 , 0.9666361 , 0.9667196 , 0.9668081 , 0.96689826,
        0.9669887 , 0.96708316, 0.967177  , 0.9672741 , 0.9673735 ,
        0.9674732 , 0.9675749 , 0.96767884, 0.9677852 , 0.9678915 ,
        0.9680016 , 0.96810937, 0.9682235 , 0.96833974, 0.96845454,
        0.96857166, 0.9686912 , 0.9688119 , 0.96893525, 0.96905935,
        0.9691868 , 0.9693139 , 0.9694436 , 0.96

In [11]:
from timeit import timeit

timeit(lambda: jax.block_until_ready(flux(star, time)), number=20)

3.532639249926433

In [12]:
import jax.numpy as jnp
from functools import partial

n = os.cpu_count()
padded_time = np.pad(time, (0, len(time) % n), mode="constant")
splitted_time = np.array(np.split(padded_time, n))


# @partial(jax.pmap, in_axes=(None, 0))
def impl(star, time):
    return jnp.einsum("ij,ij->i", light_curves.design_matrix(star, time), star.y)

In [13]:
from timeit import timeit

pmaped_lc = jax.pmap(light_curves.light_curve, in_axes=(None, 0))
pmaped_lc(star, splitted_time)

Array([[[0.9646643 , 0.9646649 , 0.964666  , 0.9646706 , 0.9646774 ,
         0.9646857 , 0.9646977 , 0.9647089 , 0.9647227 , 0.96473783,
         0.9647573 , 0.96477616, 0.96479785, 0.9648207 , 0.9648452 ,
         0.9648738 , 0.9649026 , 0.9649338 , 0.96496534, 0.9650013 ,
         0.9650366 , 0.9650762 , 0.9651162 , 0.9651576 , 0.96520096,
         0.96524835, 0.965296  , 0.96534663, 0.96539533, 0.9654506 ,
         0.9655043 , 0.96556246, 0.96562207, 0.9656807 , 0.96574295,
         0.9658076 , 0.9658751 , 0.9659416 , 0.96601063, 0.9660827 ,
         0.9661552 , 0.96623206, 0.9663086 , 0.9663889 , 0.9664665 ,
         0.9665492 , 0.9666361 , 0.9667196 , 0.9668081 , 0.96689826,
         0.9669887 , 0.96708316, 0.967177  , 0.9672741 , 0.9673735 ,
         0.9674732 , 0.9675749 , 0.96767884, 0.9677852 , 0.9678915 ,
         0.9680016 , 0.96810937, 0.9682235 , 0.96833974, 0.96845454,
         0.96857166, 0.9686912 , 0.9688119 , 0.96893525, 0.96905935,
         0.9691868 , 0.9693139 , 0

In [14]:
timeit(lambda: jax.block_until_ready(pmaped_lc(star, splitted_time)), number=20)

2.1492919160518795

In [15]:
splitted_time.shape

AttributeError: 'list' object has no attribute 'shape'

In [None]:
impl

array([0.        , 0.00099602, 0.00199203, ..., 0.        , 0.        ,
       0.        ], shape=(1106,))

In [None]:
import jax.numpy as jnp


def impl(star, time):
    return jnp.einsum("ij,ij->i", light_curves.design_matrix(star, time), star.y)

In [None]:
def lc(X, time):
    return

[(168, 1, 1452),
 (168, 1, 1452),
 (168, 1, 1452),
 (167, 1, 1452),
 (167, 1, 1452),
 (167, 1, 1452)]

In [40]:
X.shape

(1000, 1, 1200)