Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ class RunResult:
"examples.exp",
"exp_tritonbench",
),
"vector_exp-bwd": (
"tritonbench.operators.vector_exp.operator",
"examples.exp",
"exp_tritonbench",
),
"rms_norm": (
"tritonbench.operators.rms_norm.operator",
"examples.rms_norm",
Expand Down
65 changes: 60 additions & 5 deletions examples/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@


# %%
# Exponential Kernel
# ---------------
@helion.kernel()
def exp(x: torch.Tensor) -> torch.Tensor:
def exp_fwd(x: torch.Tensor) -> torch.Tensor:
"""
Computes the exponential of all elements in the input tensor.

Expand All @@ -39,6 +37,63 @@ def exp(x: torch.Tensor) -> torch.Tensor:
return out


# %%
@helion.kernel()
def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor) -> torch.Tensor:
"""
Computes the gradient of the exponential function with respect to the input tensor.

Args:
dy: Gradient of the output tensor
exp_x: Saved activation from the forward pass

Returns:
Gradient of the input tensor
"""
dx = torch.empty_like(exp_x)
for tile in hl.tile(exp_x.size()):
dx[tile] = dy[tile] * exp_x[tile]
return dx


# %%
# Exponential Kernel
# ---------------
class ExpFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx: object,
x: torch.Tensor,
) -> torch.Tensor:
"""Forward pass for exp."""
y = exp_fwd(x)
ctx.save_for_backward(y) # type: ignore[arg-type]
return y

@staticmethod
def backward( # type: ignore[override]
ctx: object,
grad_output: torch.Tensor,
) -> torch.Tensor:
"""Backward pass for exp."""
(x,) = ctx.saved_tensors # type: ignore[attr-defined]
return exp_bwd(grad_output, x)


# %%
def exp(x: torch.Tensor) -> torch.Tensor:
"""
Exponential with forward and backward support.

Args:
x: Input tensor

Returns:
Output tensor with the exponential of each element in the input
"""
return ExpFunction.apply(x) # type: ignore[no-any-return]


# %%
# Benchmark Wrapper
# --------------
Expand Down Expand Up @@ -68,8 +123,8 @@ def check(n: int) -> None:
Args:
n: Size of the test tensor
"""
x = torch.randn(n, device="cuda", dtype=torch.float32)
run_example(exp, torch.exp, (x,))
x = torch.randn(n, device="cuda", dtype=torch.float32, requires_grad=True)
run_example(exp, torch.exp, (x,), bwd=True)


# %%
Expand Down
71 changes: 71 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -705,6 +705,77 @@ def embedding(x: torch.Tensor, weight: torch.Tensor, *, _launcher=_default_launc
_launcher(_helion_embedding, (x_flat.size(0) * triton.cdiv(embedding_dim, _BLOCK_SIZE_1),), x_flat, weight, out, x_flat.size(0), out.stride(0), out.stride(1), weight.stride(0), weight.stride(1), x_flat.stride(0), embedding_dim, _BLOCK_SIZE_1, num_warps=4, num_stages=3)
return out.view(*x.size(), embedding_dim)

--- assertExpectedJournal(TestExamples.test_exp_bwd)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_exp_bwd(exp_x, dy, dx, exp_x_size_0, dx_stride_0, dy_stride_0, exp_x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < exp_x_size_0
load = tl.load(dy + indices_0 * dy_stride_0, mask_0, other=0)
load_1 = tl.load(exp_x + indices_0 * exp_x_stride_0, mask_0, other=0)
v_0 = load * load_1
tl.store(dx + indices_0 * dx_stride_0, v_0, mask_0)

def exp_bwd(dy: torch.Tensor, exp_x: torch.Tensor, *, _launcher=_default_launcher):
"""
Computes the gradient of the exponential function with respect to the input tensor.

Args:
dy: Gradient of the output tensor
exp_x: Saved activation from the forward pass

Returns:
Gradient of the input tensor
"""
dx = torch.empty_like(exp_x)
_BLOCK_SIZE_0 = 16
_launcher(_helion_exp_bwd, (triton.cdiv(exp_x.size(0), _BLOCK_SIZE_0),), exp_x, dy, dx, exp_x.size(0), dx.stride(0), dy.stride(0), exp_x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return dx

--- assertExpectedJournal(TestExamples.test_exp_fwd)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from torch._inductor.runtime.triton_compat import libdevice
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_exp_fwd(x, out, x_size_0, out_stride_0, x_stride_0, _BLOCK_SIZE_0: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < x_size_0
load = tl.load(x + indices_0 * x_stride_0, mask_0, other=0)
v_0 = tl.cast(load, tl.float32)
v_1 = libdevice.exp(v_0)
v_2 = tl.cast(v_1, tl.float16)
tl.store(out + indices_0 * out_stride_0, v_2, mask_0)

def exp_fwd(x: torch.Tensor, *, _launcher=_default_launcher):
"""
Computes the exponential of all elements in the input tensor.

Args:
x: Input tensor

Returns:
Output tensor with the exponential of each element in the input
"""
out = torch.empty_like(x)
_BLOCK_SIZE_0 = 16
_launcher(_helion_exp_fwd, (triton.cdiv(x.size(0), _BLOCK_SIZE_0),), x, out, x.size(0), out.stride(0), x.stride(0), _BLOCK_SIZE_0, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_fp8_attention)
from __future__ import annotations

Expand Down
37 changes: 37 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,6 +1291,43 @@ def test_jagged_layer_norm(self):
)
)

def test_exp_fwd(self):
x = torch.randn([1024], device=DEVICE, dtype=torch.float16)
args = (x,)
self.assertExpectedJournal(
check_example(
"exp",
args,
torch.exp(x),
fn_name="exp_fwd",
block_sizes=[16],
num_warps=4,
num_stages=3,
)
)

def test_exp_bwd(self):
x = torch.randn([1024], device=DEVICE, dtype=torch.float16).requires_grad_(True)
y = torch.exp(x)
grad_out = torch.randn_like(y)
y.backward(grad_out)
torch_out = x.grad
args = (
grad_out,
y,
)
self.assertExpectedJournal(
check_example(
"exp",
args,
torch_out,
fn_name="exp_bwd",
block_sizes=[16],
num_warps=4,
num_stages=3,
)
)


if __name__ == "__main__":
unittest.main()
Loading