In [1]:
import torch
import torch.nn as nn


class ExpertNetWork(nn.Module):
    '''升维降维操作'''
    def __init__(self,hidden_size, intermediate_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.linear1 = nn.Linear(hidden_size,intermediate_size)
        self.linear2 = nn.Linear(intermediate_size,hidden_size)

    def forward(self,x):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        output = self.linear2(x)
        return output


In [3]:
class Router(nn.Module):
    def __init__(self, hidden_size,expert_num,top_k):
        super().__init__()
        self.router = nn.Linear(hidden_size,expert_num)
        self.top_k = top_k
        self.hidden_size = hidden_size

    def forward(self,x):
        x = x.view(-1,self.hidden_size)
        x = self.router(x)  #每一个token 过router层
        x = nn.functional.softmax(x,dim = -1)
        topk_weight, topk_idx = torch.topk(x,k =self.top_k,dim = -1,sorted=False)
        #对topk权重进行归一化
        topk_weight = topk_weight / topk_weight.sum(dim = -1, keepdim=True)
        return topk_weight,topk_idx

In [None]:
class MOELayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size,expert_num,top_k):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.expert_num = expert_num
        self.top_k = top_k
        self.experts = nn.ModuleList([ExpertNetWork(self.hidden_size,self.intermediate_size) for _ in range(self.expert_num)])
        self.router = Router(self.hidden_size,self.expert_num,self.top_k)

    def forward(self,x):
        batch_size,seq_len,_ = x.size()
        token_num = batch_size * seq_len
        x_flat = x.view(token_num,self.hidden_size) #对每个token独立计算
        topk_weight,topk_idx = self.router(x_flat) #(N,K)
        #初始化输出张量
        output = torch.zeros_like(x_flat)
        for token_idx in range(token_num):
            for expert_idx in range(self.top_k):
                expert = self.experts[topk_idx[token_idx,expert_idx]]
                output[token_idx]+= topk_weight[token_idx,expert_idx] * expert(x_flat[token_idx])
        output = output.view(batch_size,seq_len,self.hidden_size)
        return output
    
hidden_size = 4096
intermediate_size = 2048
expert_num = 8
top_k = 2

inputs = torch.randn((2,11,4096))
moe_layer  =MOELayer(hidden_size,intermediate_size,expert_num,top_k)
outputs = moe_layer(inputs)
print(outputs.shape)

torch.Size([2, 11, 4096])
