In [1]:
import torch
import time
import os

from torch import nn
import torchvision.models as models
from triton.testing import do_bench
import torch._dynamo

In [2]:
torch.set_float32_matmul_precision('high')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
def run_benchmark(fn):
    exec_time, prctl20, prctl80 = do_bench(fn,warmup=100,rep=1000)
    print(f"Exec time (median): {exec_time}")
    print(f"Exec time (20th percentile): {prctl20}")
    print(f"Exec time (80th percentile): {prctl80}\n")
    return exec_time

## 1. ResNet50 Speedup on NVIDIA A10G

In [4]:
def run_batch_train(model, optimizer, batch=16):
    x = torch.randn(16, 3, 224, 224).to(device)
    optimizer.zero_grad()
    out = model(x)
    out.sum().backward()
    optimizer.step()
    
def run_batch_inference(model, batch=16):
    x = torch.randn(16, 3, 224, 224).to(device)
    with torch.inference_mode():
        out = model(x)

In [5]:
model = models.resnet101(pretrained=True).to(device)

Downloading: "https://download.pytorch.org/models/resnet101-63fe2227.pth" to /root/.cache/torch/hub/checkpoints/resnet101-63fe2227.pth
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 171M/171M [00:06<00:00, 27.5MB/s]


In [6]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Benchmark Eager
print("Resnet50 Eager mode")
exec_time = run_benchmark(lambda: run_batch(model, optimizer))

# Benchmark torch.compile defaults
print("Resnet50 Compiled defaults")
opt_model = torch.compile(model)
opt_exec_time = run_benchmark(lambda: run_batch(opt_model, optimizer))

# Print speedups
print(f"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%")

Resnet50 Eager mode
Exec time (median): 60.049407958984375
Exec time (20th percentile): 59.96953582763672
Exec time (80th percentile): 60.07562255859375

Resnet50 Compiled defaults
Exec time (median): 56.342529296875
Exec time (20th percentile): 56.3138542175293
Exec time (80th percentile): 56.39004898071289

speedup:  6.58%


## 2. Custom model Speedup on NVIDIA A10G

In [7]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1024, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        
    def forward(self, x):
        x = self.fc1(x).relu() ** 2
        return self.fc2(x).relu() ** 2

In [8]:
model = MLP().to(device)
x = torch.randn(1024, 1024).to(device)

In [9]:
# Benchmark Eager
exec_time = run_benchmark(lambda: model(x).sum().backward())

torch._dynamo.reset()
# Benchmark torch.compile defaults
cmodel = torch.compile(model, backend='inductor')
opt_exec_time = run_benchmark(lambda: cmodel(x).sum().backward())

# Print speedups
print(f"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%")

Exec time (median): 1.3148159980773926
Exec time (20th percentile): 1.3096959590911865
Exec time (80th percentile): 1.3281279802322388

Exec time (median): 1.1950080394744873
Exec time (20th percentile): 1.1909120082855225
Exec time (80th percentile): 1.2083200216293335

speedup:  10.03%


## 3. HuggingFace model Speedup on NVIDIA A10G

In [10]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset

def run_inference(model, input_values):
    
    # retrieve logits
    logits = model(input_values).logits
    
    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)

In [17]:
# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").cuda()

# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")

# tokenize
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values.cuda()

Downloading (…)rocessor_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/162 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

KeyboardInterrupt: 

In [12]:
exec_time = run_benchmark(lambda: run_inference(model, input_values))

torch._dynamo.reset()
model = torch.compile(model, mode="max-autotune")
opt_exec_time = run_benchmark(lambda: run_inference(model, input_values))

# Print speedups
print(f"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%")

Exec time (median): 29.5565128326416
Exec time (20th percentile): 29.487781524658203
Exec time (80th percentile): 29.754579544067383



AUTOTUNE bias_addmm(232x1024, 232x512, 512x1024)
  bias_addmm 0.0389s 100.0%
  addmm 0.0522s 74.5%
  triton_mm_0 0.0819s 47.5%
  triton_mm_10 0.0840s 46.3%
  triton_mm_8 0.1209s 32.2%
  triton_mm_4 0.1239s 31.4%
  triton_mm_2 0.1239s 31.4%
  triton_mm_3 0.1260s 30.9%
  triton_mm_1 0.1260s 30.9%
  triton_mm_11 0.1300s 29.9%
AUTOTUNE bias_addmm(232x1024, 232x1024, 1024x1024)
  bias_addmm 0.0717s 100.0%
  addmm 0.0798s 89.9%
  triton_mm_12 0.1577s 45.5%
  triton_mm_22 0.1628s 44.0%
  triton_mm_20 0.2353s 30.5%
  triton_mm_16 0.2398s 29.9%
  triton_mm_14 0.2405s 29.8%
  triton_mm_13 0.2447s 29.3%
  triton_mm_15 0.2447s 29.3%
  triton_mm_23 0.2540s 28.2%
AUTOTUNE bmm(16x232x64, 16x64x232)
  bmm 0.0236s 100.0%
  triton_bmm_55 0.0328s 71.9%
  triton_bmm_58 0.0391s 60.2%
  triton_bmm_51 0.0399s 59.0%
  triton_bmm_49 0.0399s 59.0%
  triton_bmm_48 0.0399s 59.0%
  triton_bmm_50 0.0410s 57.5%
  triton_bmm_52 0.0410s 57.5%
  triton_bmm_59 0.0614s 38.3%
  triton_bmm_56 0.0645s 36.5%
AUTOTUNE bmm(16x

Exec time (median): 28.666912078857422
Exec time (20th percentile): 28.57915496826172
Exec time (80th percentile): 28.744203567504883

speedup:  3.10%


In [13]:
torch._dynamo.list_backends()

['aot_ts_nvfuser',
 'cudagraphs',
 'inductor',
 'ipex',
 'nvprims_nvfuser',
 'onnxrt',
 'tvm']