### FFN, norm in LLM

In [5]:
import torch
from torch import nn
import torch.nn.functional as F

In [8]:
class FeedForward(nn.Module):
    def __init__(self, dim: int, hidden_dim: int):
        super(FeedForward, self).__init__()
        self.dim = dim
        self.hidden_dim = hidden_dim
        self.w1 = nn.Linear(self.dim, self.hidden_dim)
        self.w2 = nn.Linear(self.dim, self.hidden_dim)
        self.w3 = nn.Linear(self.hidden_dim, self.dim)
    def forward(self, x: torch.Tensor):
        return self.w3(F.silu(self.w1(x)) * self.w2(x))

In [10]:
ffn = FeedForward(64, 256)
input = torch.rand((32, 16, 64))
out = ffn(input)
out, out.shape

(tensor([[[-0.0324,  0.0382,  0.0036,  ...,  0.0129,  0.0053, -0.0337],
          [-0.0264,  0.0280,  0.0458,  ...,  0.0265, -0.0198, -0.0292],
          [-0.0722,  0.0506, -0.0119,  ...,  0.0156,  0.0095,  0.0085],
          ...,
          [-0.0248,  0.0268,  0.0094,  ..., -0.0036, -0.0083, -0.0375],
          [-0.0650,  0.0223,  0.0308,  ..., -0.0044, -0.0198,  0.0015],
          [-0.0594,  0.0161, -0.0010,  ...,  0.0190, -0.0027, -0.0419]],
 
         [[-0.0659,  0.0358,  0.0241,  ...,  0.0315,  0.0461, -0.0354],
          [-0.0520,  0.0563, -0.0149,  ...,  0.0064, -0.0176, -0.0308],
          [-0.0007,  0.0532,  0.0424,  ..., -0.0126,  0.0382, -0.0633],
          ...,
          [-0.0668,  0.0489,  0.0088,  ...,  0.0140,  0.0172, -0.0055],
          [-0.0122,  0.0284,  0.0109,  ...,  0.0213,  0.0162, -0.0819],
          [-0.0381,  0.0546,  0.0286,  ...,  0.0238,  0.0037, -0.0294]],
 
         [[-0.0947,  0.1174, -0.0368,  ..., -0.0188,  0.0275,  0.0025],
          [-0.0315,  0.0335,

In [11]:
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-8):
        super(RMSNorm, self).__init__()
        self.eps = eps
        self.dim = dim
        self.weight = nn.Parameter(torch.ones((self.dim, )))
    def forward(self, x: torch.Tensor):
        return x / torch.mean(x.pow(2), dim=-1, keepdim=True) * self.weight

In [13]:
rmsnorm = RMSNorm(32)
input = torch.rand((16, 64, 32))
out = rmsnorm(input)
out, out.shape

(tensor([[[1.7381, 2.1965, 0.8053,  ..., 2.0828, 1.3132, 2.5500],
          [0.9285, 0.9444, 3.1861,  ..., 1.0493, 0.5654, 1.6959],
          [2.2777, 1.6707, 0.6436,  ..., 0.0892, 0.4929, 2.4850],
          ...,
          [2.4916, 3.3569, 2.0914,  ..., 1.8038, 2.1059, 2.1821],
          [1.7375, 1.9472, 0.3976,  ..., 0.1947, 2.4472, 1.9364],
          [2.2786, 2.8348, 1.0656,  ..., 0.0505, 1.2916, 0.8681]],
 
         [[2.2399, 0.1426, 2.0694,  ..., 0.8386, 2.4350, 0.2784],
          [0.9224, 2.5744, 0.9108,  ..., 2.2597, 2.2281, 0.1601],
          [1.3694, 2.2160, 1.0938,  ..., 1.4120, 0.9671, 2.8722],
          ...,
          [2.2928, 0.4488, 2.3437,  ..., 0.0244, 0.3041, 1.3093],
          [0.5280, 0.7425, 2.4410,  ..., 2.1261, 0.5390, 1.0719],
          [1.5403, 2.3955, 2.7408,  ..., 0.2198, 1.4365, 0.9891]],
 
         [[1.1112, 0.4813, 1.0927,  ..., 0.3830, 1.5777, 1.1195],
          [1.2439, 0.6443, 2.7290,  ..., 1.6907, 2.7991, 0.0906],
          [2.2355, 2.0358, 2.2947,  ...,