# Title

In [None]:
%config InteractiveShell.ast_node_interactivity='last_expr_or_assign'  # always print last expr.
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
%matplotlib inline

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
import os
from itertools import product
from time import perf_counter

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import gc

import jax
import mxnet as mx
import numpy as np
import opt_einsum
import pandas as pd
import tensorflow as tf
import torch
from pandas import DataFrame, MultiIndex, Series
from tqdm.auto import tqdm

jax.config.update("jax_platform_name", "cpu")
# jax.config.update("jax_enable_x64", True)
np.set_printoptions(precision=4, floatmode="fixed", suppress=True)
rng = np.random.default_rng()

In [None]:
reductions = []
for a, b in product("ijkl", "ijkl"):
    if a == b:
        continue
    reduction = f"{a}{b},ijkl->" + "ijkl".replace(a, "").replace(b, "")
    reductions.append(reduction)

libs = ["contrat", "numpy", "torch", "jax", "tf", "mx"]
dtypes = ["float32"]
sizes = [64, 128, 256]

TORCH_DTYPES = {
    "int32": torch.int32,
    "int64": torch.int64,
    "float32": torch.float32,
    "float64": torch.float64,
}

devices = [torch.device("cpu"), torch.device("cuda")]
index = Series(reductions, name="reduction")
columns = MultiIndex.from_product([sizes, dtypes, libs], names=["size", "dtype", "lib"])
results = DataFrame(index=index, columns=columns, dtype=float)

In [None]:
# jax results
for size in tqdm(sizes):
    _mat1 = np.random.normal(size=(size, size, size, size))
    _mat2 = np.random.normal(size=(size, size))

    for dtype in tqdm(dtypes, leave=False):
        mat1 = jax.numpy.array(_mat1, dtype=dtype)
        mat2 = jax.numpy.array(_mat2, dtype=dtype)

        for reduction in tqdm(reductions, leave=False):
            gc.disable()
            start = perf_counter()
            jax.numpy.einsum(reduction, mat2, mat1)
            stop = perf_counter()
            gc.enable()
            results.loc[reduction, (size, dtype, "jax")] = stop - start

In [None]:
# torch_results
for size in tqdm(sizes):
    _mat1 = torch.randn((size, size, size, size), device="cpu")
    _mat2 = torch.randn((size, size), device="cpu")

    for dtype in tqdm(dtypes, leave=False):
        mat1 = _mat1.to(dtype=TORCH_DTYPES[dtype])
        mat2 = _mat2.to(dtype=TORCH_DTYPES[dtype])

        for reduction in tqdm(reductions, leave=False):
            gc.disable()
            start = perf_counter()
            torch.einsum(reduction, mat2, mat1)
            stop = perf_counter()
            gc.enable()
            results.loc[reduction, (size, dtype, "torch")] = stop - start

In [None]:
# numpy results
for size in tqdm(sizes):
    _mat1 = np.random.normal(size=(size, size, size, size))
    _mat2 = np.random.normal(size=(size, size))

    for dtype in tqdm(dtypes, leave=False):
        mat1 = _mat1.astype(dtype)
        mat2 = _mat2.astype(dtype)

        for reduction in tqdm(reductions, leave=False):
            gc.disable()
            start = perf_counter()
            np.einsum(reduction, mat2, mat1, optimize=False)
            stop = perf_counter()
            gc.enable()
            results.loc[reduction, (size, dtype, "numpy")] = stop - start

In [None]:
# tensorflow results

In [None]:
results = results.round(3).sort_values(by=[(256, "float32", "jax")])
results.to_csv("einsum_slow_jax.csv")

In [None]:
results.round(2).sort_values(by=[(256, "float32", "jax")])

In [None]:
import jaxlib

print(
    f"{np.__version__=}",
    f"{opt_einsum.__version__=}",
    f"{torch.__version__=}",
    f"{jax.__version__=}",
    f"{jaxlib.__version__=}",
    f"{tf.__version__=}",
    f"{mx.__version__=}",
    sep="\n",
)

In [None]:
results = pd.read_csv("einsum_slow.csv")
results = results.set_index(["size", "dtype", "lib"])
results.columns = results.columns.rename("reduction")
results = results.transpose()
results = results.sort_values(by=[(256, "float64", "numpy"), (256, "float32", "numpy")])

In [None]:
results.loc["min"] = results.min()
results.loc["max"] = results.max()
results.loc["ratio"] = results.loc["max"] / results.loc["min"]

In [None]:
results

In [None]:
results.loc[..., (slice(None), slice(None), "numpy")].round(2)

In [None]:
results.loc[..., (slice(None), slice(None), "numpy")].round(2)