In [3]:
from typing import Optional, Tuple

import torch
import torch.nn as nn

from balm.router import ExpertChoiceRouter, TopKRouter
from balm.modules import Expert


In [4]:
class SparseMLP(nn.Module):
    """
    Implementation of a Sparse MLP module, for use in Mixture-of-Experts models.

    Parameters:
    -----------
    embed_dim : int
        Embedding dimension.

    ffn_dim : int
        Feedforward dimension.

    num_experts : int
        Number of experts.

    expert_capacity : int
        Capacity of each expert.

    top_k : int, optional
        Top k for the router. The default is 1.

    activation : str, optional
        Activation function to use. The default is "swiglu".

    expert_ffn_dropout : float, optional
        Dropout rate for the expert feedforward layer. The default is 0.0.

    router_dtype : str, optional
        Dtype for the router. The default is "float32".

    router_bias : bool, optional
        Whether to use bias for the router. The default is False.

    router_jitter : float, optional
        Jitter for the router. The default is 0.0.

    router_ignore_padding_tokens : bool, optional
        Whether to ignore padding tokens for the router. The default is True.

    router_class : nn.Module, optional
        Router class to use. The default is ``TopKRouter``.

    expert_class : nn.Module, optional
        Expert class to use. The default is ``Expert``.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_experts: int,
        expert_capacity: int,
        num_shared_experts: int = 0,
        top_k: int = 1,
        activation: str = "swiglu",
        expert_ffn_dropout: float = 0.0,
        router_dtype: str = "float32",
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        router_class: nn.Module = TopKRouter,
        expert_class: nn.Module = Expert,
    ):
        super().__init__()
        self.router = router_class(
            embed_dim=embed_dim,
            num_experts=num_experts,
            expert_capacity=expert_capacity,
            top_k=top_k,
            num_shared_experts=num_shared_experts,
            dtype=router_dtype,
            bias=router_bias,
            jitter=router_jitter,
            ignore_padding_tokens=router_ignore_padding_tokens,
        )
        self.experts = nn.ModuleDict()
        for idx in range(num_experts):
            self.experts[f"expert_{idx}"] = expert_class(
                embed_dim=embed_dim,
                ffn_dim=ffn_dim,
                dropout=expert_ffn_dropout,
                activation=activation,
            )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple]:
        """
        Route tokens to experts and process them.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        Returns:
        --------
        output : Tuple[torch.Tensor, Tuple]
            A tuple containing the following:
             - x : torch.Tensor
                Output tensor of shape (batch_size, sequence_length, embed_dim).
             - router_outputs : Tuple[torch.Tensor, torch.Tensor]
                A tuple containing the following:
                 - router_logits : torch.Tensor
                    Router logits of shape (batch_size, sequence_length, num_experts).
                 - expert_mask : torch.Tensor
                    Expert mask of shape (batch_size, sequence_length, num_experts).
        """
        # router
        expert_mask, router_probs, router_logits = self.router(x)
        expert_outputs = []

        # experts
        for idx, expert in self.experts.items():
            int_idx = int(idx.split("_")[-1])
            token_indices = expert_mask[..., int_idx].bool()
            expert_output = expert(x[token_indices]).to(x.dtype)
            expanded_output = torch.zeros_like(x)
            expanded_output[token_indices] = expert_output
            expert_outputs.append(expanded_output)

        # combine outputs from the selected tokens for each expert
        x = torch.stack(expert_outputs, dim=-1) * expert_mask.unsqueeze(-2)
        x = x.sum(dim=-1)

        return x, (router_logits, expert_mask)

In [None]:
input_tensor = torch.rand(10, 96, 320)
sparse_mlp = SparseMLP(embed_dim=320, ffn_dim=320*4, num_experts=16, expert_capacity=320/16*1.5)
output = sparse_mlp(input_tensor)
print(output[0].shape)
print(output[1][0].shape)
print(output[1][1].shape)

