diff --git a/tests/kernels/test_fla_layernorm_guard.py b/tests/kernels/test_fla_layernorm_guard.py new file mode 100644 index 000000000000..f944c6dcfa73 --- /dev/null +++ b/tests/kernels/test_fla_layernorm_guard.py @@ -0,0 +1,388 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.fla.ops.layernorm_guard import ( + layer_norm_fwd, + layernorm_fn, + rms_norm_ref, +) +from vllm.platforms import current_platform + + +def layer_norm_ref( + x, + weight, + bias, + z=None, + eps=1e-6, + group_size=None, + norm_before_gate=True, + is_rms_norm=False, +): + """Reference implementation for both layer norm and RMS norm.""" + if is_rms_norm: + # Use the imported rms_norm_ref for RMS norm cases + return rms_norm_ref( + x, + weight, + bias, + z=z, + eps=eps, + group_size=group_size, + norm_before_gate=norm_before_gate, + upcast=True, + ) + + # Layer norm implementation + dtype = x.dtype + x = x.float() + weight = weight.float() + bias = bias.float() if bias is not None else None + z = z.float() if z is not None else None + + if z is not None and not norm_before_gate: + x = x * F.silu(z) + + if group_size is None: + # Layer norm: subtract mean + mean = x.mean(dim=-1, keepdim=True) + var = ((x - mean).square()).mean(dim=-1, keepdim=True) + rstd = 1 / torch.sqrt(var + eps) + out = (x - mean) * rstd * weight + if bias is not None: + out = out + bias + else: + # Group norm + from einops import rearrange + + x_group = rearrange(x, "... (g d) -> ... g d", d=group_size) + mean = x_group.mean(dim=-1, keepdim=True) + var = ((x_group - mean).square()).mean(dim=-1, keepdim=True) + rstd = 1 / torch.sqrt(var + eps) + x_group = (x_group - mean) * rstd + out = rearrange(x_group, "... g d -> ... (g d)") * weight + if bias is not None: + out = out + bias + + if z is not None and norm_before_gate: + out *= F.silu(z) + + return out.to(dtype) + + +DTYPES = [torch.bfloat16, torch.float32] +# Test various M sizes to ensure rows_per_block logic works correctly +NUM_TOKENS = [ + 1, + 7, + 16, + 63, + 128, + 256, + 512, + 1024, + 2048, + 4096, + 5789, + 8189, + 8191, + 16383, + 32767, +] +HIDDEN_SIZES = [64, 128, 256, 1024] +GROUP_SIZES = [None, 64, 128] # None means full hidden size +NORM_BEFORE_GATE = [True, False] +IS_RMS_NORM = [True, False] +SEEDS = [0, 42] + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) +@torch.inference_mode() +def test_layer_norm_fwd_basic( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + seed: int, + is_rms_norm: bool, +) -> None: + """Test basic layer norm forward pass without z (gate) tensor.""" + current_platform.seed_everything(seed) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd( + x, weight, bias, eps, z=None, is_rms_norm=is_rms_norm + ) + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=is_rms_norm) + + # Check outputs + assert out.shape == x.shape + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + # Check mean and rstd shapes + if not is_rms_norm: + assert mean.shape == (num_tokens,) + assert rstd.shape == (num_tokens,) + + +@pytest.mark.parametrize("num_tokens", NUM_TOKENS) +@pytest.mark.parametrize("hidden_size", [128, 256, 1024]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("norm_before_gate", NORM_BEFORE_GATE) +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) +@torch.inference_mode() +def test_layer_norm_fwd_with_gate( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + norm_before_gate: bool, + is_rms_norm: bool, +) -> None: + """Test layer norm forward pass with z (gate) tensor.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + z = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd( + x, + weight, + bias, + eps, + z=z, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + + # Run reference implementation + ref_out = layer_norm_ref( + x, + weight, + bias, + z=z, + eps=eps, + norm_before_gate=norm_before_gate, + is_rms_norm=is_rms_norm, + ) + + # Check outputs + assert out.shape == x.shape + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", [128, 512]) +@pytest.mark.parametrize("hidden_size", [512, 1024]) +@pytest.mark.parametrize("group_size", [64, 128, 256]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("is_rms_norm", IS_RMS_NORM) +@torch.inference_mode() +def test_layer_norm_fwd_with_groups( + num_tokens: int, + hidden_size: int, + group_size: int, + dtype: torch.dtype, + is_rms_norm: bool, +) -> None: + """Test layer norm forward pass with group normalization.""" + if hidden_size % group_size != 0: + pytest.skip( + f"hidden_size {hidden_size} not divisible by group_size {group_size}" + ) + + current_platform.seed_everything(42) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = None if is_rms_norm else torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + ngroups = hidden_size // group_size + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd( + x, weight, bias, eps, z=None, group_size=group_size, is_rms_norm=is_rms_norm + ) + + # Run reference implementation + ref_out = layer_norm_ref( + x, weight, bias, z=None, eps=eps, group_size=group_size, is_rms_norm=is_rms_norm + ) + + # Check outputs + assert out.shape == x.shape + assert out.dtype == x.dtype + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + # Check mean and rstd shapes for groups + if not is_rms_norm: + assert mean.shape == (ngroups * num_tokens,) + assert rstd.shape == (ngroups * num_tokens,) + + +@pytest.mark.parametrize("num_tokens", [7, 63, 128, 513, 1024, 2049]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_layer_norm_rows_per_block( + num_tokens: int, + dtype: torch.dtype, +) -> None: + """Test that rows_per_block logic works correctly for various M sizes.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + hidden_size = 1024 + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel + out, mean, rstd = layer_norm_fwd(x, weight, bias, eps, z=None, is_rms_norm=False) + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) + + # Check outputs + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@torch.inference_mode() +def test_strided_input(dtype: torch.dtype) -> None: + """Test that the kernel handles non-contiguous (strided) + inputs correctly.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + num_tokens = 128 + hidden_size = 1024 + + # Create a larger tensor and take a strided slice + x_large = torch.randn(num_tokens, hidden_size * 2, dtype=dtype, device=device) + x = x_large[:, :hidden_size] + + # Make it contiguous for the kernel + x_contiguous = x.contiguous() + + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run the triton kernel with contiguous input + out, mean, rstd = layer_norm_fwd( + x_contiguous, weight, bias, eps, z=None, is_rms_norm=False + ) + + # Run reference implementation + ref_out = layer_norm_ref( + x_contiguous, weight, bias, z=None, eps=eps, is_rms_norm=False + ) + + # Check outputs + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", [1, 128, 2048]) +@pytest.mark.parametrize("hidden_size", [768, 4096]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_output_buffer_provided( + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, +) -> None: + """Test that the kernel works when an output buffer is provided.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + + # Create inputs + x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Pre-allocate output buffer + out_buffer = torch.empty_like(x) + + # Run the triton kernel with provided output + out, mean, rstd = layer_norm_fwd( + x, weight, bias, eps, z=None, out=out_buffer, is_rms_norm=False + ) + + # Check that the provided buffer was used + assert out.data_ptr() == out_buffer.data_ptr() + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) + + # Check outputs + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize( + "shape", + [ + (4, 16, 1024), # 3D tensor + (2, 8, 512, 256), # 4D tensor + ], +) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +@torch.inference_mode() +def test_multidimensional_input( + shape: tuple, + dtype: torch.dtype, +) -> None: + """Test that the autograd function handles multidimensional inputs.""" + current_platform.seed_everything(42) + device = torch.device("cuda:0") + hidden_size = shape[-1] + + # Create inputs + x = torch.randn(*shape, dtype=dtype, device=device) + weight = torch.randn(hidden_size, dtype=dtype, device=device) + bias = torch.randn(hidden_size, dtype=dtype, device=device) + eps = 1e-6 + + # Run through autograd function + out = layernorm_fn(x, weight, bias, z=None, eps=eps) + + # Run reference implementation + ref_out = layer_norm_ref(x, weight, bias, z=None, eps=eps, is_rms_norm=False) + + # Check outputs + assert out.shape == x.shape + torch.testing.assert_close(out, ref_out, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + # Run a quick smoke test + test_layer_norm_fwd_basic(128, 1024, torch.float16, 42, False) + test_layer_norm_fwd_with_gate(128, 1024, torch.float16, True, False) + test_layer_norm_rows_per_block(513, torch.float16) + print("All smoke tests passed!") diff --git a/vllm/model_executor/layers/fla/ops/layernorm_guard.py b/vllm/model_executor/layers/fla/ops/layernorm_guard.py index 655cdb3f30eb..6d039efe5876 100644 --- a/vllm/model_executor/layers/fla/ops/layernorm_guard.py +++ b/vllm/model_executor/layers/fla/ops/layernorm_guard.py @@ -13,6 +13,7 @@ # This backward pass is faster for dimensions up to 8k, but after that it's much slower due to register spilling. # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine. +from functools import lru_cache from typing import Optional import torch @@ -21,6 +22,7 @@ from einops import rearrange from vllm.triton_utils import tl, triton +from vllm.utils import cdiv, next_power_of_2 from .utils import input_guard @@ -76,55 +78,103 @@ def layer_norm_fwd_kernel( stride_y_row, stride_z_row, M, # number of rows in X - N, # number of columns in X + N: tl.constexpr, # number of columns in X eps, # epsilon to avoid division by zero BLOCK_N: tl.constexpr, + ROWS_PER_BLOCK: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_Z: tl.constexpr, NORM_BEFORE_GATE: tl.constexpr, IS_RMS_NORM: tl.constexpr, ): - # Map the program id to the row of X and Y it should compute. - row = tl.program_id(0) + # Map the program id to the starting row of X and Y it should compute. + row_start = tl.program_id(0) * ROWS_PER_BLOCK group = tl.program_id(1) - X += row * stride_x_row + group * N - Y += row * stride_y_row + group * N - if HAS_Z: - Z += row * stride_z_row + group * N - if not IS_RMS_NORM: - Mean += group * M - Rstd += group * M - W += group * N - if HAS_BIAS: - B += group * N - # Compute mean and variance + + # Create 2D tile: [ROWS_PER_BLOCK, BLOCK_N] + rows = row_start + tl.arange(0, ROWS_PER_BLOCK) cols = tl.arange(0, BLOCK_N) - x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + + # Compute offsets for 2D tile + row_offsets = rows[:, None] * stride_x_row + col_offsets = cols[None, :] + group * N + + # Base pointers + X_base = X + row_offsets + col_offsets + Y_base = Y + rows[:, None] * stride_y_row + col_offsets + + # Create mask for valid rows and columns + row_mask = rows[:, None] < M + col_mask = cols[None, :] < N + mask = row_mask & col_mask + + # Load input data with 2D tile + x = tl.load(X_base, mask=mask, other=0.0).to(tl.float32) + if HAS_Z and not NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) x *= z * tl.sigmoid(z) + + # Compute mean and variance per row (reduce along axis 1) if not IS_RMS_NORM: - mean = tl.sum(x, axis=0) / N - tl.store(Mean + row, mean) - xbar = tl.where(cols < N, x - mean, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N + mean = tl.sum(x, axis=1) / N # Shape: [ROWS_PER_BLOCK] + # Store mean for each row + mean_offsets = group * M + rows + mean_mask = rows < M + tl.store(Mean + mean_offsets, mean, mask=mean_mask) + # Broadcast mean back to 2D for subtraction + xbar = tl.where(mask, x - mean[:, None], 0.0) + var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK] else: - xbar = tl.where(cols < N, x, 0.0) - var = tl.sum(xbar * xbar, axis=0) / N - rstd = 1 / tl.sqrt(var + eps) - tl.store(Rstd + row, rstd) - # Normalize and apply linear transformation - mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + xbar = tl.where(mask, x, 0.0) + var = tl.sum(xbar * xbar, axis=1) / N # Shape: [ROWS_PER_BLOCK] + mean = 0.0 # Placeholder for RMS norm + + rstd = tl.rsqrt(var + eps) # Shape: [ROWS_PER_BLOCK] + + # Store rstd for each row + rstd_offsets = group * M + rows + rstd_mask = rows < M + tl.store(Rstd + rstd_offsets, rstd, mask=rstd_mask) + + # Load weights and biases (broadcast across rows) + w_offsets = cols + group * N + w_mask = cols < N + w = tl.load(W + w_offsets, mask=w_mask, other=0.0).to(tl.float32) + if HAS_BIAS: - b = tl.load(B + cols, mask=mask).to(tl.float32) - x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd - y = x_hat * w + b if HAS_BIAS else x_hat * w + b = tl.load(B + w_offsets, mask=w_mask, other=0.0).to(tl.float32) + + # Normalize and apply linear transformation + if not IS_RMS_NORM: + x_hat = (x - mean[:, None]) * rstd[:, None] + else: + x_hat = x * rstd[:, None] + + y = x_hat * w[None, :] + b[None, :] if HAS_BIAS else x_hat * w[None, :] + if HAS_Z and NORM_BEFORE_GATE: - z = tl.load(Z + cols, mask=mask).to(tl.float32) + Z_base = Z + rows[:, None] * stride_z_row + col_offsets + z = tl.load(Z_base, mask=mask, other=0.0).to(tl.float32) y *= z * tl.sigmoid(z) + # Write output - tl.store(Y + cols, y, mask=mask) + tl.store(Y_base, y, mask=mask) + + +@lru_cache +def _get_sm_count(device: torch.device) -> int: + """Get and cache the SM count for a given device.""" + props = torch.cuda.get_device_properties(device) + return props.multi_processor_count + + +def calc_rows_per_block(M: int, device: torch.device) -> int: + sm_count = _get_sm_count(device) + rows_per_block = next_power_of_2(cdiv(M, 2 * sm_count)) + rows_per_block = min(rows_per_block, 4) + return rows_per_block def layer_norm_fwd( @@ -171,7 +221,10 @@ def layer_norm_fwd( raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") # heuristics for number of warps num_warps = min(max(BLOCK_N // 256, 1), 8) - grid = (M, ngroups) + # Calculate rows per block based on SM count + rows_per_block = calc_rows_per_block(M, x.device) + # Update grid to use rows_per_block + grid = (cdiv(M, rows_per_block), ngroups) layer_norm_fwd_kernel[grid]( x, out, @@ -187,6 +240,7 @@ def layer_norm_fwd( group_size, eps, BLOCK_N=BLOCK_N, + ROWS_PER_BLOCK=rows_per_block, NORM_BEFORE_GATE=norm_before_gate, IS_RMS_NORM=is_rms_norm, num_warps=num_warps,