# Root Mean Square Normalization ( RMSNorm ) Benchmarks

## CL Configuration

In [1]:
import os

os.environ['PATH'] += r';C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\bin\HostX86\x86'

!where cl

C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\bin\Hostx86\x86\cl.exe


## Install requirements

In [2]:
!pip install torch numpy pydantic



DEPRECATION: Loading egg at c:\users\ovuru\appdata\local\programs\python\python312\lib\site-packages\rope_cuda-0.0.0-py3.12-win-amd64.egg is deprecated. pip 24.3 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330


## Import requirements

In [3]:
import torch
import torch.nn as N
from torch.utils.cpp_extension import load_inline
import numpy as np
from pydantic import BaseModel, Field
from typing import List, Union, Dict, Any, Type, Tuple
import time


## Define Utility Models

In [4]:
class BenchmarkCase(BaseModel):
    """
    Represents a test case for benchmarking neural network layers.

    Attributes:
        test_name (str): Name of the test case.
        original_layer (Type[N.Module]): The original layer class to be tested.
        new_layer (Type[N.Module]): The new layer class to be compared against the original.
        layer_args (Dict[str, Any]): Positional arguments for layer initialization.
        layer_kwargs (Dict[str, Any]): Keyword arguments for layer initialization.
        forward_args (Dict[str, Any]): Positional arguments for the forward pass.
        forward_kwargs (Dict[str, Any]): Keyword arguments for the forward pass.
        num_iterations (int): Number of iterations for each test run.
        num_runs (int): Number of test runs to perform.

    Example:
        TestCase(
            test_name="RoPE Test",
            original_layer=RoPEOriginal,
            new_layer=Rope,
            layer_args={},
            layer_kwargs={"dim": 1000, "end": 1000},
            forward_args={},
            forward_kwargs={"dim": 1000, "end": 1000},
            num_iterations=1000,
            num_runs=100
        )
    """
    test_name: str
    original_layer: Type[N.Module]
    new_layer: Type[N.Module]
    layer_args: Dict[str, Any] = Field(default_factory=dict)
    layer_kwargs: Dict[str, Any] = Field(default_factory=dict)
    forward_args: Dict[str, Any] = Field(default_factory=dict)
    forward_kwargs: Dict[str, Any] = Field(default_factory=dict)
    num_iterations: int = 1000
    num_runs: int = 100

class LayerTestResult(BaseModel):
    """
    Represents the result of a single layer test run.

    Attributes:
        execution_time (float): The execution time of the layer.
        output: The output of the layer's forward pass.
    """
    execution_time: float
    output: Any

class TestRunResult(BaseModel):
    """
    Represents the result of a single test run comparing two layers.

    Attributes:
        speedup (float): The speedup factor of the new layer compared to the original.
        max_diff (float): The maximum absolute difference between the outputs.
        mean_diff (float): The mean absolute difference between the outputs.
        new_layer_result (LayerTestResult): The test result for the new layer.
        original_layer_result (LayerTestResult): The test result for the original layer.
    """
    speedup: float
    max_diff: float
    mean_diff: float
    new_layer_result: LayerTestResult
    original_layer_result: LayerTestResult

class BlackBoxTestResult(BaseModel):
    """
    Represents the overall result of black box testing for a test case.

    Attributes:
        test_case (BenchmarkCase): The test case that was run.
        results (List[TestRunResult]): The results of all test runs.
    """
    test_case: BenchmarkCase
    results: List[TestRunResult]

class BenchmarkAnalysis(BaseModel):
    """
    Represents the analysis of benchmark results.

    Attributes:
        avg_speedup (float): The average speedup across all runs.
        std_speedup (float): The standard deviation of speedup.
        min_speedup (float): The minimum speedup observed.
        max_speedup (float): The maximum speedup observed.
        avg_max_diff (float): The average of maximum differences.
        avg_mean_diff (float): The average of mean differences.
        new_layer_stats (Dict[str, float]): Statistics for the new layer's performance.
        original_layer_stats (Dict[str, float]): Statistics for the original layer's performance.
    """
    avg_speedup: float
    std_speedup: float
    min_speedup: float
    max_speedup: float
    avg_max_diff: float
    avg_mean_diff: float
    new_layer_stats: Dict[str, float]
    original_layer_stats: Dict[str, float]


# Base RMS Norm Implementation

In [5]:
# Based on Black Forest Labs
# Reference: https://github.com/black-forest-labs/flux/blob/478338d52759f92af9eeb92cc9eaa49582b20c78/src/flux/modules/layers.py#L63
class BaseRMSNorm(N.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.scale = torch.nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor):
        x_dtype = x.dtype
        x = x.float()
        rrms = torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + 1e-6)
        return (x * rrms).to(dtype=x_dtype) * self.scale


# Our Implementation

In [18]:
cuda_source ="""
#include <torch/extension.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cooperative_groups.h>
#include <algorithm>


template <typename scalar_t>
__global__ void rmsnorm_cuda_forward_kernel(
    const scalar_t* __restrict__ input,
    const scalar_t* __restrict__ scale,
    scalar_t* __restrict__ output,
    uint64_t dim
    ) {
    __shared__ float shared_sum;

    const float float_dim = static_cast<float>(dim);

    const uint64_t input_idx = threadIdx.x + blockIdx.z * blockDim.x + blockIdx.y * blockDim.x * gridDim.z + blockIdx.x * blockDim.x * gridDim.y * gridDim.z;

    if (threadIdx.x == 0){
        shared_sum = 0;
    }
    __syncthreads();

    float item_mean = static_cast<float>(input[input_idx] * input[input_idx]);
    atomicAdd(&shared_sum, item_mean);
    __syncthreads();

    if (threadIdx.x == 0){
        shared_sum = rsqrt(shared_sum / float_dim + 1e-6);
    }
    __syncthreads();

    output[input_idx] = static_cast<scalar_t>(input[input_idx] * shared_sum * scale[input_idx % dim]);
}

torch::Tensor rmsnorm_forward(
    torch::Tensor input,
    torch::Tensor scale) {
    auto output = torch::empty_like(input);
    // Check if 3-d or 4-d, it cannot be 2-d or 5-d>
    const int input_dims = input.dim();

    if (input_dims < 3 || input_dims > 4) {
        throw std::invalid_argument("Input must be 3-d or 4-d tensor");
    }
    if (input_dims == 3) {
        input = input.unsqueeze(0);
    }

    uint64_t x = input.size(0);
    uint64_t y = input.size(1);
    uint64_t z = input.size(2);
    uint64_t threads_per_block = input.size(3);

    dim3 blocks_per_grid(x, y, z);

    AT_DISPATCH_FLOATING_TYPES_AND_HALF(input.scalar_type(), "rmsnorm_forward_cuda", ([&] {
        rmsnorm_cuda_forward_kernel<scalar_t><<<blocks_per_grid, threads_per_block>>>(
            input.data_ptr<scalar_t>(),
            scale.data_ptr<scalar_t>(),
            output.data_ptr<scalar_t>(),
            threads_per_block
        );
    }));

    return output;
}

"""

cpp_source = """
torch::Tensor rmsnorm_forward(
    torch::Tensor input,
    torch::Tensor scale);
"""

our_extension = load_inline(
    name='our_extension',
    cpp_sources=cpp_source,
    cuda_sources=cuda_source,
    functions=['rmsnorm_forward'],
    with_cuda=True,
    extra_cuda_cflags=["-O3",],
    verbose=True
)

class OurRMSNorm(BaseRMSNorm):
    
    def forward(self, x: torch.Tensor):
        return our_extension.rmsnorm_forward(x, self.scale)


Using C:\Users\ovuru\AppData\Local\torch_extensions\torch_extensions\Cache\py312_cu124 as PyTorch extensions root...
The input conditions for extension module our_extension have changed. Bumping to version 3 and re-building as our_extension_v3...
Detected CUDA files, patching ldflags
Emitting ninja build file C:\Users\ovuru\AppData\Local\torch_extensions\torch_extensions\Cache\py312_cu124\our_extension\build.ninja...
Building extension module our_extension_v3...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module our_extension_v3...


In [15]:
def test_rmsnorm_forward(norm_class, shape, dtype, device, input_tensor):
    dim = shape[-1]
    input_tensor = input_tensor.to(device, dtype)
    norm = norm_class(dim).to(device, dtype)

    start_time = time.perf_counter()
    output = norm(input_tensor)
    end_time = time.perf_counter()

    return end_time - start_time, output


def run_tests(
    shape: Tuple[int, ...],
    dtype: torch.dtype,
    device: str,
    test_name: str,
    num_runs: int = 100,
) -> List[Tuple[float, float, float, float, float]]:
    results = []
    for _ in range(num_runs):
        input_tensor = torch.randn(*shape, dtype=dtype, device=device)
        cuda_time, cuda_output = test_rmsnorm_forward(OurRMSNorm, shape, dtype, device, input_tensor)
        original_time, original_output = test_rmsnorm_forward(
            BaseRMSNorm, shape, dtype, device, input_tensor
        )

        speedup = original_time / cuda_time if cuda_time > 0 else float("inf")
        max_diff = torch.max(torch.abs(cuda_output - original_output)).item()
        mean_diff = torch.mean(torch.abs(cuda_output - original_output)).item()

        results.append((speedup, max_diff, mean_diff, cuda_time, original_time))

    return results


def analyze_results(
    results: List[Tuple[float, float, float, float, float]]
) -> Tuple[float, float, float, float, float, float, List[float], List[float]]:
    speedups, max_diffs, mean_diffs, cuda_times, original_times = zip(*results)

    avg_speedup = np.mean(speedups)
    std_speedup = np.std(speedups)
    avg_max_diff = np.mean(max_diffs)
    avg_mean_diff = np.mean(mean_diffs)

    cuda_stats = [
        np.mean(cuda_times),
        np.min(cuda_times),
        np.max(cuda_times),
        np.std(cuda_times),
    ]
    original_stats = [
        np.mean(original_times),
        np.min(original_times),
        np.max(original_times),
        np.std(original_times),
    ]

    return (
        avg_speedup,
        std_speedup,
        np.min(speedups),
        np.max(speedups),
        avg_max_diff,
        avg_mean_diff,
        cuda_stats,
        original_stats,
    )


def print_analysis(
    test_name: str,
    input_shape: Tuple[int, ...],
    dtype: torch.dtype,
    device: str,
    analysis: Tuple[float, float, float, float, float, float, List[float], List[float]],
):
    (
        avg_speedup,
        std_speedup,
        min_speedup,
        max_speedup,
        avg_max_diff,
        avg_mean_diff,
        cuda_stats,
        original_stats,
    ) = analysis

    print(f"\n{test_name} Test")
    print(f"Input shape: {input_shape}")
    print(f"Data type: {dtype}")
    print(f"Device: {device}")
    print(f"Average Speedup: {avg_speedup:.2f}x (±{std_speedup:.2f})")
    print(f"Speedup Range: {min_speedup:.2f}x - {max_speedup:.2f}x")
    print(f"Average Max Diff: {avg_max_diff:.6f}")
    print(f"Average Mean Diff: {avg_mean_diff:.6f}")

    print("\nCUDA Implementation:")
    print(f"  Average Time: {cuda_stats[0]:.6f} seconds")
    print(f"  Min Time: {cuda_stats[1]:.6f} seconds")
    print(f"  Max Time: {cuda_stats[2]:.6f} seconds")
    print(f"  Std Dev Time: {cuda_stats[3]:.6f} seconds")

    print("\nOriginal Implementation:")
    print(f"  Average Time: {original_stats[0]:.6f} seconds")
    print(f"  Min Time: {original_stats[1]:.6f} seconds")
    print(f"  Max Time: {original_stats[2]:.6f} seconds")
    print(f"  Std Dev Time: {original_stats[3]:.6f} seconds")


In [19]:
num_runs = 10
device = "cuda" if torch.cuda.is_available() else "cpu"

test_cases = [
    ("Small 3D", (1, 1, 32)),
    ("Medium 3D", (32, 128, 512)),
    ("Large 3D", (128, 256, 512)),
    ("Small 4D", (1, 1, 1, 32)),
    ("Medium 4D", (2, 16, 32, 1024)),
    ("Large 4D", (8, 16, 32, 1024)),
]

for name, shape in test_cases:
    results = run_tests(shape, torch.float32, device, name, num_runs)
    analysis = analyze_results(results)
    print_analysis(name, shape, torch.float32, device, analysis)

# Additional tests for different data types
data_types = [torch.float32, torch.float16]
for dtype in data_types:
    shape = (32, 128, 512)
    results = run_tests(shape, dtype, device, f"Data type: {dtype}", num_runs)
    analysis = analyze_results(results)
    print_analysis(f"Data type: {dtype}", shape, dtype, device, analysis)



Small 3D Test
Input shape: (1, 1, 32)
Data type: torch.float32
Device: cuda
Average Speedup: 3.12x (±0.77)
Speedup Range: 1.08x - 3.87x
Average Max Diff: 0.000000
Average Mean Diff: 0.000000

CUDA Implementation:
  Average Time: 0.000192 seconds
  Min Time: 0.000055 seconds
  Max Time: 0.001182 seconds
  Std Dev Time: 0.000332 seconds

Original Implementation:
  Average Time: 0.000369 seconds
  Min Time: 0.000185 seconds
  Max Time: 0.001276 seconds
  Std Dev Time: 0.000317 seconds

Medium 3D Test
Input shape: (32, 128, 512)
Data type: torch.float32
Device: cuda
Average Speedup: 4.22x (±1.70)
Speedup Range: 2.43x - 8.16x
Average Max Diff: 0.000002
Average Mean Diff: 0.000000

CUDA Implementation:
  Average Time: 0.000093 seconds
  Min Time: 0.000052 seconds
  Max Time: 0.000155 seconds
  Std Dev Time: 0.000035 seconds

Original Implementation:
  Average Time: 0.000357 seconds
  Min Time: 0.000204 seconds
  Max Time: 0.000521 seconds
  Std Dev Time: 0.000106 seconds

Large 3D Test
Inpu