In [3]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [280]:
class Expert(nn.Module):

    def __init__(self, dim):

        super().__init__()

        self.mlp1 = nn.Linear(dim, 4*dim)
        self.mlp2 = nn.Linear(4*dim, 4*dim)
        self.mlp3 = nn.Linear(4*dim, dim)
        self.act = nn.LeakyReLU(inplace = True)

    def forward(self, x):

        x = self.act(self.mlp1(x))
        x = self.act(self.mlp2(x))
        x = self.mlp3(x)

        return x


class MoE(nn.Module):

    def __init__(self, num_experts, dim, max_experts):

        super().__init__()

        self.num_experts = num_experts
        self.dim = dim
        self.max_experts = max_experts

        # We need a router and a set of experts.
        self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
        self.router = nn.Linear(dim, num_experts)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):

        # Get the gating weights for this input.
        # (batch_size, seq_len, num_experts)
        expert_weights = self.softmax(self.router(x))

        # Obtain the topk weights and their indices.
        # (batch_size, seq_len, max_experts)
        top_weights, top_indices = torch.topk(expert_weights, self.max_experts, dim=-1, sorted=False)

        # Normalize when we scale.
        # (batch_size, seq_len, max_experts)
        top_weights = top_weights / top_weights.sum(dim=-1, keepdim=True)

        for i in range(self.max_experts):

            # Get the ith topmost index and the weights.
            # (batch_size, seq_len)
            selected_indices = top_indices[:, :, i]

            # We will unsqueeze now because of multiplications later.
            # (batch_size, seq_len, 1)
            selected_weights = top_weights[:, :, i].unsqueeze(dim=-1)

            # output is like the input
            # (batch_size, seq_len, dim)
            out = torch.zeros_like(x)

            # Go through all the expert indices.
            for expert_id in range(self.num_experts):

                # Is this expert in the ith topmost indices selected above?
                # (batch_size, seq_len)
                mask = (expert_id == selected_indices)

                if mask.any():

                    # If there's at least one instance in the current batch
                    # where this expert is in the ith topk slice, then evaluate.
                    # (batch_size, seq_len, dim)
                    expert_out = self.experts[expert_id](x)

                    # Mask out those experts which are not present, and only highlight this one (expert_id).
                    # Rest of the experts will be weighted 0.
                    # (batch_size, seq_len, 1)
                    masked_weights = selected_weights * mask.unsqueeze(dim=-1).float()

                    # print(masked_weights.shape, expert_out.shape)

                    # Now multiply the expert output with these weights and accumulate.
                    # (batch_size, seq_len, dim)
                    out += masked_weights * expert_out

        return out

In [290]:
moe = MoE(num_experts=4, dim=768, max_experts=2)

In [291]:
input = torch.randn(10, 64, 768)

In [293]:
print(moe(input).shape)

torch.Size([10, 64, 768])
