In [2]:
from torch import nn
import torch

In [3]:
class RMSNorm(nn.Module):
    def __init__(self, d, p=-1, eps=1e-8, bias=False):
        """
        Root Mean Square Layer Normalization, https://arxiv.org/abs/1910.07467
        :param d: model size
        :param p: partial RMSNorm, valid value [0, 1], default -1.0 (disabled)
        :param eps:  epsilon value, default 1e-8
        :param bias: whether use bias term for RMSNorm, disabled by
            default because RMSNorm doesn't enforce re-centering invariance.
        """
        super().__init__()

        self.eps = eps
        self.d = d
        self.p = p
        self.bias = bias

        self.scale = nn.Parameter(torch.ones(d))
        self.register_parameter("scale", self.scale)

        if self.bias:
            self.offset = nn.Parameter(torch.zeros(d))
            self.register_parameter("offset", self.offset)

    def forward(self, x):
        if self.p < 0. or self.p > 1.:
            norm_x = x.norm(2, dim=-1, keepdim=True)
            d_x = self.d
        else:
            partial_size = int(self.d * self.p)
            partial_x, _ = torch.split(x, [partial_size, self.d - partial_size], dim=-1)

            norm_x = partial_x.norm(2, dim=-1, keepdim=True)
            d_x = partial_size

        rms_x = norm_x * d_x ** (-1. / 2)
        x_normed = x / (rms_x + self.eps)

        if self.bias:
            return self.scale * x_normed + self.offset
        return self.scale * x_normed


class RMSNormGemma(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6, add_unit_offser: bool = True):
        super().__init__()

        self.eps = eps
        self.add_unit_offser = add_unit_offser
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
    
    def forward(self, x):
        x = self._norm(x.float()).type_as(x)
        if self.add_unit_offser:
            output = x * (1 + self.weight)
        else:
            output = x * self.weight
        return output



In [7]:
x = torch.tensor([[1.,2.,3.],[4.,5.,6.]])
rmsnorm1 = RMSNorm(3)
rmsnorm2 = RMSNormGemma(3)
print(rmsnorm1(x))
print(rmsnorm2(x))

tensor([[0.4629, 0.9258, 1.3887],
        [0.7895, 0.9869, 1.1843]], grad_fn=<MulBackward0>)
tensor([[0.4629, 0.9258, 1.3887],
        [0.7895, 0.9869, 1.1843]], grad_fn=<MulBackward0>)


In [6]:
x = torch.tensor([[1.,2.,3.],[4.,5.,6.]])
print(x / (x.norm(2, dim=-1, keepdim=True)* 3 ** (-1. / 2)))
print(x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True)))

tensor([[0.4629, 0.9258, 1.3887],
        [0.7895, 0.9869, 1.1843]])
tensor([[0.4629, 0.9258, 1.3887],
        [0.7895, 0.9869, 1.1843]])
