In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class LlamaMLP(nn.Module):
    """
    LLaMA中的MLP层，使用SwiGLU激活函数
    """
    def __init__(self, dim, hidden_dim=None, multiple_of=256):
        super().__init__()
        # 如果未指定hidden_dim，则默认为输入维度的4倍
        if hidden_dim is None:
            hidden_dim = 4 * dim
            
        # LLaMA中的缩放策略，确保隐藏维度是multiple_of的倍数
        # 这里使用2/3是SwiGLU的特殊需求
        hidden_dim = int(2 * hidden_dim / 3)
        # 向上取整到multiple_of的最近倍数
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        
        # 三个线性投影，没有偏置项
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)  # W1
        self.up_proj = nn.Linear(dim, hidden_dim, bias=False)    # W3
        self.down_proj = nn.Linear(hidden_dim, dim, bias=False)  # W2
        
    def forward(self, x):
        # SwiGLU激活: (SiLU(W1x) ⊗ W3x)W2
        # 其中SiLU(x) = x * sigmoid(x)，⊗表示元素级乘法
        gate_output = F.silu(self.gate_proj(x))  # SiLU激活
        up_output = self.up_proj(x)              # 上投影
        intermediate = gate_output * up_output   # 元素级乘法
        return self.down_proj(intermediate)      # 下投影

In [None]:
import torch
import torch.nn as nn
class MLP(nn.Module):
    def __init__(self, dim, up_dim):
        super().__init__()
        self.gate_proj=nn.Linear(dim, up_dim, bias = False)
        self.upper_proj=nn.Linear(dim, up_dim, bias = False)
        self.down_proj=nn.Linear(up_dim, dim, bias = False)
    def forward(self, hidden_states):
        hidden_states = (nn.SiLU(self.gate_proj(hidden_states)) * self.upper_proj(hidden_states)) @ self.down_proj(hidden_states)
        return hidden_states
mlp = MLP(1024,4096)
mlp(torch.randn(2,1024))

: 