# Benchmark


In [None]:
#| default_exp benchmark

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
# |export
# |hide
import time

from itertools import count
import torch
from torch import nn
from torch.nn import functional as F
from torch.cuda.amp.autocast_mode import autocast


import timm
from tqdm.auto import tqdm


In [None]:
# |export
def benchmark(model: nn.Module, # Model to run
                bs: int =32,    # Batch size
                n_batches: int|None =None,  # Number of batches to run. `seconds` must be None
                n_seconds: int|None =None,  # Number of seconds to run. `n_batches` must be None
                fp16: int =False,           # Use Automatic Mixed Precision
                size: int=224,              # Mock-train on this size "images"
                dev: torch.device=torch.device("cuda:0"),): # Device to run on

    """Mock-train the model on random noise input."""

    # There can be only one
    assert not n_batches or not n_seconds
    assert n_batches or n_seconds

    torch.backends.cudnn.benchmark=True
    assert torch.backends.cudnn.is_available()

    model.to(dev)
    optim = torch.optim.Adam(model.parameters(), lr=0.00001, weight_decay=0.00005)

    X = torch.randn((bs, 3, size, size), device=dev)

    # Assume the head is for ImageNet with 1000 catagories.
    y = torch.randint(0, 999, (bs,), device=dev)

    # Warm-up to run cudnn.benchmark first.
    yhat = model(X)

    loss = F.cross_entropy(yhat, y)
    loss.backward()

    optim.step()
    optim.zero_grad(set_to_none=True)

    if n_batches:
        pbar = tqdm(total=n_batches, unit="Batch")
    else:
        pbar = tqdm(total=n_seconds,
            bar_format="{l_bar}{bar}| {n:.1f}/{total} s [{elapsed}<{remaining} {postfix}]")

    start_time = time.time()
    last_time = start_time
    for c in count():
        with autocast(enabled=fp16):
            yhat = model(X)
            loss = F.cross_entropy(yhat, y)

        loss.backward()
        optim.step()
        optim.zero_grad(set_to_none=True)

        if n_batches:
            pbar.update()
            if c+1 == n_batches:
                break

        else:
            now = time.time()
            iter_time =  now - last_time
            run_time = now - start_time
            pbar.update(iter_time)
            if run_time >= n_seconds:
                break
            last_time = now
    pbar.close()

    return ((time.time() - start_time), c*bs)


In [None]:
# |eval: false
model = timm.create_model("resnet50", pretrained=False)
benchmark(model, n_seconds=10)

In [None]:
# |eval: false
benchmark(model, n_batches=10)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()