In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
"""
B: batch_size
L: seq_len
D: d_model
H: d_ffn
V: vocab_size
"""

'\nB: batch_size\nL: seq_len\nD: d_model\nH: d_ffn\nV: vocab_size\n'

In [2]:
class GELU(nn.Module):
    def __init__(self):
        super(GELU, self).__init__()

    def forward(self, x):
        return 0.5*x*(1+F.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3))))
# https://blog.csdn.net/w137093940/article/details/112756141

In [3]:
class SpatialGatingUnit(nn.Module):
    def __init__(self, seq_len, d_ffn):
        super().__init__()
        # 输入的x的维度是B * L *H ,那么LayerNorm的参数是H(与x的最后一个维度相同)
        self.norm = nn.LayerNorm(d_ffn//2)
        self.proj = nn.Conv1d(seq_len, seq_len, kernel_size=1) 
        # 输入通道是seq_len，输出通道是seq_len（卷积核个数），卷积核大小为1
        nn.init.constant_(self.proj.weight, 0)
        nn.init.constant_(self.proj.bias, 1)
        
    def forward(self,x):
        # x B * L * H
        u, v = x.chunk(2, dim=-1)  # 在最后一个维度上切分
        # u,v B * L * H/2
        v = self.norm(v)
        v = self.proj(v) # v B * L * H/2
        return u * v  # B * L * H/2
        

In [5]:
class gMLPblock(nn.Module):
    def __init__(self, seq_len, d_model, d_ffn):
        # seq_len是序列的长度
        # d_model是词向量的维度
        # d_ffn是前馈神经网络中中间层的维度
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj1 = nn.Linear(d_model, d_ffn)
        self.sgu = SpatialGatingUnit(seq_len, d_ffn)
        self.channel_proj2 = nn.Linear(d_ffn//2, d_model)
        self.act = GELU()
        

    def forward(self,x):
        # 输入的x的维度是 B * L * D
        shortcut = x
        x = self.norm(x)
        x = self.channel_proj1(x)  # U d_model * seq_len x变为batch_size * seq_len * d_ffn
        x = self.act(x)  # x batch_size * seq_len * d_ffn
        x = self.sgu(x)  # x batch_size * seq_len * d_ffn/2
        x = self.channel_proj2(x)  # V batch_size * seq_len * d_model
        return shortcut + x # batch_size * seq_len * d_model

In [6]:
class gMLP(nn.Module):
    def __init__(self, seq_len=256, d_model=256, d_ffn=512, num_layers=6):
        super().__init__()
        self.model = nn.Sequential(*[gMLPblock(seq_len, d_model, d_ffn)]*num_layers)
        
    def forward(self,x):
        x = self.model(x)
        return x

In [7]:
class gMLPofLanguageModel(gMLP):
    def __init__(self, vocab_size=20000, seq_len=256, d_model=256, d_ffn=512, num_layers=6):
        super().__init__(seq_len, d_model, d_ffn, num_layers)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.fc = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        emb = self.embedding(x)  # B * L * D
        out = self.model(emb) # B * L * D
        out = self.fc(out) # B * L * V
        return out

In [8]:
num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLPofLanguageModel(vocab_size=num_tokens,seq_len=len_sen,d_model=512,d_ffn=1024)
output=gmlp(input)
print(output.shape)



torch.Size([50, 49, 10000])


# aMLP

In [15]:
class TinyAttn(nn.Module):
    def __init__(self, d_out,d_model,d_attn=64):
        super().__init__()
        self.proj1 = nn.Linear(d_model, 3 * d_attn)
        self.proj2 = nn.Linear(d_attn, d_out)
        self.d_attn = d_attn
        
    def forward(self, x):
        qkv = self.proj1(x) # B * L * 3attn
        q, k, v = qkv.chunk(3, dim=-1) # B * L * attn
        k = k.permute(0,2,1) # B * attn *L
        w = torch.matmul(q,k) # B * L * L
        a = F.softmax(w*torch.rsqrt(d_attn)) # B * L * L
        x = torch.matmul(a,v) # B * L * attn
        x = self.proj2(x) # B * L * d_out
        return x

In [16]:
class aSpatialGatingUnit(nn.Module):
    def __init__(self, seq_len, d_ffn, d_model):
        super().__init__()
        # 输入的x的维度是B * L *H ,那么LayerNorm的参数是H(与x的最后一个维度相同)
        self.norm = nn.LayerNorm(d_ffn//2)
        self.proj = nn.Conv1d(seq_len, seq_len, kernel_size=1) 
        # 输入通道是seq_len，输出通道是seq_len（卷积核个数），卷积核大小为1
        nn.init.constant_(self.proj.weight, 0)
        nn.init.constant_(self.proj.bias, 1)
        self.attn = TinyAttn(d_out=d_ffn // 2, d_model=d_ffn)
        
    def forward(self,x):
        attn_out = self.attn(x)
        # x B * L * H
        # attn_out B * L * d_ffn/2
        u, v = x.chunk(2, dim=-1)  # 在最后一个维度上切分
        # u,v B * L * H/2
        v = self.norm(v)
        v = self.proj(v) # v B * L * H/2
        v = v + attn_out
        return u * v  # B * L * H/2

In [17]:
class aMLPblock(nn.Module):
    def __init__(self, seq_len, d_model, d_ffn):
        # seq_len是序列的长度
        # d_model是词向量的维度
        # d_ffn是前馈神经网络中隐藏层的维度
        super().__init__()
        self.norm = nn.LayerNorm(d_model)
        self.channel_proj1 = nn.Linear(d_model, d_ffn)
        self.sgu = aSpatialGatingUnit(seq_len, d_ffn, d_model)
        self.channel_proj2 = nn.Linear(d_ffn//2, d_model)
        self.act = GELU()
        

    def forward(self,x):
        # 输入的x的维度是 B * L * D
        shortcut = x
        x = self.norm(x)
        x = self.channel_proj1(x)  # U d_model * seq_len x变为batch_size * seq_len * d_ffn
        x = self.act(x)  # x batch_size * seq_len * d_ffn
        x = self.sgu(x)  # x batch_size * seq_len * d_ffn/2
        x = self.channel_proj2(x)  # V batch_size * seq_len * d_model
        return shortcut + x # batch_size * seq_len * d_model

In [18]:
class aMLP(nn.Module):
    def __init__(self, seq_len=256, d_model=256, d_ffn=512, num_layers=6):
        super().__init__()
        self.model = nn.Sequential(*[aMLPblock(seq_len, d_model, d_ffn)]*num_layers)
        
    def forward(self,x):
        x = self.model(x)
        return x

In [19]:
class aMLPofLanguageModel(aMLP):
    def __init__(self, vocab_size=20000, seq_len=256, d_model=256, d_ffn=512, num_layers=6):
        super().__init__(seq_len, d_model, d_ffn, num_layers)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.fc = nn.Linear(d_model, vocab_size)
        
    def forward(self, x):
        emb = self.embedding(x)  # B * L * D
        out = self.model(emb) # B * L * D
        out = self.fc(out) # B * L * V
        return out

In [20]:
num_tokens=10000
bs=50
len_sen=49
num_layers=6
input=torch.randint(num_tokens,(bs,len_sen)) #bs,len_sen
gmlp = gMLPofLanguageModel(vocab_size=num_tokens,seq_len=len_sen,d_model=512,d_ffn=1024)
output=gmlp(input)
print(output.shape)

torch.Size([50, 49, 10000])
