Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions examples/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
@helion.kernel
def layer_norm_fwd(
x: torch.Tensor,
nomralized_shape: list[int],
normalized_shape: list[int],
weight: torch.Tensor,
bias: torch.Tensor,
eps: float = 1e-5,
Expand All @@ -28,7 +28,7 @@ def layer_norm_fwd(
Performs 1D layer normalization on the input tensor using Helion.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (torch.Tensor): Learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Expand All @@ -38,19 +38,19 @@ def layer_norm_fwd(
m, n = x.size()
assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}"
assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}"
assert len(nomralized_shape) == 1, (
assert len(normalized_shape) == 1, (
"Helion layer norm only supports 1D layer norm currently"
)
assert nomralized_shape[0] == n, (
f"normalized shape mismatch {nomralized_shape[0]} != {n}"
assert normalized_shape[0] == n, (
f"normalized shape mismatch {normalized_shape[0]} != {n}"
)
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
for tile_m in hl.tile(m):
acc = x[tile_m, :].to(torch.float32)
var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0)
normalized = (acc - mean) * torch.rsqrt(var + eps)
acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32))
out[tile_m, :] = acc
out[tile_m, :] = acc.to(x.dtype)
return out


Expand Down
10 changes: 5 additions & 5 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1303,12 +1303,12 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out
v_15 = tl.cast(v_14, tl.float16)
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :])

def layer_norm_fwd(x: torch.Tensor, nomralized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher):
"""
Performs 1D layer normalization on the input tensor using Helion.
Args:
x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16.
nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1).
weight (torch.Tensor): Learnable scale parameter of shape [dim].
bias (torch.Tensor): Learnable bias parameter of shape [dim].
eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5.
Expand All @@ -1318,9 +1318,9 @@ def layer_norm_fwd(x: torch.Tensor, nomralized_shape: list[int], weight: torch.T
m, n = x.size()
assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}'
assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}'
assert len(nomralized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
assert nomralized_shape[0] == n, f'normalized shape mismatch {nomralized_shape[0]} != {n}'
out = torch.empty([m, n], dtype=torch.float16, device=x.device)
assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently'
assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}'
out = torch.empty([m, n], dtype=x.dtype, device=x.device)
_BLOCK_SIZE_0 = 32
_RDIM_SIZE_1 = triton.next_power_of_2(bias.size(0))
_launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)
Expand Down
Loading