# RMSNorm

`RMSNorm`计算公式为：
$$RMSNorm: y =\frac {x} {\sqrt{Mean(x^2)+\epsilon}} *\gamma $$

$$ Mean(x^2)=\frac{1}{N} \sum_{i=1}^N x_i^2$$

RMSNorm 之所以能更高效，是因为其创造者发现 LayerNorm 的优势在于 rescaling invariance（译者注：指的是归一化过程能够适应输入数据的缩放，使得网络对这种缩放不敏感。），而非 recentering invariance（译者注：如果输入数据的均值发生了变化，但数据的分布形状和范围保持不变，那么具有 recentering invariance 的算法或函数的输出应该不受影响。）。基于这一发现，他们省略了归一化过程中的均值计算，使得算法更加简洁，而效果不减，且运算效率显著提升。
<p align="center">
    <img src="./_img/LayerNorm_comp_RMSNorm.png" width="80%"/> <br>
    层归一化（LayerNorm）与均方根归一化（RMSNorm）之间的方程差异
</p>

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization.

    Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
    https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
    """

    def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
        super().__init__()
        self.scale = nn.Parameter(torch.ones(size))
        self.eps = eps
        self.dim = dim

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # NOTE: the original RMSNorm paper implementation is not equivalent
        # norm_x = x.norm(2, dim=self.dim, keepdim=True)
        # rms_x = norm_x * d_x ** (-1. / 2)
        # x_normed = x / (rms_x + self.eps)
        norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
        x_normed = x * torch.rsqrt(norm_x + self.eps)
        return self.scale * x_normed