diff --git a/examples/layer_norm.py b/examples/layer_norm.py index 33cb12fd3..e7ef49448 100644 --- a/examples/layer_norm.py +++ b/examples/layer_norm.py @@ -19,7 +19,7 @@ @helion.kernel def layer_norm_fwd( x: torch.Tensor, - nomralized_shape: list[int], + normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float = 1e-5, @@ -28,7 +28,7 @@ def layer_norm_fwd( Performs 1D layer normalization on the input tensor using Helion. Args: x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. - nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1). + normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1). weight (torch.Tensor): Learnable scale parameter of shape [dim]. bias (torch.Tensor): Learnable bias parameter of shape [dim]. eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. @@ -38,19 +38,19 @@ def layer_norm_fwd( m, n = x.size() assert weight.size(0) == n, f"weight size mismatch {weight.size(0)} != {m}" assert bias.size(0) == n, f"bias size mismatch {bias.size(0)} != {m}" - assert len(nomralized_shape) == 1, ( + assert len(normalized_shape) == 1, ( "Helion layer norm only supports 1D layer norm currently" ) - assert nomralized_shape[0] == n, ( - f"normalized shape mismatch {nomralized_shape[0]} != {n}" + assert normalized_shape[0] == n, ( + f"normalized shape mismatch {normalized_shape[0]} != {n}" ) - out = torch.empty([m, n], dtype=torch.float16, device=x.device) + out = torch.empty([m, n], dtype=x.dtype, device=x.device) for tile_m in hl.tile(m): acc = x[tile_m, :].to(torch.float32) var, mean = torch.var_mean(acc, dim=-1, keepdim=True, correction=0) normalized = (acc - mean) * torch.rsqrt(var + eps) acc = normalized * (weight[:].to(torch.float32)) + (bias[:].to(torch.float32)) - out[tile_m, :] = acc + out[tile_m, :] = acc.to(x.dtype) return out diff --git a/test/test_examples.expected b/test/test_examples.expected index 87ac8e55a..ba4451d3b 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1303,12 +1303,12 @@ def _helion_layer_norm_fwd(bias, x, weight, out, bias_size_0, bias_stride_0, out v_15 = tl.cast(v_14, tl.float16) tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), v_15, mask_0[:, None] & mask_1[None, :]) -def layer_norm_fwd(x: torch.Tensor, nomralized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): +def layer_norm_fwd(x: torch.Tensor, normalized_shape: list[int], weight: torch.Tensor, bias: torch.Tensor, eps: float=1e-05, *, _launcher=_default_launcher): """ Performs 1D layer normalization on the input tensor using Helion. Args: x (torch.Tensor): Input tensor of shape [batch_size, dim], expected to be FP16. - nomralized_shape (list[int]): List containing the dimension to normalize over (should be length 1). + normalized_shape (list[int]): List containing the dimension to normalize over (should be length 1). weight (torch.Tensor): Learnable scale parameter of shape [dim]. bias (torch.Tensor): Learnable bias parameter of shape [dim]. eps (float, optional): Small value added to variance for numerical stability. Default is 1e-5. @@ -1318,9 +1318,9 @@ def layer_norm_fwd(x: torch.Tensor, nomralized_shape: list[int], weight: torch.T m, n = x.size() assert weight.size(0) == n, f'weight size mismatch {weight.size(0)} != {m}' assert bias.size(0) == n, f'bias size mismatch {bias.size(0)} != {m}' - assert len(nomralized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently' - assert nomralized_shape[0] == n, f'normalized shape mismatch {nomralized_shape[0]} != {n}' - out = torch.empty([m, n], dtype=torch.float16, device=x.device) + assert len(normalized_shape) == 1, 'Helion layer norm only supports 1D layer norm currently' + assert normalized_shape[0] == n, f'normalized shape mismatch {normalized_shape[0]} != {n}' + out = torch.empty([m, n], dtype=x.dtype, device=x.device) _BLOCK_SIZE_0 = 32 _RDIM_SIZE_1 = triton.next_power_of_2(bias.size(0)) _launcher(_helion_layer_norm_fwd, (triton.cdiv(m, _BLOCK_SIZE_0),), bias, x, weight, out, bias.size(0), bias.stride(0), out.stride(0), out.stride(1), weight.stride(0), x.stride(0), x.stride(1), m, eps, _BLOCK_SIZE_0, _RDIM_SIZE_1, num_warps=4, num_stages=3)