Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,11 @@ class RunResult:
"examples.jagged_sum",
"jagged_sum_tritonbench",
),
"low_mem_dropout": (
"tritonbench.operators.low_mem_dropout.operator",
"examples.low_mem_dropout",
"low_mem_dropout_tritonbench",
),
}


Expand Down Expand Up @@ -538,6 +543,14 @@ class RunResult:
"helion_fp8_gemm_tritonbench-speedup": "helion_speedup",
"helion_fp8_gemm_tritonbench-accuracy": "helion_accuracy",
},
"low_mem_dropout": {
"seeded_dropout-accuracy": "triton_accuracy",
"seeded_dropout-speedup": "triton_speedup",
"torch_compile_dropout-accuracy": "torch_compile_accuracy",
"torch_compile_dropout-speedup": "torch_compile_speedup",
"helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy",
"helion_low_mem_dropout_tritonbench-speedup": "helion_speedup",
},
}


Expand Down
136 changes: 136 additions & 0 deletions examples/low_mem_dropout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
"""
Low mem dropout Example
================

This example demonstrates how to implement a Low mem dropout using Helion.
"""

# %%
# Imports
# -------
from __future__ import annotations

from typing import Callable

import torch

import helion
import helion.language as hl


# %%
# Low mem dropout forward implementations
# -------------------
@helion.kernel()
def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor:
"""
Applies dropout on x using p
Args:
p (float): dropout probability
x (torch.Tensor): input tensor
Returns:
Output tensor
"""
scale = 1.0 / (1.0 - p)
# flatten to 1D so we can use tile
n = x.numel()
x_flat = x.view(-1)
out_flat = torch.empty_like(x_flat)
for tidx in hl.tile(n):
xi = x_flat[tidx].to(torch.float32)
r = hl.rand([tidx], seed=seed)
keep = r > p
yscaled = xi * scale
yi = torch.where(keep, yscaled, 0.0)
out_flat[tidx] = yi.to(x.dtype)
return out_flat.view_as(x)


# %%
# Low mem dropout backward implementation
# -------------------
@helion.kernel()
def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor, seed: int) -> torch.Tensor:
"""
For low mem dropout we are applying randomness inside both fwd and bwd
technically dropout bwd is same as fwd
Args:
p (float): Dropout probability
grad_y (torch.Tensor): Gradient tensor
Returns:
Output tensor
"""
scale = 1.0 / (1.0 - p)
n = grad_y.numel()
grad_y_flat = grad_y.view(-1)
out_flat = torch.empty_like(grad_y_flat)
for tidx in hl.tile(n):
gi = grad_y_flat[tidx].to(torch.float32)
r = hl.rand([tidx], seed=seed)
keep = r > p
g_scaled = gi * scale
gxi = torch.where(keep, g_scaled, 0.0)
out_flat[tidx] = gxi.to(grad_y.dtype)
return out_flat.view_as(grad_y)


# %%
# TritonBench Wrapper
# -------------------
def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable:
"""
Wrapper for TritonBench compatibility.

Args:
tb_op: TritonBench operator instance
p (float): dropout probability
x (torch.Tensor): Input tensor

Returns:
Callable: A function that performs the low_mem_dropout.
"""

def _inner() -> torch.Tensor:
return low_mem_dropout(p, x, seed=123)

return _inner


# %%
# Verification Function
# -------------------
def check(p: float, size: int) -> None:
"""
Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation.

Args:
p (float): dropout probability
size (int): input tensor size
"""
x = torch.randn(size=(size,)).cuda()
seed = 123

out = low_mem_dropout(p, x, seed)
grad_y = torch.ones_like(x)
grad_x = low_mem_dropout_bwd(p, grad_y, seed)
mask_fwd = out != 0
mask_bwd = grad_x != 0
assert torch.equal(mask_fwd, mask_bwd)


# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the low mem dropout kernel verification with different tensor sizes.
Tests with two configurations:
- p=0.25, s=8192
- p=0.25, s=32768
"""
check(0.25, 8192)
check(0.25, 32768)


if __name__ == "__main__":
main()
40 changes: 40 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,46 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, mean, rstd, mean.stride(0), out.stride(0), out.stride(1), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
return (out, mean, rstd)

--- assertExpectedJournal(TestExamples.test_low_mem_dropout)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_low_mem_dropout(x_flat, out_flat, out_flat_stride_0, x_flat_stride_0, n, seed, p, scale, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < n
xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0)
rand = tl.rand(seed, indices_0)
v_0 = rand > p
v_1 = xi * scale
v_2 = 0.0
v_3 = v_2[None]
v_4 = tl.where(v_0, v_1, v_3)
tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0)

def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_launcher):
"""
Applies dropout on x using p
Args:
p (float): dropout probability
x (torch.Tensor): input tensor
Returns:
Output tensor
"""
scale = 1.0 / (1.0 - p)
n = x.numel()
x_flat = x.view(-1)
out_flat = torch.empty_like(x_flat)
_BLOCK_SIZE_0 = 1024
_launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, out_flat.stride(0), x_flat.stride(0), n, seed, p, scale, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out_flat.view_as(x)

--- assertExpectedJournal(TestExamples.test_matmul)
from __future__ import annotations

Expand Down
45 changes: 45 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,51 @@ def test_welford(self):
)
)

def test_low_mem_dropout(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test backwards (not just fwd) and assert that the same elements are dropped out in bwd as fwd (and different elements are dopped out if you change the seed).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, I've updated the test case with dropout mask checking.

from examples.low_mem_dropout import low_mem_dropout
from examples.low_mem_dropout import low_mem_dropout_bwd

from helion._testing import code_and_output

p = 0.25
size = 8192
seed = 123
seed2 = 456
x = torch.randn(size=(size,)).cuda()

_, out_fwd = code_and_output(
low_mem_dropout,
(p, x, seed),
)

grad_y = torch.ones_like(x)
_, grad_x = code_and_output(
low_mem_dropout_bwd,
(p, grad_y, seed),
)

_, grad_x2 = code_and_output(
low_mem_dropout_bwd,
(p, grad_y, seed2),
)

mask_fwd = out_fwd != 0
mask_bwd = grad_x != 0
self.assertTrue(
torch.equal(mask_fwd, mask_bwd),
"Same elements should be dropped in fwd and bwd with the same seed",
)

mask_bwd2 = grad_x2 != 0
self.assertFalse(
torch.equal(mask_bwd, mask_bwd2),
"Different elements should be dropped when using a different seed",
)

self.assertExpectedJournal(
check_example("low_mem_dropout", (p, grad_y, seed), grad_x),
)

def test_rms_norm_fwd(self):
args = (
torch.randn([128, 256], device=DEVICE, dtype=torch.float16),
Expand Down
Loading