In [1]:
import torch

In [3]:
def foo(x, y):
    a = torch.sin(x)
    b = torch.cos(y)
    return a + b

In [6]:
TORCH_LOGS="output_code"

In [7]:
fun = torch.compile(foo)

In [8]:
torch.cuda.is_available()

True

In [10]:
print(fun(torch.randn(10, 10), torch.randn(10, 10)))

tensor([[ 3.9405e-01, -2.7816e-01,  6.5460e-01,  1.9136e+00,  1.1260e+00,
          1.1907e+00,  1.0387e+00,  6.7897e-01,  5.3868e-01,  6.1485e-02],
        [-7.0066e-01,  1.1511e+00,  9.8689e-01,  1.6656e+00,  1.5357e+00,
          7.4831e-01,  5.0244e-01,  9.7823e-01,  8.6504e-01, -1.1680e+00],
        [ 1.9611e+00,  7.6579e-01,  1.0716e-03, -1.7634e-01, -3.8509e-01,
          3.9524e-01,  1.5205e+00, -1.0848e+00,  3.6593e-01,  1.8708e-01],
        [ 3.8583e-01,  9.2706e-01,  1.1730e+00,  7.4232e-01,  9.8594e-01,
         -3.3094e-01,  3.6186e-01,  8.7859e-01, -1.1520e+00, -8.3283e-01],
        [ 4.1727e-01,  2.5993e-02,  1.0605e+00, -2.5975e-01,  7.7126e-01,
          1.7706e+00,  7.2792e-01,  1.7790e+00, -4.5036e-01,  1.0802e+00],
        [ 3.0801e-01,  1.2220e+00,  8.7680e-01,  1.1921e+00, -4.0569e-02,
          9.0909e-01,  2.1940e-01,  1.0495e+00,  1.9080e-01, -1.1733e-02],
        [ 7.8029e-01,  5.7761e-01, -7.7533e-01,  1.5216e+00,  3.6661e-01,
          1.5670e-01, -1.6596e-0

In [11]:
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

In [12]:
# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

In [14]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.008099616050720215
compile: 9.948044921875


In [15]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager eval time 0: 0.008600576400756836
eager eval time 1: 0.00891478443145752
eager eval time 2: 0.006881279945373535
eager eval time 3: 0.006432672023773193
eager eval time 4: 0.006688704013824463
eager eval time 5: 0.006543360233306885
eager eval time 6: 0.006344704151153564
eager eval time 7: 0.006672383785247803
eager eval time 8: 0.006772736072540283
eager eval time 9: 0.007655424118041992
~~~~~~~~~~
compile eval time 0: 0.006307839870452881
compile eval time 1: 0.005724287986755371
compile eval time 2: 0.006801407814025879
compile eval time 3: 0.005718016147613526
compile eval time 4: 0.006551487922668457
compile eval time 5: 0.005708799839019775
compile eval time 6: 0.005701632022857666
compile eval time 7: 0.00570684814453125
compile eval time 8: 0.006833055973052979
compile eval time 9: 0.006255680084228515
~~~~~~~~~~
(eval) eager median: 0.006730720043182373, compile median: 0.005989984035491943, speedup: 1.1236624343740167x
~~~~~~~~~~


In [16]:
model = init_model()
opt = torch.optim.Adam(model.parameters())

def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert(speedup > 1)
print(f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager train time 0: 0.31984945678710935
eager train time 1: 0.030057344436645508
eager train time 2: 0.02201190376281738
eager train time 3: 0.02111692810058594
eager train time 4: 0.021227487564086912
eager train time 5: 0.01924073600769043
eager train time 6: 0.02023094367980957
eager train time 7: 0.019140607833862306
eager train time 8: 0.019456928253173827
eager train time 9: 0.02128166389465332
~~~~~~~~~~
compile train time 0: 32.9713671875
compile train time 1: 0.028633983612060546
compile train time 2: 0.02534604835510254
compile train time 3: 0.023723007202148438
compile train time 4: 0.019777536392211914
compile train time 5: 0.017926143646240233
compile train time 6: 0.01678950309753418
compile train time 7: 0.018381792068481444
compile train time 8: 0.017373376846313477
compile train time 9: 0.02227609634399414
~~~~~~~~~~
(train) eager median: 0.021172207832336427, compile median: 0.021026816368103027, speedup: 1.0069145733566187x
~~~~~~~~~~
