In [1]:
import torch
import torch.nn as nn

In [2]:
class Expert(nn.Module):
    def __init__(self, hidden_dim: int, multiplier: int):
        super().__init__()
        self.expert = nn.Sequential(
            nn.Linear(hidden_dim, multiplier * hidden_dim),
            nn.GELU(),
            nn.Linear(multiplier * hidden_dim, hidden_dim)
        )

    def forward(self, x: torch.tensor):
        return self.expert(x)

In [3]:
class DenseMoE(nn.Module):
    def __init__(self, num_experts: int, hidden_dim: int, multiplier: int):
        super().__init__()
        self.router = nn.Linear(hidden_dim, num_experts)
        self.expert_layer = nn.ModuleList([Expert(hidden_dim, multiplier) for _ in range(num_experts)])

    def forward(self, x: torch.tensor):
        weights = torch.softmax(self.router(x), dim=-1)
        stack = torch.stack([self.expert(x) for expert in self.expert_layer], dim=2)
        output = stack * weights.unsqueeze(-1)
        output = torch.sum(output, dim=2)
        return output