In [None]:
import time
import torch
import math
from torchvision.models import resnet
import numpy as np
import torch._dynamo

torch._dynamo.reset()
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

def run_batch_inference(model, batch=1):
    x = torch.randn(batch, 3, 224, 224).to(device)
    model(x)

model = resnet.resnet50(weights=resnet.ResNet50_Weights.IMAGENET1K_V1).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
batch = 1
compiled_model = torch.compile(model, mode="reduce-overhead")

times_model = []
model(torch.randn(batch, 3, 224, 224).to(device))
for i in range(50):
    start_time = time.time()
    run_batch_inference(model, batch)
    times_model.append(time.time()-start_time)

compiled_model(torch.randn(batch, 3, 224, 224).to(device))
times_compiled_model = []
for i in range(50):
    start_time = time.time()
    run_batch_inference(compiled_model, batch)
    times_compiled_model.append(time.time()-start_time)

speedup = 100*(np.mean(times_model) - np.mean(times_compiled_model)) / np.mean(times_model)

print(f"Speedup: {speedup: .2f}%")
print(f"Average inference time model: {np.mean(times_model)}")
print(f"Average inference time compiled model: {np.mean(times_compiled_model)}")

In [1]:
import time
import numpy as np

def benchmark_inference(model, trials=50):
    # warmup
    model(torch.randn(1, 3, 224, 224).to(device))

    t0 = time.time()
    for i in range(trials):
        x = torch.randn(1, 3, 224, 224).to(device)
        model(x)
    t1 = time.time()
    return (t1 - t0) / trials

In [2]:
import torch
from torchvision.models import resnet

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"


model = resnet.resnet50(weights=resnet.ResNet50_Weights.IMAGENET1K_V1).to(device)
compiled_model = torch.compile(model, mode="reduce-overhead")


unoptimized_t = benchmark_inference(model)
print(unoptimized_t)

optimized_t = benchmark_inference(compiled_model)
print(optimized_t)

print(f"speedup: {100*(unoptimized_t-optimized_t) / unoptimized_t: .2f}%")

0.00666710376739502
0.0038507938385009765


speedup:  42.24%
