diff --git a/benchmarks/run.py b/benchmarks/run.py index 4b921c455..884cd0ef3 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -142,6 +142,11 @@ class RunResult: "examples.exp", "exp_tritonbench", ), + "vector_exp-bwd": ( + "tritonbench.operators.vector_exp.operator", + "examples.exp", + "exp_tritonbench", + ), "rms_norm": ( "tritonbench.operators.rms_norm.operator", "examples.rms_norm", diff --git a/examples/exp.py b/examples/exp.py index 417c8f5cb..fae90e033 100644 --- a/examples/exp.py +++ b/examples/exp.py @@ -20,10 +20,8 @@ # %% -# Exponential Kernel -# --------------- @helion.kernel() -def exp(x: torch.Tensor) -> torch.Tensor: +def exp_fwd(x: torch.Tensor) -> torch.Tensor: """ Computes the exponential of all elements in the input tensor. @@ -39,6 +37,63 @@ def exp(x: torch.Tensor) -> torch.Tensor: return out +# %% +@helion.kernel() +def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor) -> torch.Tensor: + """ + Computes the gradient of the exponential function with respect to the input tensor. + + Args: + dy: Gradient of the output tensor + exp_x: Saved activation from the forward pass + + Returns: + Gradient of the input tensor + """ + dx = torch.empty_like(exp_x) + for tile in hl.tile(exp_x.size()): + dx[tile] = dy[tile] * exp_x[tile] + return dx + + +# %% +# Exponential Kernel +# --------------- +class ExpFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx: object, + x: torch.Tensor, + ) -> torch.Tensor: + """Forward pass for exp.""" + y = exp_fwd(x) + ctx.save_for_backward(y) # type: ignore[arg-type] + return y + + @staticmethod + def backward( # type: ignore[override] + ctx: object, + grad_output: torch.Tensor, + ) -> torch.Tensor: + """Backward pass for exp.""" + (x,) = ctx.saved_tensors # type: ignore[attr-defined] + return exp_bwd(grad_output, x) + + +# %% +def exp(x: torch.Tensor) -> torch.Tensor: + """ + Exponential with forward and backward support. + + Args: + x: Input tensor + + Returns: + Output tensor with the exponential of each element in the input + """ + return ExpFunction.apply(x) # type: ignore[no-any-return] + + # %% # Benchmark Wrapper # -------------- @@ -68,8 +123,8 @@ def check(n: int) -> None: Args: n: Size of the test tensor """ - x = torch.randn(n, device="cuda", dtype=torch.float32) - run_example(exp, torch.exp, (x,)) + x = torch.randn(n, device="cuda", dtype=torch.float32, requires_grad=True) + run_example(exp, torch.exp, (x,), bwd=True) # %% diff --git a/test/test_examples.expected b/test/test_examples.expected index 3436d1d7a..2c004def2 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -705,6 +705,77 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc _launcher(_helion_embedding, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3) return out.view(*x.size(), embedding_dim) +--- assertExpectedJournal(TestExamples.test_exp_bwd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_exp_bwd(exp_x, dy, dx, exp_x_size_0, dx_stride_0, dy_stride_0, exp_x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < exp_x_size_0 + load = tl.load(dy + indices_0 * dy_stride_0, mask_0, other=0) + load_1 = tl.load(exp_x + indices_0 * exp_x_stride_0, mask_0, other=0) + v_0 = load * load_1 + tl.store(dx + indices_0 * dx_stride_0, v_0, mask_0) + +def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor, *, _launcher=_default_launcher): + """ + Computes the gradient of the exponential function with respect to the input tensor. + + Args: + dy: Gradient of the output tensor + exp_x: Saved activation from the forward pass + + Returns: + Gradient of the input tensor + """ + dx = torch.empty_like(exp_x) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_exp_bwd, (triton.cdiv(exp_x.size(0), _BLOCK_SIZE_0),), exp_x, dy, dx, exp_x.size(0), dx.stride(0), dy.stride(0), exp_x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return dx + +--- assertExpectedJournal(TestExamples.test_exp_fwd) +from __future__ import annotations + +import torch +import triton +import triton.language as tl +from torch._inductor.runtime.triton_compat import libdevice +from helion.runtime import default_launcher as _default_launcher + +@triton.jit +def _helion_exp_fwd(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr): + pid_0 = tl.program_id(0) + offset_0 = pid_0 * _BLOCK_SIZE_0 + indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32) + mask_0 = indices_0 < x_size_0 + load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0) + v_0 = tl.cast(load, tl.float32) + v_1 = libdevice.exp(v_0) + v_2 = tl.cast(v_1, tl.float16) + tl.store(out + indices_0 * out_stride_0, v_2, mask_0) + +def exp_fwd(x: torch.Tensor, *, _launcher=_default_launcher): + """ + Computes the exponential of all elements in the input tensor. + + Args: + x: Input tensor + + Returns: + Output tensor with the exponential of each element in the input + """ + out = torch.empty_like(x) + _BLOCK_SIZE_0 = 16 + _launcher(_helion_exp_fwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3) + return out + --- assertExpectedJournal(TestExamples.test_fp8_attention) from __future__ import annotations diff --git a/test/test_examples.py b/test/test_examples.py index a0e3492ff..ed1cc82f0 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -1291,6 +1291,43 @@ def test_jagged_layer_norm(self): ) ) + def test_exp_fwd(self): + x = torch.randn([1024], device=DEVICE, dtype=torch.float16) + args = (x,) + self.assertExpectedJournal( + check_example( + "exp", + args, + torch.exp(x), + fn_name="exp_fwd", + block_sizes=[16], + num_warps=4, + num_stages=3, + ) + ) + + def test_exp_bwd(self): + x = torch.randn([1024], device=DEVICE, dtype=torch.float16).requires_grad_(True) + y = torch.exp(x) + grad_out = torch.randn_like(y) + y.backward(grad_out) + torch_out = x.grad + args = ( + grad_out, + y, + ) + self.assertExpectedJournal( + check_example( + "exp", + args, + torch_out, + fn_name="exp_bwd", + block_sizes=[16], + num_warps=4, + num_stages=3, + ) + ) + if __name__ == "__main__": unittest.main()