# 1 基础班MOE课程

In [2]:
import torch 
print(torch.rand(4, 10))

print(torch.randn(4, 10))

tensor([[0.4159, 0.7608, 0.7949, 0.0397, 0.3288, 0.6377, 0.1688, 0.9971, 0.1471,
         0.9540],
        [0.2243, 0.5076, 0.3180, 0.3391, 0.2576, 0.2278, 0.6785, 0.4587, 0.2295,
         0.0387],
        [0.0165, 0.9382, 0.6330, 0.2166, 0.2831, 0.2508, 0.6871, 0.6813, 0.0354,
         0.7468],
        [0.5860, 0.0913, 0.4539, 0.5084, 0.1148, 0.2923, 0.5111, 0.7588, 0.5885,
         0.3750]])
tensor([[ 0.0872, -0.8610, -0.6415, -0.0382,  1.5984, -0.9150,  0.3136, -1.2365,
          1.0620,  1.1745],
        [ 0.6456, -0.9643, -0.5202, -1.5816, -0.3633, -0.3776,  1.5500, -0.1296,
          0.3385,  0.6432],
        [ 1.3777,  0.9867, -0.9896, -0.2385,  1.3886, -0.8474, -0.5092, -0.7831,
         -0.3012, -1.6260],
        [ 2.1372, -1.4853,  0.2536,  0.3981,  0.6687, -1.7441, -1.3074, -0.7002,
         -0.7426, -0.9205]])


In [1]:
import torch.nn as nn
class Expert(nn.Module):
    
    def __init__(self, feature_in, feature_out):
        super(Expert, self).__init__()
        self.linear = nn.Linear(feature_in, feature_out)
    
    def forward(self, x):
        return self.linear(x)

In [2]:
class Gate(nn.Module):
    
    def __init__(self, feature_in, num_experts):
        super(Gate, self).__init__()
        self.linear = nn.Linear(feature_in, num_experts)
    
    def forward(self, x):
        # 归一化，按照列求softmax
        return torch.softmax(self.linear(x), dim=1)

class MoE(nn.Module):
    
    def __init__(self, num_experts, feature_in, feature_out):
        super(MoE, self).__init__()
        self.experts = nn.ModuleList([Expert(feature_in, feature_out) for _ in range(num_experts)])
        self.gate = Gate(feature_in, num_experts)

    def forward(self, x):
        # 每个expert的输出
        #  [batch_size, feature_out]
        expert_outputs = [expert(x) for expert in self.experts]
        # 每个expert的权重
        # [batch_size, num_experts]
        gate_weights = self.gate(x)
        # (batch_size, num_experts, feature_out)
        expert_outputs = torch.stack(expert_outputs, dim=1)
        # (batch_size, 1, num_experts)
        gate_weights = gate_weights.unsqueeze(dim=1)
        
        # (batch_size, 1, feature_out)
        moe_output = gate_weights @ expert_outputs
        return moe_output.squeeze(dim=1)
        

In [8]:
import torch
input_x = torch.randn(4, 10)
print(input_x)
print(input_x.size())

moe = MoE(3, 10, 5)
moe(input_x)

tensor([[-1.3221,  0.1324, -0.0975, -1.4474,  0.1934, -1.4832, -0.6117,  0.6445,
         -0.8924,  0.7569],
        [-0.0894, -0.9504, -0.0928, -1.1415,  0.2462,  0.8481,  0.0710,  0.4090,
         -0.8152, -0.8653],
        [-0.8259, -0.4623, -1.5597, -1.0041,  0.7669, -0.0485,  0.1031, -0.6506,
          0.1711, -0.4212],
        [ 1.5897,  0.1043, -0.7626, -1.1799, -0.1994,  0.6402,  1.1418, -0.2604,
          1.8493,  0.5944]])
torch.Size([4, 10])


tensor([[ 0.9591,  0.3557, -0.5081, -0.3962, -0.1108],
        [ 0.2066, -0.2297, -0.0362, -0.0333, -0.4391],
        [ 0.1110, -0.0592,  0.1146, -0.2023,  0.0063],
        [-0.1800, -0.0082,  0.5062, -0.6453, -0.0814]],
       grad_fn=<SqueezeBackward1>)