In [None]:
import platform
import psutil
from typing import Tuple, Union
from timeit import timeit
from typing import Callable
from warnings import warn

# PyTorch dependencies
import torch
import torch.backends.opt_einsum as opt_einsum
from torch import Tensor

# Jax Depencencies
import jax
import jax.numpy as jnp
from jax import Array
from jax.experimental.jet import jet
from jax._src.lib import xla_client

# Internal dependencies
from thoad import backward, Controller

In [None]:
# control size of tensors
TENSOR_SCALE: Union[int, float] = 1
REPEAT_SCALE: Union[int, float] = 1

In [3]:
sys: platform.uname_result = platform.uname()
print(f"system           {sys.system} {sys.release} {sys.version}")

system           Windows 11 10.0.26100


In [4]:
torch_dev: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch_dev.type == 'cuda':
    idx = torch_dev.index if torch_dev.index is not None else 0
    props: "_CudaDeviceProperties" = torch.cuda.get_device_properties(idx)
    name: str = props.name
    total_mem_gb: float = props.total_memory / (1024**3)
    print(f"using device     {torch_dev} -> {name}")
    print(f"device memory    {total_mem_gb:.1f} GB)")
else:
    cpu_name: str = platform.processor() or "CPU"
    print(f"using device     {torch_dev} -> {cpu_name}")
    print(f"physical cores   {psutil.cpu_count(logical=False)}")
    print(f"logical cores    {psutil.cpu_count(logical=True)}")

using device     cpu -> AMD64 Family 23 Model 160 Stepping 0, AuthenticAMD
physical cores   4
logical cores    8


In [5]:
_available_gpus: list[xla_client.Device]
_available_gpus = [d for d in jax.devices() if d.platform == "gpu"]
dev_name: xla_client.Device
dev_name = _available_gpus[0] if len(_available_gpus) > 0 else jax.devices("cpu")[0]
print(f"using device     {dev_name}")


using device     TFRT_CPU_0


In [None]:
if opt_einsum.is_available():
    opt_einsum.enabled = True
    print("opt_einsum backend enabled")
    opt_einsum.strategy = "optimal"
else:
    warn(
        "opt_einsum backend is not available. "
        "For better performance, install and enable opt_einsum.",
        UserWarning
    )

opt_einsum backend enabled


jax.experimental.jet documentation example (https://docs.jax.dev/en/latest/jax.experimental.jet.html)

In [7]:
h0: Array = jnp.array([0.1, 0.2, 0.3])
h1: Array = jnp.ones_like(h0)
h2: Array = jnp.zeros_like(h0)

f: Callable[[Array], Array]
df: Callable[[Array], Array]
ddf: Callable[[Array], Array]
f, df, ddf = jnp.sin, jnp.cos, lambda x: -jnp.sin(x)

f0: Array
f1: Array
f2: Array
f0, (f1, f2) = jax.experimental.jet.jet(f, (h0,), ((h1, h2),))

print("f   :", (jet_f0 := f(h0)))
print("df*h1:", (jet_f1 := df(h0) * h1))               
print("ddf*h1^2 + df*h2:", (jet_f2 := ddf(h0) * h1**2 + df(h0) * h2))
assert jnp.allclose(f0, jet_f0)
assert jnp.allclose(f1, jet_f1)
assert jnp.allclose(f2, jet_f2)

f   : [0.09983342 0.19866933 0.29552022]
df*h1: [0.9950042 0.9800666 0.9553365]
ddf*h1^2 + df*h2: [-0.09983342 -0.19866933 -0.29552022]


## **Benchmark differentiations on full MLP**

definition of MLP

In [8]:
def jax_forward_pass(X: jax.Array, *params: jax.Array) -> jax.Array:
    T: Array = X
    for i, P in enumerate(params):
        last_step: bool = (i == len(params) - 1)
        T = T @ P
        T = jnp.maximum(T, 0) if not last_step else jax.nn.softmax(T, axis=1)
    return T

def torch_foward_pass(X: Tensor, *params) -> Tensor:
    T: Tensor = X
    for i, P in enumerate(params):
        last_step: bool = i == (len(params) - 1)
        T = T @ P
        T = torch.softmax(T, dim=1) if last_step else torch.relu(T)
    return T

definition of helper functions to meassure differentiation times

In [9]:
def time_jet_differentiation(reps: int, order: int, X: jax.Array, *params) -> float:
    def _fwd(x) -> Array:
        return jax_forward_pass(x, *params)
    seed_tangents: Tuple[Array, ...] = tuple(jnp.ones_like(X) for _ in range(order))
    series: Tuple[Tuple[Array, ...]] = (seed_tangents,)
    jet(_fwd, (X,), series)  # warm up once to avoid including compile time
    return timeit(lambda: jet(_fwd, (X,), series), number=reps)

def time_thoad_differentiation(reps: int, order:int, X: Tensor, *params) -> float:
    X.requires_grad_(True)
    params: list[Tensor] = [P.requires_grad_(False) for P in params]
    assert all(not param.requires_grad for param in params)
    def _foward_and_backward() -> None:
        T: Tensor = torch_foward_pass(X, *params)
        ctrl: Controller = backward(tensor=T, order=order, crossings=False, keep_batch=True)
        ctrl.clear()
        return None
    time: float = timeit(
        lambda: _foward_and_backward(),
        number=reps,
    )
    return time

differentiation computational cost w.r.t. **order** and **batch size**

In [None]:
for o in [1, 2, 3]:
    print(f"\nORDER {o}")
    for batch_size in [15, 30, 45, 60, 75, 90]:
        batch_size //= o
        param_size: int = int(5 * TENSOR_SCALE)
        x_shape: Tuple[int, int] = (batch_size, param_size)
        p_shape: Tuple[int, int] = (param_size, param_size)

        # create jax tensors
        key: Array = jax.random.PRNGKey(0)
        jax_X: Array = jax.random.uniform(key, x_shape)
        jax_params: list[Array] = []
        for i in range(3):
            key: Array
            subkey: Array
            key, subkey = jax.random.split(key)
            jax_params.append(jax.random.uniform(subkey, p_shape))

        # create torch tensors
        torch_X: Tensor = torch.rand(size=x_shape, device=torch_dev)
        torch_params: list[Tensor] = [
            torch.rand(size=p_shape, device=torch_dev) for _ in range(3)
        ]

        reps: int = int(600 * (1 / batch_size) * (1/o) * REPEAT_SCALE)
        jax_time: float = time_jet_differentiation(reps, o, jax_X, *jax_params)
        thoad_time: float = time_thoad_differentiation(reps, o, torch_X, *torch_params)
        print(
            f"batch size: {batch_size:02d} -> "
            f"jax time: {jax_time/reps:.4f}  thoad time: {thoad_time/reps:.4f}"
        )



ORDER 1
batch size: 15 -> jax time: 0.0035  thoad time: 0.0086
batch size: 30 -> jax time: 0.0034  thoad time: 0.0085
batch size: 45 -> jax time: 0.0049  thoad time: 0.0102
batch size: 60 -> jax time: 0.0042  thoad time: 0.0122
batch size: 75 -> jax time: 0.0037  thoad time: 0.0118
batch size: 90 -> jax time: 0.0040  thoad time: 0.0148

ORDER 2
batch size: 07 -> jax time: 0.0055  thoad time: 0.0154
batch size: 15 -> jax time: 0.0043  thoad time: 0.0173
batch size: 22 -> jax time: 0.0043  thoad time: 0.0171
batch size: 30 -> jax time: 0.0062  thoad time: 0.0262
batch size: 37 -> jax time: 0.0044  thoad time: 0.0248
batch size: 45 -> jax time: 0.0052  thoad time: 0.0253

ORDER 3
batch size: 05 -> jax time: 0.0057  thoad time: 0.0268
batch size: 10 -> jax time: 0.0056  thoad time: 0.0264
batch size: 15 -> jax time: 0.0069  thoad time: 0.0320
batch size: 20 -> jax time: 0.0081  thoad time: 0.0438
batch size: 25 -> jax time: 0.0052  thoad time: 0.0577
batch size: 30 -> jax time: 0.0062  th

differentiation computational cost w.r.t. **order** and **param size**

In [None]:
for o in [1, 2, 3]:
    print(f"\nORDER {o}")
    for param_size in [10, 20, 30, 40, 50, 60]:
        batch_size: int = int(5 * TENSOR_SCALE)
        param_size //= o
        x_shape: Tuple[int, int] = (batch_size, param_size)
        p_shape: Tuple[int, int] = (param_size, param_size)

        # create jax tensors
        key: Array = jax.random.PRNGKey(0)
        jax_X: Array = jax.random.uniform(key, x_shape)
        jax_params: list[Array] = []
        for i in range(3):
            key: Array
            subkey: Array
            key, subkey = jax.random.split(key)
            jax_params.append(jax.random.uniform(subkey, p_shape))

        # create torch tensors
        torch_X: Tensor = torch.rand(size=x_shape, device=torch_dev)
        torch_params: list[Tensor] = [
            torch.rand(size=p_shape, device=torch_dev) for _ in range(3)
        ]

        reps: int = int(600 * (1 / param_size) * (1/o) * REPEAT_SCALE)
        jax_time: float = time_jet_differentiation(reps, o, jax_X, *jax_params)
        thoad_time: float = time_thoad_differentiation(reps, o, torch_X, *torch_params)
        print(
            f"param size: {param_size:02d} -> "
            f"jax time: {jax_time/reps:.4f}  thoad time: {thoad_time/reps:.4f}"
        )




ORDER 1
param size: 10 -> jax time: 0.0044  thoad time: 0.0089
param size: 20 -> jax time: 0.0045  thoad time: 0.0125
param size: 30 -> jax time: 0.0039  thoad time: 0.0104
param size: 40 -> jax time: 0.0044  thoad time: 0.0102
param size: 50 -> jax time: 0.0045  thoad time: 0.0106
param size: 60 -> jax time: 0.0065  thoad time: 0.0199

ORDER 2
param size: 05 -> jax time: 0.0066  thoad time: 0.0161
param size: 10 -> jax time: 0.0064  thoad time: 0.0202
param size: 15 -> jax time: 0.0061  thoad time: 0.0224
param size: 20 -> jax time: 0.0074  thoad time: 0.0225
param size: 25 -> jax time: 0.0044  thoad time: 0.0275
param size: 30 -> jax time: 0.0079  thoad time: 0.0483

ORDER 3
param size: 03 -> jax time: 0.0054  thoad time: 0.0297
param size: 06 -> jax time: 0.0057  thoad time: 0.0268
param size: 10 -> jax time: 0.0090  thoad time: 0.0477
param size: 13 -> jax time: 0.0059  thoad time: 0.0599
param size: 16 -> jax time: 0.0060  thoad time: 0.0941
param size: 20 -> jax time: 0.0071  th

differentiation computational cost w.r.t. **graph depth** (param gradients included)

In [None]:
for o in [1, 2, 3]:
    print(f"\nORDER {o}")
    for depth in [2, 3, 4, 5, 6, 7, 8]:
        batch_size: int = 40 // o
        param_size: int = int(10 // o * TENSOR_SCALE)
        x_shape: Tuple[int, int] = (batch_size, param_size)
        p_shape: Tuple[int, int] = (param_size, param_size)

        # create jax tensors
        key: Array = jax.random.PRNGKey(0)
        jax_X: Array = jax.random.uniform(key, x_shape)
        jax_params: list[Array] = []
        for i in range(depth):
            key: Array
            subkey: Array
            key, subkey = jax.random.split(key)
            jax_params.append(jax.random.uniform(subkey, p_shape))

        # create torch tensors
        torch_X: Tensor = torch.rand(size=x_shape, device=torch_dev)
        torch_params: list[Tensor] = [
            torch.rand(size=p_shape, device=torch_dev) for _ in range(depth)
        ]

        reps: int = int(600 * (1 / depth) * (1/o) * REPEAT_SCALE)
        jax_time: float = time_jet_differentiation(reps, o, jax_X, *jax_params)
        thoad_time: float = time_thoad_differentiation(reps, o, torch_X, *torch_params)
        print(
            f"depth size: {depth:02d} -> "
            f"jax time: {jax_time/reps:.4f}  thoad time: {thoad_time/reps:.4f}"
        )


ORDER 1
depth size: 02 -> jax time: 0.0047  thoad time: 0.0115
depth size: 03 -> jax time: 0.0039  thoad time: 0.0150
depth size: 04 -> jax time: 0.0059  thoad time: 0.0184
depth size: 05 -> jax time: 0.0066  thoad time: 0.0243
depth size: 06 -> jax time: 0.0080  thoad time: 0.0243
depth size: 07 -> jax time: 0.0069  thoad time: 0.0287
depth size: 08 -> jax time: 0.0098  thoad time: 0.0290

ORDER 2
depth size: 02 -> jax time: 0.0052  thoad time: 0.0152
depth size: 03 -> jax time: 0.0048  thoad time: 0.0179
depth size: 04 -> jax time: 0.0060  thoad time: 0.0227
depth size: 05 -> jax time: 0.0104  thoad time: 0.0274
depth size: 06 -> jax time: 0.0086  thoad time: 0.0333
depth size: 07 -> jax time: 0.0109  thoad time: 0.0381
depth size: 08 -> jax time: 0.0119  thoad time: 0.0394

ORDER 3
depth size: 02 -> jax time: 0.0048  thoad time: 0.0265
depth size: 03 -> jax time: 0.0074  thoad time: 0.0257
depth size: 04 -> jax time: 0.0077  thoad time: 0.0340
depth size: 05 -> jax time: 0.0124  th