In [3]:
import torch
import triton
import triton.language as tl

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@triton.jit
def dropout_kernel(
    x_ptr,
    output_ptr,
    n,
    p,
    seed,
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n
    x = tl.load(x_ptr + offsets, mask=mask)
    to_keep = tl.rand(seed, offsets) > p
    output = tl.where(to_keep, x / (1 - p), 0.0)
    tl.store(output_ptr + offsets, output, mask=mask)


def seeded_dropout(x, p, seed):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
    dropout_kernel[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)
    return output


x = torch.randn(size=(10,), device=DEVICE)
# Compare this to the baseline - dropout mask is never instantiated!
output = seeded_dropout(x, p=0.5, seed=123)
output2 = seeded_dropout(x, p=0.5, seed=123)
output3 = seeded_dropout(x, p=0.5, seed=512)

import tabulate
print(
    tabulate.tabulate(
        [
            ["input"] + x.tolist(),
            ["output (seed = 123)"] + output.tolist(),
            ["output (seed = 123)"] + output2.tolist(),
            ["output (seed = 512)"] + output3.tolist(),
        ]
    )
)

-------------------  ---------  ---------  -------  ---------  --------  --------  ---------  -------  --------  --------
input                -0.588533  -0.380392  1.23812  -0.199316  0.548186  0.901865  -0.540132  1.10326  0.515807  0.116383
output (seed = 123)   0         -0.760784  0         0         0         1.80373    0         0        1.03161   0.232766
output (seed = 123)   0         -0.760784  0         0         0         1.80373    0         0        1.03161   0.232766
output (seed = 512)   0          0         2.47623  -0.398631  0         1.80373   -1.08026   0        0         0
-------------------  ---------  ---------  -------  ---------  --------  --------  ---------  -------  --------  --------
