# CUDA vs CPU RLCT estimation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/timaeus-research/devinterp/blob/main/examples/cuda_benchmark.ipynb)

This notebook measures how fast RLCT estimation is on CUDA vs on CPU. We check this using a standard normal crossing model.

In [1]:
%pip install devinterp 

You should consider upgrading via the '/home/paperspace/devinterp/testvenv/bin/python -m pip install --upgrade pip' command.[0m[33m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
from functools import partial
import timeit

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

from devinterp.optim.sgld import SGLD
from devinterp.slt import estimate_learning_coeff
from devinterp.zoo.normal_crossing import PolyModel

assert torch.cuda.is_available()

In [3]:
sigma = 0.25
lr = 0.0005
criterion = F.mse_loss


def timeit_rlct_estimation_wrapper(model, device, cores):
    return estimate_learning_coeff(
        model,
        train_loader,
        criterion=criterion,
        optimizer_kwargs=dict(
            lr=lr,
            bounding_box_size=1.0,
            num_samples=len(train_data),
        ),
        sampling_method=SGLD,
        num_chains=cores,
        num_draws=1_000,
        num_burnin_steps=0,
        num_steps_bw_draws=1,
        verbose=False,
        device=device,
        cores=cores,
    )


num_train_samples = 50_000
batch_size = num_train_samples
x = torch.normal(0, 2, size=(num_train_samples,))
y = sigma * torch.normal(0, 1, size=(num_train_samples,))
train_data = TensorDataset(x, y)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
for cores in (1, 4):
    for device in ("cuda", "cpu"):
        powers = torch.tensor([1, 2], device=device)
        model = PolyModel(powers)
        w_true = torch.zeros_like(powers)
        timeit_rlct_function = partial(timeit_rlct_estimation_wrapper, *(model, device, cores))
        time_taken = timeit.timeit(timeit_rlct_function, number=5)
        print(
            f"{num_train_samples} samples on {device}, {cores} cores/chains: {time_taken:.2f}s per estimate"
        )

50000 samples on cuda, 1 cores/chains: 12.27s per estimate
50000 samples on cpu, 1 cores/chains: 17.86s per estimate
50000 samples on cuda, 4 cores/chains: 37.81s per estimate
50000 samples on cpu, 4 cores/chains: 493.21s per estimate
