In [1]:
import torch
import time
import os

from torch import nn
import torchvision.models as models
from triton.testing import do_bench
import torch._dynamo

In [2]:
torch.set_float32_matmul_precision('high')
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
def run_benchmark(fn):
    exec_time, prctl20, prctl80 = do_bench(fn,warmup=100,rep=1000)
    print(f"Exec time (median): {exec_time}")
    print(f"Exec time (20th percentile): {prctl20}")
    print(f"Exec time (80th percentile): {prctl80}\n")
    return exec_time

## 1. ResNet50 Speedup on NVIDIA A10G

In [4]:
def run_batch(model, optimizer):
    x = torch.randn(16, 3, 224, 224).to(device)
    optimizer.zero_grad()
    out = model(x)
    out.sum().backward()
    optimizer.step()

In [5]:
model = models.resnet50().to(device)

In [6]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# Benchmark Eager
print("Resnet50 Eager mode")
exec_time = run_benchmark(lambda: run_batch(model, optimizer))

# Benchmark torch.compile defaults
print("Resnet50 Compiled defaults")
opt_model = torch.compile(model)
opt_exec_time = run_benchmark(lambda: run_batch(opt_model, optimizer))

# Print speedups
print(f"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%")

Resnet50 Eager mode
Exec time (median): 50.456573486328125
Exec time (20th percentile): 50.4313850402832
Exec time (80th percentile): 50.50531768798828

Resnet50 Compiled defaults
Exec time (median): 46.913536071777344
Exec time (20th percentile): 46.89653778076172
Exec time (80th percentile): 46.9372673034668

speedup:  7.55%


Process ForkProcess-4:
Process ForkProcess-1:
Process ForkProcess-3:
Process ForkProcess-2:


## 2. Custom model Speedup on NVIDIA A10G

In [7]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1024, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        
    def forward(self, x):
        x = self.fc1(x).relu() ** 2
        return self.fc2(x).relu() ** 2

In [8]:
model = MLP().to(device)
x = torch.randn(1024, 1024).to(device)

In [9]:
# Benchmark Eager
exec_time = run_benchmark(lambda: model(x).sum().backward())

torch._dynamo.reset()
# Benchmark torch.compile defaults
cmodel = torch.compile(model, backend='inductor')
opt_exec_time = run_benchmark(lambda: cmodel(x).sum().backward())

# Print speedups
print(f"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%")

Exec time (median): 0.7147520184516907
Exec time (20th percentile): 0.7127040028572083
Exec time (80th percentile): 0.7157760262489319

Exec time (median): 0.6010879874229431
Exec time (20th percentile): 0.6000639796257019
Exec time (80th percentile): 0.6021119952201843

speedup:  18.91%


## 3. HuggingFace model Speedup on NVIDIA A10G

In [10]:
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
from datasets import load_dataset

def run_inference(model, input_values):
    
    # retrieve logits
    logits = model(input_values).logits
    
    # take argmax and decode
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)

In [11]:
# load model and processor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").cuda()

# load dummy dataset and read soundfiles
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")

# tokenize
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values.cuda()

Downloading (…)rocessor_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/162 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/291 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Downloading builder script:   0%|          | 0.00/5.16k [00:00<?, ?B/s]

Downloading and preparing dataset librispeech_asr_dummy/clean to /root/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating validation split: 0 examples [00:00, ? examples/s]

Dataset librispeech_asr_dummy downloaded and prepared to /root/.cache/huggingface/datasets/patrickvonplaten___librispeech_asr_dummy/clean/2.1.0/f2c70a4d03ab4410954901bde48c54b85ca1b7f9bf7d616e7e2a72b5ee6ddbfc. Subsequent calls will reuse this data.


It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


In [12]:
exec_time = run_benchmark(lambda: run_inference(model, input_values))

torch._dynamo.reset()
model = torch.compile(model, mode="max-autotune")
opt_exec_time = run_benchmark(lambda: run_inference(model, input_values))

# Print speedups
print(f"speedup: {100*(exec_time-opt_exec_time) / opt_exec_time: .2f}%")

Exec time (median): 30.511903762817383
Exec time (20th percentile): 30.503211975097656
Exec time (80th percentile): 30.539411544799805



AUTOTUNE bias_addmm(544x1024, 544x512, 512x1024)
  triton_mm_1 0.0276s 100.0%
  triton_mm_3 0.0276s 100.0%
  triton_mm_2 0.0287s 96.4%
  triton_mm_4 0.0287s 96.4%
  bias_addmm 0.0287s 96.4%
  triton_mm_0 0.0297s 93.1%
  triton_mm_8 0.0328s 84.4%
  triton_mm_10 0.0328s 84.4%
  triton_mm_11 0.0348s 79.4%
  triton_mm_5 0.0389s 71.1%
AUTOTUNE bias_addmm(544x1024, 544x1024, 1024x1024)
  triton_mm_13 0.0481s 100.0%
  triton_mm_15 0.0481s 100.0%
  triton_mm_14 0.0492s 97.9%
  triton_mm_16 0.0492s 97.9%
  triton_mm_12 0.0532s 90.4%
  triton_mm_20 0.0532s 90.4%
  bias_addmm 0.0543s 88.7%
  triton_mm_22 0.0543s 88.7%
  triton_mm_23 0.0594s 81.0%
  triton_mm_18 0.0635s 75.8%
AUTOTUNE bmm(16x544x64, 16x64x544)
  triton_bmm_56 0.0553s 100.0%
  triton_bmm_48 0.0553s 100.0%
  triton_bmm_58 0.0573s 96.4%
  triton_bmm_59 0.0584s 94.7%
  triton_bmm_52 0.0604s 91.5%
  triton_bmm_49 0.0604s 91.5%
  triton_bmm_50 0.0604s 91.5%
  triton_bmm_51 0.0614s 90.0%
  bmm 0.0614s 90.0%
  triton_bmm_55 0.0676s 81.8%


Exec time (median): 27.311792373657227
Exec time (20th percentile): 27.307104110717773
Exec time (80th percentile): 27.3222713470459

speedup:  11.72%


In [13]:
torch._dynamo.list_backends()

['aot_ts_nvfuser',
 'cudagraphs',
 'inductor',
 'ipex',
 'nvprims_nvfuser',
 'onnxrt',
 'tvm']