# 11 model

In [3]:
from zxfMLtools_dev.models import BaseNet

In [6]:
# 多头自注意力

In [7]:
class Head(BaseNet):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out
    
#Multi-Headed Self Attention
class MultiHeadAttention(BaseNet):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [None]:
# 专家模型

In [None]:
class Expert(BaseNet):
    """ An MLP is a simple linear layer followed by a non-linearity i.e. each Expert """

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
    

In [8]:
# 路由模型

In [10]:
#noisy top-k gating
class NoisyTopkRouter(BaseNet):
    def __init__(self, n_embed, num_experts, top_k):
        super(NoisyTopkRouter, self).__init__()
        self.top_k = top_k
        #layer for router logits
        self.topkroute_linear = nn.Linear(n_embed, num_experts)
        self.noise_linear =nn.Linear(n_embed, num_experts)

    
    def forward(self, mh_output):
        # mh_ouput is the output tensor from multihead self attention block
        logits = self.topkroute_linear(mh_output)

        #Noise logits
        noise_logits = self.noise_linear(mh_output)

        #Adding scaled unit gaussian noise to the logits
        noise = torch.randn_like(logits)*F.softplus(noise_logits)
        noisy_logits = logits + noise

        top_k_logits, indices = noisy_logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(noisy_logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

In [11]:
#Now create the sparse mixture of experts module
class MoE(BaseNet):
    def __init__(self, n_embed, num_experts, top_k):
        super().__init__()
        self.router = NoisyTopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([Expert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)

        # Reshape inputs for batch processing
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Process each expert in parallel
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                # Extract and apply gating scores
                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Update final output additively by indexing and adding
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

In [None]:
%%aigen -a torch 中的 out = torch.cat([h(x) for h in self.heads], dim=-1) 在做什么 输入输出形状是什么样的
pass


torch 中的 out = torch.cat([h(x) for h in self.heads], dim=-1) 在做什么 输入输出形状是什么样的



这行代码是将多个头部的输出连接起来。输入是一个列表，列表中的每个元素都是一个头部函数 h(x) 的输出。输出是将这些输出连接起来的张量。

假设列表中有 n 个元素，每个元素的形状为 (batch_size, d)，其中 batch_size 是批量大小，d 是每个头部的输出维度。则输出的形状为 (batch_size, n*d)，即将每个头部的输出按照最后一个维度连接起来。


In [None]:
import numpy as np

In [None]:
a = torch.randn(2,3)

In [None]:
b = torch.randn(2,3)

In [None]:
a.shape,b.shape

(torch.Size([2, 3]), torch.Size([2, 3]))

In [37]:
torch.cat([a,b],dim=).shape

torch.Size([2, 6])

In [None]:
import torch

In [None]:
tt = torch.randn(2,3)

In [None]:
torch.full_like(tt, float('-inf'))

tensor([[-inf, -inf, -inf],
        [-inf, -inf, -inf]])

In [None]:
tt

tensor([[ 0.0701, -1.0484, -1.9100],
        [ 0.4687,  1.8424,  1.7367]])

In [None]:
tt.topk(2,dim = -1)

torch.return_types.topk(
values=tensor([[ 0.0701, -1.0484],
        [ 1.8424,  1.7367]]),
indices=tensor([[0, 1],
        [1, 2]]))

In [41]:
import torch.nn as nn
MultiHeadAttention(2,10)

NameError: name 'n_embed' is not defined