# Show that Triton autotune is broken due to warmup issues

### Aim

Create an instance of Triton kernel + @autotune on a100 where the autotuner picks the wrong config, due to `do_bench` not doing enough warmup by default.

## Pick a kernel

`ops.flash_attention` isn't good because it's basically not tunable - it breaks for most meta param values.

Try `ops.matmul`

In [3]:
import torch
import triton
import triton.language as tl
from triton.ops.matmul import _kernel

# Extract original (non-autotuned) kernel.
matmul_kernel = _kernel.fn.fn


# Based on `triton.ops.matmul._matmul._call`, but with exposed meta-params.
def matmul_dispatch(a,
                    b,
                    BLOCK_M=64,
                    BLOCK_N=64,
                    BLOCK_K=64,
                    GROUP_M=8,
                    SPLIT_K=1,
                    EVEN_K=True):
    device = a.device
    # handle non-contiguous inputs if necessary
    if a.stride(0) > 1 and a.stride(1) > 1:
        a = a.contiguous()
    if b.stride(0) > 1 and b.stride(1) > 1:
        b = b.contiguous()
    # checks constraints
    assert a.shape[1] == b.shape[0], "incompatible dimensions"
    M, K = a.shape
    _, N = b.shape
    # allocates output
    c = torch.empty((M, N), device=device, dtype=a.dtype)
    # accumulator types
    ACC_TYPE = tl.float32 if a.dtype in [
        torch.float16, torch.bfloat16, torch.float32
    ] else tl.int32
    # launch kernel
    grid = lambda META: (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(
        N, META['BLOCK_N']), META['SPLIT_K'])
    _kernel[grid](a,
                  b,
                  c,
                  M,
                  N,
                  K,
                  a.stride(0),
                  a.stride(1),
                  b.stride(0),
                  b.stride(1),
                  c.stride(0),
                  c.stride(1),
                  GROUP_M=8,
                  ACC_TYPE=ACC_TYPE)
    return c

In [2]:
a = torch.rand((2048, 1024), device="cuda")
b = torch.rand((1024, 512), device="cuda")

In [5]:
(a @ b - matmul_dispatch(a, b)).abs().max()

tensor(0.1930, device='cuda:0')

In [1]:
from conch.bench import MetaParamGrid

In [None]:
mp_grid = MetaParamGrid(min_val_prod=100_000)