In [15]:
import torch
import torch.nn as nn

In [16]:
print(torch.__version__)

2.3.1+cu121


In [17]:
class RMSNorm(nn.Module):
    def __init__(self, dim, eps = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        rms = torch.sqrt(torch.mean(x ** 2, dim = 1, keepdim=True) + self.eps)
        # Normalize and scale
        x_normed = x / rms
        return self.scale * x_normed

In [18]:
N, C, H, W = 2, 3, 4, 4

torch.manual_seed(42)
#rms_norm = nn.RMSNorm([C, H, W]) # pytorch version 2.5 required
rms_norm = RMSNorm(C*H*W)
input = torch.randn(N, C, H, W)
input

tensor([[[[ 1.9269,  1.4873,  0.9007, -2.1055],
          [ 0.6784, -1.2345, -0.0431, -1.6047],
          [-0.7521,  1.6487, -0.3925, -1.4036],
          [-0.7279, -0.5594, -0.7688,  0.7624]],

         [[ 1.6423, -0.1596, -0.4974,  0.4396],
          [-0.7581,  1.0783,  0.8008,  1.6806],
          [ 1.2791,  1.2964,  0.6105,  1.3347],
          [-0.2316,  0.0418, -0.2516,  0.8599]],

         [[-1.3847, -0.8712, -0.2234,  1.7174],
          [ 0.3189, -0.4245,  0.3057, -0.7746],
          [-1.5576,  0.9956, -0.8798, -0.6011],
          [-1.2742,  2.1228, -1.2347, -0.4879]]],


        [[[-0.9138, -0.6581,  0.0780,  0.5258],
          [-0.4880,  1.1914, -0.8140, -0.7360],
          [-1.4032,  0.0360, -0.0635,  0.6756],
          [-0.0978,  1.8446, -1.1845,  1.3835]],

         [[ 1.4451,  0.8564,  2.2181,  0.5232],
          [ 0.3466, -0.1973, -1.0546,  1.2780],
          [-0.1722,  0.5238,  0.0566,  0.4263],
          [ 0.5750, -0.6417, -2.2064, -0.7508]],

         [[ 0.0109, -0.3387,

In [19]:
normalized = rms_norm(input.view(N, -1))
normalized

tensor([[ 1.7691,  1.3655,  0.8269, -1.9331,  0.6229, -1.1334, -0.0395, -1.4732,
         -0.6905,  1.5137, -0.3603, -1.2887, -0.6683, -0.5136, -0.7059,  0.7000,
          1.5078, -0.1465, -0.4567,  0.4036, -0.6960,  0.9900,  0.7352,  1.5430,
          1.1744,  1.1902,  0.5605,  1.2254, -0.2127,  0.0383, -0.2310,  0.7894,
         -1.2713, -0.7999, -0.2051,  1.5767,  0.2928, -0.3898,  0.2807, -0.7112,
         -1.4300,  0.9141, -0.8077, -0.5519, -1.1698,  1.9489, -1.1335, -0.4480],
        [-0.9633, -0.6938,  0.0822,  0.5543, -0.5144,  1.2559, -0.8581, -0.7758,
         -1.4792,  0.0380, -0.0669,  0.7122, -0.1031,  1.9444, -1.2487,  1.4584,
          1.5234,  0.9028,  2.3381,  0.5515,  0.3654, -0.2080, -1.1117,  1.3472,
         -0.1815,  0.5521,  0.0597,  0.4494,  0.6061, -0.6765, -2.3258, -0.7914,
          0.0115, -0.3571, -1.4132, -0.6171,  0.5652,  0.5530,  1.2030,  0.0544,
          0.7842, -0.5077, -1.1063,  0.6366, -1.8155, -0.8726,  1.4069,  0.5097]],
       grad_fn=<MulBackwa

In [20]:
# Print stats
print('Mean and std of input:')
for i in range(N):
    img_s = input[i].flatten()
    print(f"Image {i}:")
    print(f"Mean: {img_s.mean():.6f}")
    print(f"RMS: {torch.sqrt(torch.mean(img_s ** 2)):.6f}")

print('\nMean and std after RMSNorm:')
for i in range(N):
    img_stats = normalized[i].flatten()
    print(f"Image {i}:")
    print(f"Mean: {img_stats.mean():.6f}")
    print(f"RMS: {torch.sqrt(torch.mean(img_stats ** 2)):.6f}")

Mean and std of input:
Image 0:
Mean: 0.056767
RMS: 1.089204
Image 1:
Mean: 0.035141
RMS: 0.948650

Mean and std after RMSNorm:
Image 0:
Mean: 0.052118
RMS: 1.000000
Image 1:
Mean: 0.037043
RMS: 1.000000
