In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# 定义专家
class BasicExpert(nn.Module):
    def __init__(self, features_in, features_out):
        super().__init__()
        self.f = nn.Linear(features_in, features_out)

    def forward(self, x):
        return self.f(x)

In [5]:
class BasicMoE(nn.Module):
    def __init__(self, features_in, features_out, num_experts):
        """MoE实现"""
        super().__init__()
        self.gate = nn.Linear(features_in, num_experts)

        self.experts = nn.ModuleList(
            BasicExpert(
                features_in,features_out
            ) for _ in range(num_experts)
        )

    def forward(self, x):
        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x) for expert in self.experts
        ]  # 每一个expert输出一个(batch, feature_out)

        expert_outputs = [
            expert_out.unsqueeze(1)
            for expert_out in expert_out_list
        ]

        # expert output(b, 1, feature_out)
        expert_output = torch.concat(
                expert_outputs,
                dim=1,
        )

        expert_weights = F.softmax(expert_weights, dim=1)

        expert_weights = expert_weights.unsqueeze(1)
        output = expert_weights @ expert_output
        return output.squeeze(1)

In [7]:
def text_basic_moe():
    x = torch.rand(4, 512)
    basic_moe = BasicMoE(512, 128, 4)
    output = basic_moe(x)
    print(output.data, output.shape)

text_basic_moe()

tensor([[ 6.6325e-02,  2.0662e-01,  2.2977e-02, -8.4462e-02, -2.5342e-01,
          2.3081e-01, -1.0448e-02, -1.5514e-01, -2.2544e-01, -1.9022e-01,
         -3.0879e-03, -2.5351e-01,  3.5677e-01,  8.6907e-02,  8.7096e-02,
          3.6754e-02,  2.0798e-01, -6.3753e-03,  3.9717e-01,  2.1782e-02,
          4.2338e-02,  1.4011e-01, -1.6863e-01, -1.6302e-01,  2.1730e-01,
         -3.5477e-02,  1.3021e-01, -4.5069e-01,  4.7473e-03, -1.7201e-01,
          7.1447e-02, -4.2137e-02, -4.3878e-03,  1.1961e-01, -9.9498e-02,
          1.0333e-01,  2.2167e-01, -3.2443e-02, -1.8130e-02,  1.9047e-01,
         -2.5822e-01, -1.0243e-01,  1.8449e-01, -3.7699e-03,  8.2029e-02,
          2.4209e-01,  2.1167e-01,  8.6227e-02,  1.5288e-01,  1.3335e-01,
          3.4756e-01, -2.6244e-01,  5.3349e-02,  7.4197e-02, -9.7490e-02,
         -1.7634e-01, -5.3507e-02,  7.4848e-02,  9.3969e-03,  3.8816e-01,
          1.2948e-01,  2.4515e-01,  3.2426e-02, -7.8446e-02,  5.7149e-02,
          1.7013e-02, -1.7828e-01,  2.