# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
%reset -f

import numpy as np

n, d = 10_000_000, 10
A = np.random.randn(d, d)
x = np.random.randn(n, d)


def f(x):
    return np.einsum("...d, ...e, de -> ...", x, x, A, optimize=True)


def g(x):
    return (x @ A * x).sum(len(x.shape) - 1)


assert f(x).shape == g(x).shape
np.allclose(f(x), g(x))

In [None]:
%%timeit
f(x)

In [None]:
%%timeit
g(x)

In [None]:
%reset -f

import mxnet as mx
from mxnet import np

n, d = 10_000_000, 10
A = np.random.normal(0, 1, size=(d, d), ctx=mx.gpu())
x = np.random.normal(0, 1, size=(n, d), ctx=mx.gpu())


def f(x):
    return np.einsum("...d, ...e, de -> ...", x, x, A, optimize=True)


def g(x):
    return (x @ A * x).sum(len(x.shape) - 1)


assert f(x).shape == g(x).shape
np.allclose(f(x), g(x))

In [None]:
%%timeit
f(x).wait_to_read()

In [None]:
%%timeit
g(x).wait_to_read()

In [None]:
%reset -f
import jax

# jax.config.update('jax_platform_name', 'cpu')
print(jax.devices())
import jax.numpy as np
from jax import jit, random

key = random.PRNGKey(0)

n, d = 10_000_000, 10
A = jax.random.normal(key, shape=(d, d))
x = jax.random.normal(key, shape=(n, d))


@jit
def f(x):
    return np.einsum("...d, ...e, de -> ...", x, x, A, optimize=True)


@jit
def g(x):
    return (x @ A * x).sum(len(x.shape) - 1)


assert f(x).shape == g(x).shape
np.allclose(f(x), g(x))

In [None]:
%%timeit
f(x).block_until_ready()

In [None]:
%%timeit
g(x).block_until_ready()

In [None]:
%reset -f
import cupy as np

n, d = 10_000_000, 10
A = np.random.randn(d, d)
x = np.random.randn(n, d)


def f(x):
    return np.einsum("...d, ...e, de -> ...", x, x, A, optimize=True)


def g(x):
    return (x @ A * x).sum(len(x.shape) - 1)


assert f(x).shape == g(x).shape
np.allclose(f(x), g(x))

In [None]:
%%timeit
f(x)

In [None]:
%%timeit
g(x)

In [None]:
%reset -f

import torch
from torch import Tensor, jit

n, d = 10_000_000, 10
A = torch.randn(d, d, device="cuda")
x = torch.randn(n, d, device="cuda")


@jit.script
def f(A: Tensor, x: Tensor) -> Tensor:
    return torch.einsum("...d, ...e, de -> ...", x, x, A)


@jit.script
def g(A: Tensor, x: Tensor):
    return (x @ A * x).sum(len(x.shape) - 1)


assert f(A, x).shape == g(A, x).shape
torch.allclose(f(A, x), g(A, x))

In [None]:
%%timeit
f(A, x)

In [None]:
%%timeit
g(A, x)