In [None]:
import os
# os.environ["TORCHINDUCTOR_DEBUG"] = "1"
# os.environ["TORCH_LOGS"] = "output_code"

import torch
from pythia.seq.v2.perf.presentation.kernel import op_triton

torch.set_float32_matmul_precision("high")

# torch
def op_torch(x, w, scale, bias):
    y = (x * w[None, :, :]).sum(dim=1)  # (N, D)
    z = torch.relu(y * scale[None, :] + bias[None, :])  # (N, D)
    return x + z[:, None, :]  # (N, K, D)

def op_torch2(x, w, scale, bias):
    # Question: Will this speed up? Is this kernel fusion?
    return x + torch.relu(
        (x * w[None, :, :]).sum(dim=1)
        * scale[None, :] + bias[None, :]
    )[:, None, :]  # (N, K, D)

def op_torch3(x, w, scale, bias):
    y = torch.einsum("nkd, kd -> nd", x, w)  # (N, D)
    z = torch.relu(y * scale[None, :] + bias[None, :])  # (N, D)
    return x + z[:, None, :]  # (N, K, D)

# torch compile
op_torch_compile = torch.compile(op_torch)

def get_inputs(N=100_000, K=32, D=128, requires_grad=False):
    # Seed
    torch.manual_seed(42)

    # Representation
    x = torch.randn(N, K, D, device="cuda", requires_grad=requires_grad)

    # Learnable weights, scale, bias
    w = torch.randn(K, D, device="cuda", requires_grad=requires_grad)
    scale = torch.randn(D, device="cuda", requires_grad=requires_grad)
    bias = torch.randn(D, device="cuda", requires_grad=requires_grad)

    return {"x": x, "w": w, "scale": scale, "bias": bias}

inputs = get_inputs(N=100_000, requires_grad=True)

# torch v torch2
diff = (op_torch(**inputs) - op_torch3(**inputs)).abs()
print(f"{'torch v torch2':<10} | diff mean: {diff.mean():.2e} | diff mean: {diff.max():.2e}")

# torch v triton
diff = (op_torch(**inputs) - op_triton(**inputs)).abs()
print(f"{'torch v triton':<10} | diff mean: {diff.mean():.2e} | diff mean: {diff.max():.2e}")
# op_torch_compile(**inputs);

In [None]:
import pandas as pd
from functools import partial

from torch.profiler import ProfilerActivity
from torch.profiler.profiler import profile, schedule

from pythia.seq.v2.perf.utils import profiler_to_dataframe, stats_fn

%load_ext autoreload
%autoreload 2

inputs = get_inputs(N=100_000)

# Define function variants
fns = {
    "op_torch": partial(op_torch, **inputs),
    # "op_torch2": partial(op_torch2, **inputs),
    # "op_torch3": partial(op_torch3, **inputs),
    # "op_torch_compile": partial(op_torch_compile, **inputs),
    "op_triton": partial(op_triton, **inputs),
}

ref = fns["op_torch"]().clone()
n_repeat = 100
for name, fn in fns.items():
    stats_fn(fn, inputs, None, label=name, n_warmup=10, n_repeat=100); continue
    for _ in range(5):
        fn()

    with profile(
        activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
        record_shapes=False,
        with_stack=False,
        profile_memory=False,
    ) as prof:
        for _ in range(n_repeat):
            out = fn()
    
    print(f"{'-' * (len(name) + 4)}\n| {name} |\n{'-' * (len(name) + 4)}")
    print(prof.key_averages().table(sort_by="device_time_total", row_limit=20, max_name_column_width=50, top_level_events_only=True))
