# MOE

<img src="./moe.png" width="500" height="400">

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Linear(nn.Module):
    def __init__(self, in_features, out_features):
        super(Linear, self).__init__()
        self.fc=nn.Linear(in_features,out_features)

    def forward(self, x):
        return self.fc(x)

class MOE(nn.Module):
    def __init__(self, in_features, out_features, num_experts):
        super(MOE, self).__init__()
        self.num_experts = num_experts
        self.experts= nn.ModuleList([Linear(in_features,out_features) for _ in range(num_experts)])
        self.gate=Linear(in_features,num_experts)


    def forward(self, x):
        #x shape [batch_size, in_features]  gate(x)/gate_scores shape [batch_size, num_experts] 
        #expert(x) shape [batch_size, out_features]
        gate_socres=F.softmax(self.gate(x),dim=-1)
        #expert_outputs shape [batch_size, num_experts, out_features]
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        #output shape [batch_size, out_features]
        output=torch.bmm(gate_socres.unsqueeze(1),expert_outputs).squeeze(1)
        return output

