In [1]:
import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )

In [2]:
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to(torch.float32).cuda(),
        torch.randint(1000, (b,)).cuda(),
    )

N_ITERS = 10

from torchvision.models import densenet121
def init_model():
    return densenet121().to(torch.float32).cuda()

In [4]:
model = init_model()

# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.6182738037109375
compile: 73.2932265625


In [5]:
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.029868032455444334
compile: 0.9026836547851562


In [9]:
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.02528358459472656
compile: 0.006692863941192627


In [None]:
from world_machine.profile import profile_range

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.model = init_model()
    
    
    def forward(self, x):
        return self.model(x)

In [18]:
model = MyModule()

In [19]:
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.021358591079711914
compile: 8.074880859375


In [25]:
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])

eager: 0.01586995220184326
compile: 0.0067276802062988285


In [3]:
from world_machine import WorldMachine, WorldMachineBuilder
from world_machine.layers import PointwiseFeedforward


def get_benchmark_model() -> WorldMachine:
    builder = WorldMachineBuilder(128, 100, "alibi", False)

    builder.add_sensorial_dimension("dim0",
                                    128,
                                    PointwiseFeedforward(
                                        3, 2*128, output_dim=128),
                                    PointwiseFeedforward(128, 2*128, output_dim=3))

    builder.add_sensorial_dimension("dim1",
                                    128,
                                    PointwiseFeedforward(
                                        3, 2*128, output_dim=128),
                                    PointwiseFeedforward(128, 2*128, output_dim=3))

    builder.add_block(1, "dim0", n_attention_head=4)
    builder.add_block(1, "dim1", n_attention_head=1)

    builder.remove_positional_encoding = False
    builder.state_activation = "tanh"
    builder.state_dropout = False

    model = builder.build()

    return model

import torch

from world_machine.data import WorldMachineDataLoader, WorldMachineDataset


class BenchmarkDataset(WorldMachineDataset):
    def __init__(self):
        sensorial_dimensions = ["dim0", "dim1"]
        size = 32*10
        has_state_decoded = False
        has_masks = True
        super().__init__(sensorial_dimensions, size, has_state_decoded, has_masks)

    def get_dimension_item(self, dimension: str, index: int) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.empty([100, 3]), torch.empty([100, 3])

    def get_dimension_mask(self, dimension, index) -> tuple[torch.Tensor, torch.Tensor]:
        return torch.ones(100, dtype=bool), torch.ones(100, dtype=bool)


def get_benchmark_dataloaders():
    dataset = BenchmarkDataset()

    train_loader = WorldMachineDataLoader(dataset, 32, True)
    val_loader = WorldMachineDataLoader(dataset, 32, True)

    return train_loader, val_loader


In [4]:
model = get_benchmark_model()
model.eval()
model = model.cuda()

In [5]:
loader, _ = get_benchmark_dataloaders()

In [6]:
item = next(iter(loader))
item = item.cuda()

state = torch.zeros([32, 100, 128], device="cuda")

In [7]:
torch._dynamo.reset()

model_opt = torch.compile(model, mode="reduce-overhead")

inp = item
with torch.no_grad():
    print("eager:", timed(lambda: model.inference(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))[1])
    print("compile:", timed(lambda: model_opt.inference(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))[1])

eager: 0.8495134887695313
compile: 0.25009152221679687


In [8]:
with torch.no_grad():
    print("eager:", timed(lambda: model.inference(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))[1])
    print("compile:", timed(lambda: model_opt.inference(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))[1])

eager: 0.27407461547851564
compile: 0.23983821105957032


In [9]:
with torch.no_grad():
    print("eager:", timed(lambda: model(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))[1])
    print("compile:", timed(lambda: model_opt(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))[1])

eager: 0.00329420804977417
compile: 8.89790625


In [10]:
eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(
                state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"]))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
#assert(speedup > 1)
print(f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x")
print("~" * 10)

eager eval time 0: 0.005171199798583984
eager eval time 1: 0.0037355520725250242
eager eval time 2: 0.0030412800312042236
eager eval time 3: 0.0032481279373168946
eager eval time 4: 0.002983936071395874
eager eval time 5: 0.003458048105239868
eager eval time 6: 0.002942975997924805
eager eval time 7: 0.003034111976623535
eager eval time 8: 0.0027596800327301025
eager eval time 9: 0.0030146560668945313
~~~~~~~~~~


  super().capture_end()
  super().capture_end()


compile eval time 0: 4.11998291015625
compile eval time 1: 0.004907008171081543
compile eval time 2: 0.0033894400596618654
compile eval time 3: 0.004100096225738525
compile eval time 4: 0.0036382720470428467
compile eval time 5: 0.003408895969390869
compile eval time 6: 0.005074944019317627
compile eval time 7: 0.0037969920635223388
compile eval time 8: 0.003471359968185425
compile eval time 9: 0.004237311840057373
~~~~~~~~~~
(eval) eager median: 0.003037696003913879, compile median: 0.003948544144630432, speedup: 0.7693205122310202x
~~~~~~~~~~


In [11]:
torch.cuda.set_sync_debug_mode("warn")

  torch._C._cuda_set_sync_debug_mode(debug_mode)


In [12]:
model(state, sensorial_data=item["inputs"], sensorial_masks=item["input_masks"])

  E = torch.nn.functional.scaled_dot_product_attention(


TensorDict(
    fields={
        dim0: Tensor(shape=torch.Size([32, 100, 3]), device=cuda:0, dtype=torch.float32, is_shared=True),
        dim1: Tensor(shape=torch.Size([32, 100, 3]), device=cuda:0, dtype=torch.float32, is_shared=True),
        state: Tensor(shape=torch.Size([32, 100, 128]), device=cuda:0, dtype=torch.float32, is_shared=True),
        state_decoded: Tensor(shape=torch.Size([32, 100, 128]), device=cuda:0, dtype=torch.float32, is_shared=True)},
    batch_size=torch.Size([32, 100]),
    device=cuda:0,
    is_shared=True)