In [19]:
import platform
import psutil
from typing import Tuple, Union
from timeit import timeit
from warnings import warn

# PyTorch dependencies
import torch
import torch.backends.opt_einsum as opt_einsum
from torch import Tensor

# Internal dependencies
from thoad import config
from thoad import backward, Controller

In [None]:
# control size of tensors
TENSOR_SCALE: Union[int, float] = 1
REPEAT_SCALE: Union[int, float] = 1

In [21]:
sys: platform.uname_result = platform.uname()
print(f"system           {sys.system} {sys.release} {sys.version}")

system           Windows 11 10.0.26100


In [22]:
dev: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if dev.type == 'cuda':
    idx = dev.index if dev.index is not None else 0
    props: "_CudaDeviceProperties" = torch.cuda.get_device_properties(idx)
    name: str = props.name
    total_mem_gb: float = props.total_memory / (1024**3)
    print(f"using device     {dev} -> {name}")
    print(f"device memory    {total_mem_gb:.1f} GB)")
else:
    cpu_name: str = platform.processor() or "CPU"
    print(f"using device     {dev} -> {cpu_name}")
    print(f"physical cores   {psutil.cpu_count(logical=False)}")
    print(f"logical cores    {psutil.cpu_count(logical=True)}")

using device     cpu -> AMD64 Family 23 Model 160 Stepping 0, AuthenticAMD
physical cores   4
logical cores    8


In [23]:
if opt_einsum.is_available():
    opt_einsum.enabled = True
    opt_einsum.strategy = "greedy"
    print("opt_einsum backend enabled")
else:
    warn(
        "opt_einsum backend is not available. "
        "For better performance, install and enable opt_einsum.",
        UserWarning
    )

  warn(


definition of MLP

In [24]:
def foward_pass(X: Tensor, *params) -> Tensor:
    T: Tensor = X
    for i, P in enumerate(params):
        last_step: bool = i == (len(params) - 1)
        T = T @ P
        T = torch.softmax(T, dim=1) if last_step else torch.relu(T)
    return T.sum()

definition of helper function to meassure differentiation times

In [25]:
def time_differentiation(
        reps: int,
        param_grad: bool,
        keep_batch: bool,
        keep_schwarz: bool,
        order: int,
        X: Tensor,
        *params,
    ) -> float:
    X.requires_grad_(True)
    params: list[Tensor] = [P.requires_grad_(param_grad) for P in params]
    def _foward_and_backward() -> None:
        T: Tensor = foward_pass(X, *params)
        ctrl: Controller = backward(
            tensor=T,
            order=order,
            crossings=param_grad,
            keep_batch=keep_batch,
            keep_schwarz=keep_schwarz,
        )
        ctrl.clear()
        return None
    time: float = timeit(
        lambda: _foward_and_backward(),
        number=reps,
    )
    return time

## **Benchmark optimizations**

computational cost w.r.t. **batch size**

In [26]:
for o in [1, 2, 3]:
    print(f"\nORDER {o}")
    for batch_size in [10, 20, 30, 40, 50, 60, 70, 80]:
        param_size: int = int(10 * (1 / o) * TENSOR_SCALE)
        x_shape: Tuple[int, int] = (batch_size, param_size)
        p_shape: Tuple[int, int] = (param_size, param_size)

        X: Tensor = torch.rand(size=x_shape, device=dev)
        params: list[Tensor] = [torch.rand(size=p_shape, device=dev) for _ in range(3)]

        config.SCHWARZ_OPTIMIZATION = False
        reps: int = int(200 * (1/batch_size) * REPEAT_SCALE)
        
        config.BATCH_OPTIMIZATION = False
        regular_time: float = time_differentiation(
            reps, True, True, False, o, X, *params
        )
        config.BATCH_OPTIMIZATION = True
        optimized_time: float = time_differentiation(
            reps, True, True, False, o, X, *params
        )

        print(
            f"batch size: {batch_size:03d} -> "
            f"baseline: {regular_time / reps:.4f}  "
            f"optimized: {optimized_time / reps:.4f}"
        )


ORDER 1
batch size: 010 -> baseline: 0.0302  optimized: 0.0226
batch size: 020 -> baseline: 0.0204  optimized: 0.0182
batch size: 030 -> baseline: 0.0167  optimized: 0.0198
batch size: 040 -> baseline: 0.0167  optimized: 0.0191
batch size: 050 -> baseline: 0.0185  optimized: 0.0205
batch size: 060 -> baseline: 0.0183  optimized: 0.0168
batch size: 070 -> baseline: 0.0185  optimized: 0.0182
batch size: 080 -> baseline: 0.0197  optimized: 0.0191

ORDER 2
batch size: 010 -> baseline: 0.0636  optimized: 0.0593
batch size: 020 -> baseline: 0.0591  optimized: 0.0556
batch size: 030 -> baseline: 0.0596  optimized: 0.0629
batch size: 040 -> baseline: 0.0529  optimized: 0.0582
batch size: 050 -> baseline: 0.0544  optimized: 0.0552
batch size: 060 -> baseline: 0.0658  optimized: 0.0574
batch size: 070 -> baseline: 0.0744  optimized: 0.0649
batch size: 080 -> baseline: 0.0700  optimized: 0.0618

ORDER 3
batch size: 010 -> baseline: 0.2732  optimized: 0.2962
batch size: 020 -> baseline: 0.3937  o

computational cost w.r.t. **depth ~ number of terminal nodes**

In [27]:
for o in [1, 2, 3]:
    print(f"\nORDER {o}")
    for depth in range(2, 7):
        batch_size: int = 10
        param_size: int = int(10 * (1 / o) * TENSOR_SCALE)
        x_shape: Tuple[int, int] = (batch_size, param_size)
        p_shape: Tuple[int, int] = (param_size, param_size)

        X: Tensor = torch.rand(size=x_shape, device=dev)
        params: list[Tensor] = [torch.rand(size=p_shape, device=dev) for _ in range(depth)]

        config.BATCH_OPTIMIZATION = False
        reps: int = int(200 * (1/depth) * REPEAT_SCALE)

        config.SCHWARZ_OPTIMIZATION = False
        baseline_time: float = time_differentiation(
            reps, True, False, True, o, X, *params
        )
        config.SCHWARZ_OPTIMIZATION = True
        optimized_time: float = time_differentiation(
            reps, True, False, True, o, X, *params
        )

        print(
            f"graph depth: {depth:02d} -> "
            f"baseline: {baseline_time / reps:.4f}  "
            f"optimized: {optimized_time / reps:.4f}"
        )


ORDER 1
graph depth: 02 -> baseline: 0.0185  optimized: 0.0146
graph depth: 03 -> baseline: 0.0212  optimized: 0.0189
graph depth: 04 -> baseline: 0.0254  optimized: 0.0286
graph depth: 05 -> baseline: 0.0315  optimized: 0.0298
graph depth: 06 -> baseline: 0.0386  optimized: 0.0417

ORDER 2
graph depth: 02 -> baseline: 0.0329  optimized: 0.0297
graph depth: 03 -> baseline: 0.0919  optimized: 0.0509
graph depth: 04 -> baseline: 0.0987  optimized: 0.0801
graph depth: 05 -> baseline: 0.1432  optimized: 0.0914
graph depth: 06 -> baseline: 0.1742  optimized: 0.1564

ORDER 3
graph depth: 02 -> baseline: 0.1404  optimized: 0.0668
graph depth: 03 -> baseline: 0.2618  optimized: 0.1318
graph depth: 04 -> baseline: 0.5782  optimized: 0.2635
graph depth: 05 -> baseline: 1.0093  optimized: 0.4209
graph depth: 06 -> baseline: 1.8741  optimized: 0.7664
