In [23]:
!rm -rf torch_compile_debug

In [19]:
import torch.utils.benchmark as benchmark
import torch
from torchvision.models import resnet
import torch._dynamo

device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

In [20]:
def run_batch_inference(model, batch=1):
    x = torch.randn(batch, 3, 224, 224).to(device)
    model(x)

def run_batch_train(model, optimizer, batch=16):
    x = torch.randn(batch, 3, 224, 224).to(device)
    optimizer.zero_grad()
    out = model(x)
    out.sum().backward()
    optimizer.step()
    
model = resnet.resnet18(weights=resnet.ResNet18_Weights.IMAGENET1K_V1).to(device)

In [22]:
batch = 1
torch._dynamo.reset()
compiled_model = torch.compile(model, options={'triton.cudagraphs': False,
                                              'trace.enabled':True})

t_model = benchmark.Timer(
    stmt='run_batch_inference(model, batch)',
    setup='from __main__ import run_batch_inference',
    globals={'model': model, 'batch':batch})

t_compiled_model = benchmark.Timer(
    stmt='run_batch_inference(model, batch)',
    setup='from __main__ import run_batch_inference',
    globals={'model': compiled_model, 'batch':batch})

t_model_runs = t_model.timeit(100)
t_compiled_model_runs = t_compiled_model.timeit(100)

print(f"Inference speedup: {100*(t_model_runs.mean - t_compiled_model_runs.mean) / t_model_runs.mean: .2f}%")



Inference speedup:  1.96%


In [4]:
batch = 32
torch._dynamo.reset()
compiled_model = torch.compile(model)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

t_model = benchmark.Timer(
    stmt='run_batch_train(model, optimizer, batch)',
    setup='from __main__ import run_batch_train',
    globals={'model': model,'optimizer':optimizer, 'batch':batch})

t_compiled_model = benchmark.Timer(
    stmt='run_batch_train(model, optimizer, batch)',
    setup='from __main__ import run_batch_train',
    globals={'model': compiled_model, 'optimizer':optimizer, 'batch':batch})

t_model_runs = t_model.timeit(100)
t_compiled_model_runs = t_compiled_model.timeit(100)

print(t_model_runs)
print(t_compiled_model_runs)

print(f"Training speedup: {100*(t_model_runs.mean - t_compiled_model_runs.mean) / t_model_runs.mean: .2f}%")

<torch.utils.benchmark.utils.common.Measurement object at 0x7f9db1e1ae30>
run_batch_train(model, optimizer, batch)
setup: from __main__ import run_batch_train
  37.30 ms
  1 measurement, 100 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f9db1d97f10>
run_batch_train(model, optimizer, batch)
setup: from __main__ import run_batch_train
  34.57 ms
  1 measurement, 100 runs , 1 thread
Training speedup:  7.32%
