In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timeit import default_timer
import torchvision
import contextlib

In [2]:
# Datagen
def datagen(batch_size = 1, device='cpu'):
    while(True):
        inp = torch.rand(batch_size, 3, 224, 224, dtype=torch.float32, device=device)
        label = torch.rand(batch_size, 1000, dtype=torch.float32, device=device)
        yield (inp, label)


In [3]:
@contextlib.contextmanager
def timer(enable = True):
    if not enable:
        yield
    else:
        try:
            start = default_timer()
            yield
        finally:
            torch.cuda.synchronize()
            stop = default_timer()
            print(f'Time: {(stop - start) * 1000}ms')

In [4]:
def train(model, loss_fn, opt, datagen, num_iter):
    model.train()
    num_iter = num_iter
    for _ in range(num_iter):
        inp, lbl = next(datagen)
        with timer(enable=True) as t:
            out = model(inp)
            loss = loss_fn(out, lbl)
            loss.backward()
            opt.zero_grad()
            opt.step()

In [5]:
def train_cg(g, static_inp, static_lbl, data, num_iter):
    for _ in range(num_iter):
        inp, lbl = next(data)
        with timer(enable=True) as t:
            static_inp.copy_(inp)
            static_lbl.copy_(lbl)
            g.replay()

In [9]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_iter = 100

model = torchvision.models.resnet50().to(device)
print(model)

loss_fn = nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters())
batch_size = 128

### DATAGEN
data = datagen(batch_size, device)

### WARMUP
train(model, loss_fn, opt, data, 3)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [12]:

print('Training Started')
train(model, loss_fn, opt, data, num_iter)
print('Training Done')


Training Started
Time: 321.23929634690285ms
Time: 370.78429386019707ms
Time: 377.4437755346298ms
Time: 375.01807883381844ms
Time: 377.60650366544724ms
Time: 382.1890316903591ms
Time: 370.5022782087326ms
Time: 385.4169398546219ms
Time: 383.84488597512245ms
Time: 374.608650803566ms
Time: 378.77947464585304ms
Time: 383.0159120261669ms
Time: 385.30611991882324ms
Time: 386.6407461464405ms
Time: 388.49441334605217ms
Time: 373.73048439621925ms
Time: 382.7702924609184ms
Time: 384.64372232556343ms
Time: 389.25352320075035ms
Time: 387.6347355544567ms
Time: 383.311253041029ms
Time: 380.1971711218357ms
Time: 386.09829545021057ms
Time: 383.5417330265045ms
Time: 386.02547720074654ms
Time: 388.1518207490444ms
Time: 380.47298416495323ms
Time: 388.4025812149048ms
Time: 378.3841282129288ms
Time: 381.9281794130802ms
Time: 390.5490040779114ms
Time: 379.3046772480011ms
Time: 392.48355850577354ms
Time: 383.18243622779846ms
Time: 379.3654218316078ms
Time: 387.1603533625603ms
Time: 383.1409737467766ms
Time: 3

In [13]:

g = torch.cuda.CUDAGraph()
static_inp = torch.randn(batch_size, 3, 224, 224, device=device)
static_lbl = torch.randn(batch_size, 1000, device=device)
opt.zero_grad()
with torch.cuda.graph(g):
    model.train()
    static_out = model(static_inp)
    static_loss = loss_fn(static_out, static_lbl)
    static_loss.backward()
    opt.step()
### CUDA Graph Train
print('CUDA Graph Training Started')
train_cg(g, static_inp, static_lbl, data, num_iter)
print('CUDA Graph Training Done')

CUDA Graph Training Started
Time: 155.4575152695179ms
Time: 161.14727780222893ms
Time: 172.48966172337532ms
Time: 172.34009131789207ms
Time: 168.95577311515808ms
Time: 163.2806472480297ms
Time: 168.46423596143723ms
Time: 186.6588443517685ms
Time: 255.35886734724045ms
Time: 270.74969559907913ms
Time: 275.9062945842743ms
Time: 284.33337062597275ms
Time: 280.3274691104889ms
Time: 281.9564491510391ms
Time: 278.3804051578045ms
Time: 274.300180375576ms
Time: 277.1734707057476ms
Time: 290.35187512636185ms
Time: 313.6036656796932ms
Time: 313.472256064415ms
Time: 300.31537637114525ms
Time: 314.2014890909195ms
Time: 355.7199090719223ms
Time: 356.15959390997887ms
Time: 352.2270806133747ms
Time: 357.5357720255852ms
Time: 362.9982881247997ms
Time: 366.3595952093601ms
Time: 359.4311624765396ms
Time: 351.4114320278168ms
Time: 355.26909306645393ms
Time: 352.2370830178261ms
Time: 364.85184356570244ms
Time: 385.47782972455025ms
Time: 386.76170632243156ms
Time: 383.4626004099846ms
Time: 387.8740146756172