In [4]:
import torch
from torch import nn

class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

In [17]:
ln = LlamaRMSNorm(3)
a = torch.rand(1,3)
a[0,0] = 10
a

tensor([[10.0000,  0.5684,  0.2856]])

In [18]:
ln(a)

tensor([[1.7286, 0.0982, 0.0494]], grad_fn=<MulBackward0>)