<a href="https://colab.research.google.com/github/priyal6/NLP-Prac/blob/main/MOE_basic.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
class Expert(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(input_dim, hidden_dim),
        nn.ReLU(),
        nn.Linear(hidden_dim, output_dim)
    )

  def forward(self,x):
    return self.net(x)


In [5]:
class SparseMOE(nn.Module):
  def __init__(self, input_dim, hidden_dim, output_dim, num_experts):
    super().__init__()

    self.experts = nn.ModuleList([
        Expert(input_dim, hidden_dim, output_dim)
        for _ in range(num_experts)
    ])

    self.gate = nn.Linear(input_dim, num_experts)

  def forward(self, x):

    gate_logits = self.gate(x)

    top1_idx = torch.argmax(gate_logits, dim=1)

    outputs = []
    for i, idx in enumerate(top1_idx):
      expert_output = self.experts[idx](x[i].unsqueeze(0))
      outputs.append(expert_output)

    return torch.cat(outputs, dim =0)


In [6]:
model = SparseMOE(
   input_dim=10,
   hidden_dim=32,
   output_dim=1,
   num_experts =4
)

x = torch.randn(5,10)
y= model(x)
print(y.shape)

torch.Size([5, 1])
