In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass


@dataclass
class ModelConfig:
    hidden_size:int=4096
    intermediate_size:int=14336
    max_position_embeddings:int=8192
    num_attention_heads:int=32
    num_hidden_layers:int=32
    num_key_value_heads:int=8
    rms_norm_eps:float=1e-5
    vocab_size:int=128000 # this param should be equal to tokenizer.vocab_size

config = ModelConfig()

In [9]:
class RMS_Norm(nn.Module):
    """
    RMS归一化,应当在先归一化再过子层
    """
    def __init__(self, config) -> None:
        super(RMS_Norm, self).__init__()
        
        self.refactor = nn.Parameter(torch.ones(config.hidden_size))
        self.rms_norm_eps = config.rms_norm_eps
        
    def __norm(self, x):
        # 加了eps确保不会出现除以0的情况，注意这里时rsqrt取根号的倒数
        rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.rms_norm_eps)
        return x * rms
    
    def forward(self, x):
        x = self.__norm(x)
        x = x * self.refactor
        return x
    
x = torch.ones((3, config.hidden_size), dtype=torch.float32)
print(x.shape)
model = RMS_Norm(config)
y = model(x)
print(y.shape)

torch.Size([3, 4096])
torch.Size([3, 4096])


In [12]:
class SwiGLU(nn.Module):
    """这只是门"""
    def __init__(self, config) -> None:
        super(SwiGLU, self).__init__()
        
        # 注意激活函数在两个linear层之间，维度会发生变化，要与第一个linear的参数一致，才能实现element-wise计算
        self.W_gate = nn.Linear(config.hidden_size, config.intermediate_size)
        
    def forward(self, x):
        return F.silu(self.W_gate(x))
    
x = torch.ones((3, config.hidden_size), dtype=torch.float32)
print(x.shape)
model = SwiGLU(config)
y = model(x)
print(y.shape)

torch.Size([3, 4096])
torch.Size([3, 14336])
