In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn import MultiheadAttention


In [2]:
def sinkhorn_knopp(matrix: torch.Tensor, num_iter: int = 20, epsilon: float = 1e-20) -> torch.Tensor:
    """
    所有行和为1，所有列和为1，且元素非负。
    """
    # 确保元素非负
    K = torch.exp(matrix)
    for _ in range(num_iter):
        # 行归一化，使每行和为1
        K = K / (K.sum(dim=-1, keepdim=True) + epsilon)
        # 列归一化，使每列和为1
        K = K / (K.sum(dim=-2, keepdim=True) + epsilon)
    return K

In [5]:
class mHC(nn.Module):
    def __init__(self, dim, n, layer_id):
        super(mHC, self).__init__()
        self.dim = dim
        self.n = n
        self.nc = n * dim
        self.n2 = n * n
        
        self.phi = nn.Linear(self.nc, self.n2 + 2 * self.n, bias=False)
        # 也可拆分成3个矩阵
        # self.phi_pre = nn.Linear(self.nc, self.n, bias=False)
        # self.phi_post = nn.Linear(self.nc, self.n, bias=False)
        # self.phi_res = nn.Linear(self.nc, self.n2, bias=False)
        
        self.a = nn.Parameter(torch.ones(3) * 0.01)
        # 也可拆分成3个
        # self.a_pre = nn.Parameter(torch.ones(1) * 0.01)
        # self.a_post = nn.Parameter(torch.ones(1) * 0.01)
        # self.a_res = nn.Parameter(torch.ones(1) * 0.01)
        self.b = nn.Parameter(torch.zeros(self.n2 + 2 * self.n))
        # 也可拆分成3个矩阵
        # self.b_pre = nn.Parameter(torch.zeros(self.n))
        # self.b_post = nn.Parameter(torch.zeros(self.n))
        # self.b_res = nn.Parameter(torch.zeros(self.n2))
    
    # 不同分支之间信息交互
    def width_connection(self, hidden_states):
        # 前一层的输出：hidden_states
        B, L, N, D = hidden_states.shape  # [B, L, n, dim]
        hidden_states_flatten = hidden_states.flatten(2)  # [B, L, n*dim]
        r = hidden_states_flatten.norm(dim=-1, keepdim=True) / math.sqrt(self.nc) # [B, L, 1]
        
        H = self.phi(hidden_states_flatten)  # [B, L, n*n + 2*n]
        H_pre = (1/r) * H[:, :, :self.n] * self.a[0] + self.b[0:self.n]  # [B, L, 1] * [B, L, n] + [n, ]
        H_post = (1/r) * H[:, :, self.n:self.n*2] * self.a[1] + self.b[self.n:self.n*2]  # [B, L, 1] * [B, L, n] + [n, ]
        H_res = (1/r) * H[:, :, self.n*2:] * self.a[2] + self.b[self.n*2:]  # [B, L, 1] * [B, L, n*n] + [n*n, ]
        
        H_pre = F.sigmoid(H_pre)
        H_post = 2 *F.sigmoid(H_post)
        
        H_res = H_res.reshape(B, L, self.n, self.n) # [B, L, n, n]
        H_res = sinkhorn_knopp(H_res)
        
        H_pre = H_pre.unsqueeze(dim=2) # [B, L, 1, n]
        h_pre = torch.matmul(H_pre, hidden_states)  # [B, L, 1, n] @ [B, L, n, dim] = [B, L, 1, dim]
        h_res = torch.matmul(H_res, hidden_states)  # [B, L, n, n] @ [B, L, n, dim] = [B, L, n, dim]
        
        return h_pre, h_res, H_post
    
    # 不同层之间信息传递，残差连接
    def depth_connection(self, h_res, hidden_states, H_post):
        # H_post: [B, L, n]
        # hidden_states: [B, L, dim]，经过attention或者ffn后的输出
        h_post = torch.matmul(H_post.unsqueeze(-1), hidden_states.unsqueeze(-2))# [B, L, n, 1] * [B, L, 1, dim] = [B, L, n, dim]
        output = h_post + h_res
                    
        return output # [B, L, n, dim]

In [6]:
hidden_states = torch.randn(1, 6, 4, 64) # [B, L, n, dim]
mhc = mHC(dim=64, n=4, layer_id=0)
h_pre, h_res, H_post = mhc.width_connection(hidden_states)
# h_pre: [B, L, 1, dim]
# h_res: [B, L, n, dim]
# H_post: [B, L, n]
output = mhc.depth_connection(h_res, h_pre.squeeze(-2), H_post)
print(output.shape)  # [B, L, n, dim]

torch.Size([1, 6, 4, 64])


In [7]:
class FFN(nn.Module):
    def __init__(self, dim, hidden_dim):
        super(FFN, self).__init__()
        self.proj_up = nn.Linear(dim, hidden_dim)
        self.proj_down = nn.Linear(hidden_dim, dim)
        
    def forward(self, x):
        x = F.relu(self.proj_up(x))
        x = self.proj_down(x)
        return x

In [8]:
class DecoderLayer(nn.Module):
    def __init__(self, dim, n_heads, layer_id, n=4):
        super(DecoderLayer, self).__init__()
        self.attn_mhc = mHC(dim=dim, n=n, layer_id=layer_id)
        self.ffn_mhc = mHC(dim=dim, n=n, layer_id=layer_id)
        self.attention = MultiheadAttention(embed_dim=dim, num_heads=n_heads, bias=False, batch_first=True)
        self.ffn = FFN(dim=dim, hidden_dim=4*dim)
        
    def forward(self, hidden_states):
        h_pre, h_res, H_post = self.attn_mhc.width_connection(hidden_states)
        # h_pre: [B, L, 1, dim]
        attn_output, _ = self.attention(h_pre.squeeze(-2), h_pre.squeeze(-2), h_pre.squeeze(-2))
        # attn_output: [B, L, dim]
        hidden_states = self.attn_mhc.depth_connection(h_res, attn_output, H_post)  # [B, L, n, dim]
        
        h_pre, h_res, H_post = self.ffn_mhc.width_connection(hidden_states)
        ffn_output = self.ffn(h_pre.squeeze(-2))  # [B, L, dim]
        hidden_states = self.ffn_mhc.depth_connection(h_res, ffn_output, H_post)  # [B, L, n, dim]
        return hidden_states

In [9]:
decoder = DecoderLayer(dim=64, n_heads=4, layer_id=0)
hidden_states = torch.randn(1, 6, 4, 64) # [B, L, n, dim]
output = decoder(hidden_states)
print(output.shape)  # [B, L, n, dim]

torch.Size([1, 6, 4, 64])


In [10]:
class LLM(nn.Module):
    def __init__(self, vocab_size, dim, n_heads, num_layers, n = 4):
        super(LLM, self).__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.layers = nn.ModuleList([DecoderLayer(dim=dim, n_heads=n_heads, layer_id=i, n=n) for i in range(num_layers)])
        self.output_layer = nn.Linear(dim, vocab_size)
        
    def forward(self, input_ids):
        hidden_states = self.embedding(input_ids)  # (B, L, dim)
        hidden_states = hidden_states.unsqueeze(2).expand(-1, -1, 4, -1)  # (B, L, n, dim)
        for layer in self.layers:
            hidden_states = layer(hidden_states)  # (B, L, n, dim)
        output = self.output_layer(hidden_states.mean(dim=2))  # (B, L, vocab_size)
        return output

In [11]:
model = LLM(vocab_size=5000, dim=64, n_heads=4, num_layers=2, n=4)
input_ids = torch.randint(0, 5000, (1, 10))  # (B, L)
output = model(input_ids)
print(output.shape)  # (B, L, vocab_size)

torch.Size([1, 10, 5000])
