In [1]:
import torch
import torch.nn as nn

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, normalized_shape, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        return self.weight * x + self.bias

In [5]:
input_tensor = torch.randn(20, 30, 40)
input_tensor

tensor([[[ 0.7255, -0.6066, -0.6184,  ..., -0.2686, -0.6328,  2.1450],
         [ 0.9917, -2.4804,  0.5674,  ...,  0.2782,  0.0135,  0.7060],
         [ 0.7337,  0.7179, -0.7640,  ..., -0.3219,  0.5617,  0.6680],
         ...,
         [ 0.2087,  0.8269,  1.1416,  ...,  1.4732,  0.7298, -0.1580],
         [-0.1336, -0.9155, -1.2460,  ..., -0.3334,  1.3623, -0.4850],
         [ 2.0274,  2.2748, -0.7103,  ..., -0.0397,  0.5341,  0.7279]],

        [[-1.4518,  0.6383, -2.3612,  ...,  0.5441,  0.3309, -1.0578],
         [ 0.4910, -0.2876, -0.6747,  ...,  0.5881,  0.5189,  0.7032],
         [ 0.6569,  0.1317, -0.3139,  ...,  0.7912, -0.8796,  1.5632],
         ...,
         [ 0.1050,  2.1956,  0.6099,  ..., -0.1878, -0.0764,  0.5571],
         [ 1.4613,  1.0355, -0.4763,  ...,  1.6708, -0.6169, -1.4809],
         [-0.9514,  0.8928, -0.1484,  ..., -0.3699,  2.5144, -0.4620]],

        [[ 1.1680, -0.2562, -0.8612,  ...,  1.5782,  0.1415, -1.3018],
         [-1.3422, -0.7426,  1.1206,  ..., -1

In [7]:
layer_norm = LayerNorm(normalized_shape=40)
output_tensor = layer_norm(input_tensor)
output_tensor

tensor([[[ 6.2201e-01, -6.4169e-01, -6.5287e-01,  ..., -3.2104e-01,
          -6.6658e-01,  1.9687e+00],
         [ 1.0998e+00, -2.3756e+00,  6.7506e-01,  ...,  3.8564e-01,
           1.2067e-01,  8.1384e-01],
         [ 9.1971e-01,  9.0310e-01, -6.5848e-01,  ..., -1.9268e-01,
           7.3847e-01,  8.5053e-01],
         ...,
         [-3.3803e-02,  5.7626e-01,  8.8687e-01,  ...,  1.2141e+00,
           4.8041e-01, -3.9570e-01],
         [-1.6432e-01, -8.7389e-01, -1.1738e+00,  ..., -3.4563e-01,
           1.1933e+00, -4.8325e-01],
         [ 1.6317e+00,  1.8591e+00, -8.8420e-01,  ..., -2.6792e-01,
           2.5947e-01,  4.3756e-01]],

        [[-1.3466e+00,  7.9828e-01, -2.2798e+00,  ...,  7.0169e-01,
           4.8288e-01, -9.4231e-01],
         [ 7.2889e-01, -3.7280e-02, -4.1825e-01,  ...,  8.2443e-01,
           7.5634e-01,  9.3770e-01],
         [ 8.2664e-01,  1.9626e-01, -3.3874e-01,  ...,  9.8794e-01,
          -1.0178e+00,  1.9147e+00],
         ...,
         [ 2.2280e-01,  2

In [8]:
print(output_tensor.shape)

torch.Size([20, 30, 40])
