In [5]:
import torch
import torch.nn as nn
from torchsummary import summary

import torch.nn.functional as F

# 定义单个专家网络
class Expert(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Expert, self).__init__()
        self.layer1 = nn.Linear(input_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, output_dim)
    
    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = self.layer2(x)
        return x

# 定义门控网络
class Gate(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(Gate, self).__init__()
        self.layer = nn.Linear(input_dim, num_experts)
    
    def forward(self, x):
        return F.softmax(self.layer(x), dim=-1)

# 定义 MoE 模型
class MixtureOfExperts(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts):
        super(MixtureOfExperts, self).__init__()
        # 初始化多个专家
        self.experts = nn.ModuleList([
            Expert(input_dim, hidden_dim, output_dim) 
            for _ in range(num_experts)
        ])
        # 初始化门控网络
        self.gate = Gate(input_dim, num_experts)
        self.num_experts = num_experts
    
    def forward(self, x):
        # 获取门控输出 (batch_size, num_experts)
        gate_output = self.gate(x)
        
        # 获取每个专家的输出 (batch_size, output_dim, num_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=2)
        
        # 加权组合专家输出 (batch_size, output_dim)
        output = torch.einsum('be,bde->bd', gate_output, expert_outputs)
        return output

# 测试代码
def main():
    # 设置参数
    input_dim = 10
    hidden_dim = 20
    output_dim = 5
    num_experts = 3
    batch_size = 32
    
    # 创建模型
    model = MixtureOfExperts(input_dim, hidden_dim, output_dim, num_experts)
    summary(model, (input_dim,))
    
    # 生成随机输入数据
    x = torch.randn(batch_size, input_dim)
    
    # 前向传播
    output = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")
    print(f"Sample output: {output[0]}")

if __name__ == "__main__":
    # 检查是否有 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    main()


    # 在 main() 函数中添加

Using device: cpu
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Linear-1                    [-1, 3]              33
              Gate-2                    [-1, 3]               0
            Linear-3                   [-1, 20]             220
            Linear-4                    [-1, 5]             105
            Expert-5                    [-1, 5]               0
            Linear-6                   [-1, 20]             220
            Linear-7                    [-1, 5]             105
            Expert-8                    [-1, 5]               0
            Linear-9                   [-1, 20]             220
           Linear-10                    [-1, 5]             105
           Expert-11                    [-1, 5]               0
Total params: 1,008
Trainable params: 1,008
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB):