# RMS Norm

## Summary

Root Mean Square Layer Normalization, or **RMSNorm**, is a simplified version of the standard Layer Normalization.

First off, I will explain the basics of RMSNorm, and then pros of using it, followed by technical remarks

**The Mathematical Formula**

Unlike LayerNorm, which shifts and scales the input, RMSNorm only scales the input based on the root mean square of the elements. For an input vector $x$, the operation is defined as:

$$
\bar{a_{i}} = \frac{x_i}{\text{RMS}(x)}g_i, \;\;\;\;\;\;\text{where}\;\;\; \text{RMS}(x)=\sqrt{\frac{1}{n}\sum_{i=1}^{n}{{x_i}^2}+\epsilon}
$$

**Notation**

* $x_i$: Input feature vector.
* $n$: The number of features (the dimension of the vector).
* $\epsilon$: A very small value to prevent division by zero which is often fixed `1e-5`.
* $g_i$: A learnable scaling parameter applied element-wise, allowing the model to adjust the signal's magnitude.


### Implimention

In [None]:
import math
import torch
from statistics import mean
from jaxtyping import Float
import torch.nn as nn
from torch.nn.parameter import Parameter

class RMSNorm(nn.Module):

    def __init__(
        self,
        d_model: int,
        eps: float = 1e-5,
        device: torch.device | None = None,
        dtype: torch.dtype | None = None,
    )-> None:
        
        factory_kwargs = {"device": device, "dtype": dtype}
        
        super().__init__()

        self.d_model = d_model
        self.eps = eps

        self.weight = Parameter(
            torch.empty(self.d_model, **factory_kwargs)
        )
        self.reset_parameters()
    
    def reset_parameters(self) -> None:
        nn.init.trunc_normal_(self.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        in_dtype = x.dtype

        x = x.to(torch.float32)

        rms = torch.sqrt(self.eps + torch.mean(x**2, dim=-1, keepdim=True)) + self.eps
        x_norm = x / rms
        
        result = x_norm * self.weight
        
        return result.to(in_dtype)

## Details

### Pre-Norm vs Post-Norm

**Learning Stability `pre-norm` > `post-norm`**

<img src="../images/PreNormPostNorm.png" width="70%">

This paragraph is describing **where LayerNorm is placed** inside a Transformer block.


**1.Post-Norm (Original Transformer, 2017)**

In the original paper (Vaswani et al.), they did:

```
x → F(x) → + → LayerNorm → output
      ↑
      x (residual)
```



So mathematically:

$$
y = \text{LayerNorm}(x + F(x))
$$

<!-- Then:

$$
x = \text{LayerNorm}(x + \text{MLP}(x))
$$
 -->

---

**2 Pre-Norm (Modern Transformers)**

Now modern models do this instead:

```
x → LayerNorm → F(x) → +
↑                      ↓
└─────── residual ─────┘
```

So mathematically:

$$
y = x + F(\text{LayerNorm}(x))
$$

<!-- Then:

$$
x = x + \text{MLP}(\text{LayerNorm}(x))
$$ -->

Now normalization happens **before** the sublayer.

---

**Why is Pre-Norm better?**

The key idea:

In pre-norm, there is a **clean residual path**. **This helps gradients flow better!!** -> why?

But first of all, **What does “clean residual path” mean?**

In Pre-Norm:

$$
y = x + F(x)
$$

So derivative w.r.t. x:

$$
\frac{dy}{dx} = 1 + \frac{dF}{dx}
$$

There is always a direct identity term:

$$
\frac{dy}{dx} \supset 1
$$

if $\frac{dF}{dx} \approx 0$. This is the key.

**Why this helps gradient flow?**

When backpropagating:

$$
\frac{dL}{dx} = \frac{dL}{dy} \cdot \frac{dy}{dx}\;\;\;\;\;\; \text{where}\;\;\; L: \text{loss function}
$$

If:

$$
\frac{dy}{dx} = 1 + \text{small term}
$$

Then:

$$
\frac{dL}{dx} \approx \frac{dL}{dy}
$$

Gradient can pass almost unchanged.

So even if:

* F(x) saturates
* F(x) has small gradients
* F(x) has exploding gradients

The identity path guarantees:
