From 0f4844a59bb0682efa415261e5c3b455fa34d86e Mon Sep 17 00:00:00 2001 From: karthickai Date: Fri, 19 Sep 2025 15:29:32 -0700 Subject: [PATCH] [Benchmark] Add low mem dropout example stack-info: PR: https://github.com/pytorch/helion/pull/641, branch: karthickai/stack/1 --- benchmarks/run.py | 13 ++++ examples/low_mem_dropout.py | 136 ++++++++++++++++++++++++++++++++++++ test/test_examples.expected | 40 +++++++++++ test/test_examples.py | 45 ++++++++++++ 4 files changed, 234 insertions(+) create mode 100644 examples/low_mem_dropout.py diff --git a/benchmarks/run.py b/benchmarks/run.py index 4195b16e8..678b2b432 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -280,6 +280,11 @@ class RunResult: "examples.jagged_sum", "jagged_sum_tritonbench", ), + "low_mem_dropout": ( + "tritonbench.operators.low_mem_dropout.operator", + "examples.low_mem_dropout", + "low_mem_dropout_tritonbench", + ), } @@ -538,6 +543,14 @@ class RunResult: "helion_fp8_gemm_tritonbench-speedup": "helion_speedup", "helion_fp8_gemm_tritonbench-accuracy": "helion_accuracy", }, + "low_mem_dropout": { + "seeded_dropout-accuracy": "triton_accuracy", + "seeded_dropout-speedup": "triton_speedup", + "torch_compile_dropout-accuracy": "torch_compile_accuracy", + "torch_compile_dropout-speedup": "torch_compile_speedup", + "helion_low_mem_dropout_tritonbench-accuracy": "helion_accuracy", + "helion_low_mem_dropout_tritonbench-speedup": "helion_speedup", + }, } diff --git a/examples/low_mem_dropout.py b/examples/low_mem_dropout.py new file mode 100644 index 000000000..92b5d630f --- /dev/null +++ b/examples/low_mem_dropout.py @@ -0,0 +1,136 @@ +""" +Low mem dropout Example +================ + +This example demonstrates how to implement a Low mem dropout using Helion. +""" + +# %% +# Imports +# ------- +from __future__ import annotations + +from typing import Callable + +import torch + +import helion +import helion.language as hl + + +# %% +# Low mem dropout forward implementations +# ------------------- +@helion.kernel() +def low_mem_dropout(p: float, x: torch.Tensor, seed: int) -> torch.Tensor: + """ + Applies dropout on x using p + Args: + p (float): dropout probability + x (torch.Tensor): input tensor + Returns: + Output tensor + """ + scale = 1.0 / (1.0 - p) + # flatten to 1D so we can use tile + n = x.numel() + x_flat = x.view(-1) + out_flat = torch.empty_like(x_flat) + for tidx in hl.tile(n): + xi = x_flat[tidx].to(torch.float32) + r = hl.rand([tidx], seed=seed) + keep = r > p + yscaled = xi * scale + yi = torch.where(keep, yscaled, 0.0) + out_flat[tidx] = yi.to(x.dtype) + return out_flat.view_as(x) + + +# %% +# Low mem dropout backward implementation +# ------------------- +@helion.kernel() +def low_mem_dropout_bwd(p: float, grad_y: torch.Tensor, seed: int) -> torch.Tensor: + """ + For low mem dropout we are applying randomness inside both fwd and bwd + technically dropout bwd is same as fwd + Args: + p (float): Dropout probability + grad_y (torch.Tensor): Gradient tensor + Returns: + Output tensor + """ + scale = 1.0 / (1.0 - p) + n = grad_y.numel() + grad_y_flat = grad_y.view(-1) + out_flat = torch.empty_like(grad_y_flat) + for tidx in hl.tile(n): + gi = grad_y_flat[tidx].to(torch.float32) + r = hl.rand([tidx], seed=seed) + keep = r > p + g_scaled = gi * scale + gxi = torch.where(keep, g_scaled, 0.0) + out_flat[tidx] = gxi.to(grad_y.dtype) + return out_flat.view_as(grad_y) + + +# %% +# TritonBench Wrapper +# ------------------- +def low_mem_dropout_tritonbench(tb_op: object, p: float, x: torch.Tensor) -> Callable: + """ + Wrapper for TritonBench compatibility. + + Args: + tb_op: TritonBench operator instance + p (float): dropout probability + x (torch.Tensor): Input tensor + + Returns: + Callable: A function that performs the low_mem_dropout. + """ + + def _inner() -> torch.Tensor: + return low_mem_dropout(p, x, seed=123) + + return _inner + + +# %% +# Verification Function +# ------------------- +def check(p: float, size: int) -> None: + """ + Verify the low mem dropout kernel implementation against PyTorch's native dropout implementation. + + Args: + p (float): dropout probability + size (int): input tensor size + """ + x = torch.randn(size=(size,)).cuda() + seed = 123 + + out = low_mem_dropout(p, x, seed) + grad_y = torch.ones_like(x) + grad_x = low_mem_dropout_bwd(p, grad_y, seed) + mask_fwd = out != 0 + mask_bwd = grad_x != 0 + assert torch.equal(mask_fwd, mask_bwd) + + +# %% +# Main Function +# ----------- +def main() -> None: + """ + Main entry point that runs the low mem dropout kernel verification with different tensor sizes. + Tests with two configurations: + - p=0.25, s=8192 + - p=0.25, s=32768 + """ + check(0.25, 8192) + check(0.25, 32768) + + +if __name__ == "__main__": + main() diff --git a/test/test_examples.expected b/test/test_examples.expected index b36008989..7c0bde755 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -2766,6 +2766,46 @@ def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.T _launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), x, weight, out, mean, rstd, mean.stride(0), out.stride(0), out.stride(1), rstd.stride(0), weight.stride(0), x.stride(0), x.stride(1), m, n, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3) return (out, mean, rstd) +--- assertExpectedJournal(TestExamples.test_low_mem_dropout) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_low_mem_dropout(x_flat, out_flat, out_flat_stride_0, x_flat_stride_0, n, seed, p, scale, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < n + xi = tl.load(x_flat + indices_0 * x_flat_stride_0, mask_0, other=0) + rand = tl.rand(seed, indices_0) + v_0 = rand > p + v_1 = xi * scale + v_2 = 0.0 + v_3 = v_2[None] + v_4 = tl.where(v_0, v_1, v_3) + tl.store(out_flat + indices_0 * out_flat_stride_0, v_4, mask_0) + +def low_mem_dropout(p: float, x: torch.Tensor, seed: int, *, _launcher=_default_launcher): + """ + Applies dropout on x using p + Args: + p (float): dropout probability + x (torch.Tensor): input tensor + Returns: + Output tensor + """ + scale = 1.0 / (1.0 - p) + n = x.numel() + x_flat = x.view(-1) + out_flat = torch.empty_like(x_flat) + _BLOCK_SIZE_0 = 1024 + _launcher(_helion_low_mem_dropout, (triton.cdiv(n, _BLOCK_SIZE_0),), x_flat, out_flat, out_flat.stride(0), x_flat.stride(0), n, seed, p, scale, _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out_flat.view_as(x) + --- assertExpectedJournal(TestExamples.test_matmul) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index dc4ec9494..7fde641ce 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -310,6 +310,51 @@ def test_welford(self): ) ) + def test_low_mem_dropout(self): + from examples.low_mem_dropout import low_mem_dropout + from examples.low_mem_dropout import low_mem_dropout_bwd + + from helion._testing import code_and_output + + p = 0.25 + size = 8192 + seed = 123 + seed2 = 456 + x = torch.randn(size=(size,)).cuda() + + _, out_fwd = code_and_output( + low_mem_dropout, + (p, x, seed), + ) + + grad_y = torch.ones_like(x) + _, grad_x = code_and_output( + low_mem_dropout_bwd, + (p, grad_y, seed), + ) + + _, grad_x2 = code_and_output( + low_mem_dropout_bwd, + (p, grad_y, seed2), + ) + + mask_fwd = out_fwd != 0 + mask_bwd = grad_x != 0 + self.assertTrue( + torch.equal(mask_fwd, mask_bwd), + "Same elements should be dropped in fwd and bwd with the same seed", + ) + + mask_bwd2 = grad_x2 != 0 + self.assertFalse( + torch.equal(mask_bwd, mask_bwd2), + "Different elements should be dropped when using a different seed", + ) + + self.assertExpectedJournal( + check_example("low_mem_dropout", (p, grad_y, seed), grad_x), + ) + def test_rms_norm_fwd(self): args = ( torch.randn([128, 256], device=DEVICE, dtype=torch.float16),