In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from typing import Optional, Tuple, Union

from balm.data import load_dataset, DataCollator
# from balm.models.balm import BalmForMaskedLM
# from balm.models.balm_moe import BalmMoEForMaskedLM
# from balm.models.balm_moe_rope import BalmMoERoPEForMaskedLM

from balm.embedding import RelativePositionalEmbedding
from balm.loss import router_z_loss, router_load_balancing_loss
from balm.modules import (
    Top1Router,
    # SparseTransformerLayer,
    Expert,
    BalmLMHead,
    MaskedLMOutput,
)

from balm.tokenizer import Tokenizer
from balm.training.trainer import Trainer

In [2]:
class TopKRouter(nn.Module):
    """
    This router uses the "token choice of top-k experts" strategy introduced in the
    `Switch Transformers`_ paper. Tokens are routed to their expert of choice until the
    expert's `expert_capacity` is reached.

    .. note::
        There is no guarantee that each token will be processed by an expert,
        or that every expert will receive at least one token.

    If tokens are routed to an expert which is above capacity, they are not processed by any expert
    and their hidden states are passed to the subsequent layer unchanged.


    Parameters:
    -----------
    embed_dim : int
        Embedding dimension.

    num_experts : int
        Number of experts.

    expert_capacity : int
        Maximum number of tokens that can be routed to each expert.

    dtype : str, optional
        Data type to use for router probabilities. The default is "float32".

    bias : bool, optional
        Whether to add bias to the router classifier. The default is ``False``.

    jitter : float, optional
        Amount of jitter to add to the router probabilities. The default is ``0.0``.

    ignore_padding_tokens : bool, optional
        Whether to ignore padding tokens when computing router probabilities.
        The default is ``True``.


    .. _Switch Transformers:
        https://arxiv.org/abs/2101.03961
    """

    def __init__(
        self,
        embed_dim: int,
        num_experts: int,
        expert_capacity: int,
        dtype: str = "float32",
        bias: bool = False,
        jitter: float = 0.0,
        ignore_padding_tokens: bool = True,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.expert_capacity = expert_capacity
        self.dtype = getattr(torch, dtype)
        self.classifier = nn.Linear(
            embed_dim,
            self.num_experts,
            bias=bias,
            dtype=self.dtype,
        )
        self.jitter = jitter
        self.ignore_padding_tokens = ignore_padding_tokens

    def _compute_router_probabilities(
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes router probabilities from input hidden states.

        Parameters:
        -----------
        x : torch.Tensor
            Tensor of shape (batch_size, sequence_length, hidden_dim) from which
            router probabilities are computed.

        Returns:
        --------
        router_probabilities : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) corresponding to
            the probabilities for each token and expert. Used for routing tokens to experts.

        router_logits : torch.Tensor
            Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding
            to raw router logits. This is used for computing router z-loss.
        """
        # float32 is used to ensure stability. See the discussion of "selective precision" in
        # https://arxiv.org/abs/2101.03961.
        # we also store the input dtype so we can cast the output back to the original dtype
        self.input_dtype = x.dtype
        x = x.to(self.dtype)
        if self.jitter > 0:
            x *= torch.empty_like(x).uniform_(1.0 - self.jitter, 1.0 + self.jitter)

        # shape: [batch_size, sequence_length, num_experts]
        logits = self.classifier(x)

        # apply softmax and cast back to the original dtype
        probabilities = F.softmax(logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
        return probabilities, logits

    def forward(
        self, x: torch.Tensor, top_k: int = 1
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Route tokens to top-k experts.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        top_k : int
            Number of top experts to route each token to.

        Returns:
        --------
        expert_indices : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) indicating
            which experts the token should be routed to.

        router_probabilities : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) containing
            the router probabilities.

        router_logits : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) containing
            the router logits.
        """
        router_probs, router_logits = self._compute_router_probabilities(x)
        top_k_values, top_k_indices = torch.topk(router_probs, k=top_k, dim=-1)
        expert_indices = F.one_hot(top_k_indices, num_classes=self.num_experts).sum(
            dim=-2
        )

        # mask tokens if their desired experts are above capacity
        token_priority = torch.cumsum(expert_indices, dim=-2)
        expert_capacity_mask = token_priority <= self.expert_capacity
        expert_indices = expert_indices * expert_capacity_mask

        # get the probabilities of the top-choice experts for each token
        router_probs = top_k_values * expert_indices

        return expert_indices, router_probs, router_logits

In [33]:
class ExpertChoiceRouter(nn.Module):
    """
    This router uses the "expert choice of top-k tokens" strategy introduced in the
    `Switch Transformers`_ paper. Tokens are routed to their expert of choice until the
    expert's `expert_capacity` is reached.

    .. note::
        There is no guarantee that each token will be processed by an expert,
        or that every expert will receive at least one token.

    If tokens are routed to an expert which is above capacity, they are not processed by any expert
    and their hidden states are passed to the subsequent layer unchanged.


    Parameters:
    -----------
    embed_dim : int
        Embedding dimension.

    num_experts : int
        Number of experts.

    expert_capacity : int
        Maximum number of tokens that can be routed to each expert.

    dtype : str, optional
        Data type to use for router probabilities. The default is "float32".

    bias : bool, optional
        Whether to add bias to the router classifier. The default is ``False``.

    jitter : float, optional
        Amount of jitter to add to the router probabilities. The default is ``0.0``.

    ignore_padding_tokens : bool, optional
        Whether to ignore padding tokens when computing router probabilities.
        The default is ``True``.


    .. _Switch Transformers:
        https://arxiv.org/abs/2101.03961
    """

    def __init__(
        self,
        embed_dim: int,
        num_experts: int,
        expert_capacity: int,
        dtype: str = "float32",
        bias: bool = False,
        jitter: float = 0.0,
        ignore_padding_tokens: bool = True,
    ):
        super().__init__()
        self.num_experts = num_experts
        self.expert_capacity = expert_capacity
        self.dtype = getattr(torch, dtype)
        self.classifier = nn.Linear(
            embed_dim,
            self.num_experts,
            bias=bias,
            dtype=self.dtype,
        )
        self.jitter = jitter
        self.ignore_padding_tokens = ignore_padding_tokens

    def _compute_router_probabilities(
        self, x: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Computes router probabilities from input hidden states.

        Parameters:
        -----------
        x : torch.Tensor
            Tensor of shape (batch_size, sequence_length, hidden_dim) from which
            router probabilities are computed.

        Returns:
        --------
        router_probabilities : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) corresponding to
            the probabilities for each token and expert. Used for routing tokens to experts.

        router_logits : torch.Tensor
            Logits tensor of shape (batch_size, sequence_length, num_experts) corresponding
            to raw router logits. This is used for computing router z-loss.
        """
        # float32 is used to ensure stability. See the discussion of "selective precision" in
        # https://arxiv.org/abs/2101.03961.
        # we also store the input dtype so we can cast the output back to the original dtype
        self.input_dtype = x.dtype
        x = x.to(self.dtype)
        if self.jitter > 0:
            x *= torch.empty_like(x).uniform_(1.0 - self.jitter, 1.0 + self.jitter)

        # shape: [batch_size, sequence_length, num_experts]
        logits = self.classifier(x)

        # apply softmax and cast back to the original dtype
        probabilities = F.softmax(logits, dim=-1, dtype=self.dtype).to(self.input_dtype)
        return probabilities, logits

    def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, Tuple]:
        """
        Route tokens to experts, selecting top-k tokens for each expert.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        Returns:
        --------
        expert_mask : torch.Tensor
            Binary mask tensor of shape (batch_size, sequence_length, num_experts) indicating
            which tokens are selected for each expert.

        router_probabilities : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) containing
            the router probabilities.

        router_logits : torch.Tensor
            Tensor of shape (batch_size, sequence_length, num_experts) containing
            the router logits.
        """
        router_probs, router_logits = self._compute_router_probabilities(x)
        expert_mask = torch.zeros_like(router_probs)

        # Select top-k tokens for each expert
        for i in range(self.num_experts):
            _, top_k_indices = torch.topk(
                router_probs[..., i], k=self.expert_capacity, dim=1
            )
            expert_mask.scatter_(1, top_k_indices.unsqueeze(-1), 1, reduce="add")

        # Ensure that the mask is binary
        expert_mask = expert_mask.clamp(max=1)

        return expert_mask, router_probs, router_logits

In [34]:
class SparseMLP(nn.Module):
    """
    Implementation of the Switch Transformers Sparse MLP module.

    Parameters:
    -----------
    config : BalmMoEConfig
        Model configuration class with all the parameters of the model.
        Initializing with a config file does not load the weights associated with the model, only the
        configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.

    router_class : nn.Module, optional
        Router class to use. The default is ``Router``.

    expert_class : nn.Module, optional
        Expert class to use. The default is ``Expert``.

    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_experts: int,
        expert_capacity: int,
        top_k: int = 1,
        expert_activation: str = "gelu",
        expert_ffn_dropout: float = 0.0,
        router_dtype: str = "float32",
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        router_class: nn.Module = TopKRouter,
        expert_class: nn.Module = Expert,
    ):
        super().__init__()
        self.top_k = top_k
        self.router = router_class(
            embed_dim=embed_dim,
            num_experts=num_experts,
            expert_capacity=expert_capacity,
            dtype=router_dtype,
            bias=router_bias,
            jitter=router_jitter,
            ignore_padding_tokens=router_ignore_padding_tokens,
        )
        self.experts = nn.ModuleDict()
        for idx in range(num_experts):
            self.experts[f"expert_{idx}"] = expert_class(
                embed_dim=embed_dim,
                ffn_dim=ffn_dim,
                dropout_rate=expert_ffn_dropout,
                activation=expert_activation,
            )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple]:
        """
        Route tokens to experts and process them.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        Returns:
        --------
        x : torch.Tensor
            Output tensor of shape (batch_size, sequence_length, embed_dim).
        """
        # get the router mask, probabilities, and logits
        expert_mask, router_probs, router_logits = self.router(x, top_k=self.top_k)
        expert_outputs = []

        for idx, expert in self.experts.items():
            int_idx = int(idx.split("_")[-1])
            token_indices = expert_mask[..., int_idx].bool()
            expert_output = expert(x[token_indices]).to(x.dtype)
            expanded_output = torch.zeros_like(x)
            expanded_output[token_indices] = expert_output
            expert_outputs.append(expanded_output)

        # Combine the outputs from the selected tokens for each expert
        x = torch.stack(expert_outputs, dim=-1) * expert_mask.unsqueeze(-2)
        x = x.sum(dim=-1)

        return x, (router_logits, expert_mask)

In [35]:
class SparseTransformerLayer(nn.Module):
    """
    BALM transformer layer with Mixture of Experts. Approximately follows the ESM-2
    implementation, but differs in a few ways:
        - includes (optional) dropout for self-attention and feedforward layers
        - normalize **after**, not before, the self-attention and feedforward layers
        - we don't use rotary embeddings, which aren't (yet?) compatible with
          torch's optimized implementation of ``nn.MultiheadAttention``

    Parameters:
    -----------
    config : BalmMoEConfig
        Model configuration class with all the parameters of the model.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_heads: int,
        num_experts: int,
        expert_capacity: int,
        top_k: int = 1,
        expert_activation: str = "gelu",
        expert_ffn_dropout: float = 0.0,
        ffn_dropout: float = 0.0,
        attention_dropout: float = 0.0,
        attention_batch_first: bool = True,
        layer_norm_eps: float = 1e-5,
        router_dtype: str = "float32",
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        router_class: nn.Module = TopKRouter,
        expert_class: nn.Module = Expert,
        # config: BalmMoEConfig,
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        self.num_heads = num_heads
        self.attention_dropout = attention_dropout
        self.ffn_dropout = ffn_dropout
        self.expert_ffn_dropout = expert_ffn_dropout
        self.layer_norm_eps = layer_norm_eps

        # can't use rotary embeddings with nn.MultiheadAttention
        # see: https://discuss.pytorch.org/t/is-there-a-way-to-implement-rope-around-nn-multiheadattention-somehow/175051
        # it is possible to use rotary embeddings with F.scaled_dot_product_attention,
        # but it's not clear that it's worth the effort
        # see: https://github.com/pytorch/pytorch/issues/97899 for an example
        # self.use_rotary_embeddings = use_rotary_embeddings

        self.self_attn = nn.MultiheadAttention(
            embed_dim=self.embed_dim,
            num_heads=self.num_heads,
            dropout=self.attention_dropout,
            batch_first=attention_batch_first,
        )

        self.mlp = SparseMLP(
            embed_dim=self.embed_dim,
            ffn_dim=self.ffn_dim,
            num_experts=num_experts,
            top_k=top_k,
            expert_capacity=expert_capacity,
            expert_activation=expert_activation,
            expert_ffn_dropout=expert_ffn_dropout,
            router_dtype=router_dtype,
            router_bias=router_bias,
            router_jitter=router_jitter,
            router_ignore_padding_tokens=router_ignore_padding_tokens,
            router_class=router_class,
            expert_class=expert_class,
        )
        self.ff_dropout = nn.Dropout(self.ffn_dropout)

        self.norm1 = nn.LayerNorm(self.embed_dim, eps=self.layer_norm_eps)
        self.norm2 = nn.LayerNorm(self.embed_dim, eps=self.layer_norm_eps)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
        output_router_logits: bool = True,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
        """
        Process the input hidden states.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        attn_mask : torch.Tensor, optional
            Attention mask of shape (batch_size * num_heads, sequence_length, sequence_length). The default is None.

        key_padding_mask : torch.Tensor, optional
            Mask of shape (batch_size, sequence_length). The default is None.

        need_weights : bool, optional
            Whether to return attention weights. The default is False.

            .. note::
                if `need_weights` is ``True``, the output will be a tuple of (x, attn). Also,
                nn.MultiHeadAttention will not be able to use the optimized torch implementation
                of ``scaled_dot_product_attention``. See `here`_ for more details.

        output_router_logits : bool, optional
            Whether to output router logits. The default is True.

        Returns:
        --------
        x : torch.Tensor or Tuple

            Output tensor of shape (batch_size, sequence_length, embed_dim). If `need_weights`, is ``True``,
            output is a tuple of (x, attn). If `output_router_logits` is ``True``, the output will be a tuple
            of (x, router_logits) or (x, attn, router_logts) depending on the value of `need_weights`.


        .. _here:
            https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward
        """
        # attention
        residual = x
        x, _ = self.self_attn(
            query=x,
            key=x,
            value=x,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attention_mask,
        )
        if need_weights:
            x, attn = x
        x = residual + x
        x = self.norm1(x)

        # sparse feedforward
        residual = x
        x, router_tuple = self.mlp(x)  # router_tuple is (router_logits, expert_index)
        x = self.ff_dropout(x)
        x = self.norm2(residual + x)
        if output_router_logits and router_tuple is not None:
            if need_weights:
                return (x, attn, router_tuple)
            return (x, router_tuple)
        if need_weights:
            return (x, attn)
        return x

In [36]:
class BalmMoE(nn.Module):
    """
    BALM Mixture of Experts model.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_layers: int,
        num_heads: int,
        num_experts: int,
        expert_capacity: int,
        vocab_size: int,
        max_length: int = 320,
        expert_activation: str = "gelu",
        expert_ffn_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        attention_dropout: float = 0.0,
        attention_batch_first: bool = True,
        layer_norm_eps: float = 1e-5,
        router_dtype: str = "float32",
        router_top_k: int = 1,
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        padding_idx: int = 0,
        router_class: nn.Module = TopKRouter,
        expert_class: nn.Module = Expert,
        # config: BalmMoEConfig,
    ):
        super().__init__()
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.embed_positions = RelativePositionalEmbedding(embed_dim)
        self.layers = nn.ModuleList(
            [
                SparseTransformerLayer(
                    embed_dim=embed_dim,
                    ffn_dim=ffn_dim,
                    num_heads=num_heads,
                    num_experts=num_experts,
                    top_k=router_top_k,
                    expert_capacity=expert_capacity,
                    expert_activation=expert_activation,
                    expert_ffn_dropout=expert_ffn_dropout,
                    attention_dropout=attention_dropout,
                    attention_batch_first=attention_batch_first,
                    layer_norm_eps=layer_norm_eps,
                    router_dtype=router_dtype,
                    router_bias=router_bias,
                    router_jitter=router_jitter,
                    router_ignore_padding_tokens=router_ignore_padding_tokens,
                    router_class=router_class,
                    expert_class=expert_class,
                )
                for _ in range(num_layers)
            ]
        )
        self.embedding_dropout = nn.Dropout(token_embedding_dropout)
        self.final_norm = nn.LayerNorm(embed_dim)

        self.attention_batch_first = attention_batch_first

    @property
    def num_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_router_logits: bool = False,
        return_dict: bool = True,
    ):
        """
        Parameters:
        -----------

        input_ids: torch.LomgTensor
            Tokenized input IDs

        attention_mask: torch.BoolTensor
            Attention mask

        output_attentions: bool
            Whether to output attention weights

        output_hidden_states: bool
            Whether to output hidden states

        output_router_logits: bool
            Whether to output router logits

        return_dict: bool
            Whether to return a dictionary of outputs (returns a tuple by default)


        Returns:
        --------
        output (tuple or dict):
            If `return_dict` is ``True``, the output is a ``dict`` of outputs:
                - last_hidden_state (torch.FloatTensor): last hidden state
                - router_z_loss (torch.FloatTensor): router z loss
                - router_aux_loss (torch.FloatTensor): router auxiliary loss
                - attentions (torch.FloatTensor): attention weights
                - hidden_states (torch.FloatTensor): hidden states
                - router_logits (torch.FloatTensor): router logits
            If `return_dict` is ``False``, the output is a ``tuple`` with the f0llowing elements:
                - last_hidden_state (torch.FloatTensor): last hidden state
                - attentions (torch.FloatTensor): attention weights
                - hidden_states (torch.FloatTensor): hidden states
                - router_logits (torch.FloatTensor): router logits
        """
        # init
        attn_weights = []
        hidden_states = {}
        router_logits = []
        expert_indexes = []

        # embeddings
        x = self.embed_tokens(input_ids)
        x = self.embed_positions(x)
        x = self.embedding_dropout(x)

        # encoder
        # x = x.transpose(0, 1)
        for layer_idx, layer in enumerate(self.layers, 1):
            x = layer(
                x,
                attention_mask=attention_mask,
                key_padding_mask=key_padding_mask,
                need_weights=output_attentions,
                output_router_logits=output_router_logits,
            )
            if output_attentions:
                x, attn, router_tuple = x
                attn_weights.append(attn)
            else:
                x, router_tuple = x
            router_logits.append(router_tuple[0])
            expert_indexes.append(router_tuple[1])
            if output_hidden_states:
                # hidden_states[layer_idx] = x.transpose(0, 1)
                hidden_states[layer_idx] = x
        x = self.final_norm(x)
        # x = x.transpose(0, 1)

        # Compute the router losses (z_loss + auxiliary loss)
        cat_router_logits = torch.cat(router_logits, dim=1)
        cat_expert_indexes = torch.cat(expert_indexes, dim=1)
        router_probs = nn.Softmax(dim=-1)(cat_router_logits)
        z_loss = router_z_loss(cat_router_logits)
        aux_loss = router_load_balancing_loss(router_probs, cat_expert_indexes)

        # results
        result = MaskedLMOutput(
            last_hidden_state=x,
            router_z_loss=z_loss,
            router_aux_loss=aux_loss,
        )
        if output_attentions:
            # attentions: B x L x H x T x T
            attentions = torch.stack(attn_weights, 1)
            attentions = attentions * attention_mask[:, None, None, :, :]
            result["attentions"] = attentions
        if output_hidden_states:
            result["hidden_states"] = hidden_states
        if output_router_logits:
            result["router_logits"] = cat_router_logits
        if return_dict:
            return result
        return result.as_tuple()


class BalmMoEForMaskedLM(nn.Module):
    """
    BALM Mixture of Experts model for Masked Language Modeling.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_layers: int,
        num_heads: int,
        num_experts: int,
        expert_capacity: int,
        vocab_size: int,
        max_length: int = 320,
        expert_activation: str = "gelu",
        expert_ffn_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        attention_dropout: float = 0.0,
        attention_batch_first: bool = True,
        layer_norm_eps: float = 1e-5,
        router_dtype: str = "float32",
        router_top_k: int = 1,
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        router_z_loss_coef: float = 0.001,
        router_aux_loss_coef: float = 0.001,
        padding_idx: int = 0,
        router_class: nn.Module = TopKRouter,
        expert_class: nn.Module = Expert,
    ):
        super().__init__()
        self.balm = BalmMoE(
            embed_dim=embed_dim,
            ffn_dim=ffn_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            num_experts=num_experts,
            router_top_k=router_top_k,
            expert_capacity=expert_capacity,
            vocab_size=vocab_size,
            max_length=max_length,
            expert_activation=expert_activation,
            expert_ffn_dropout=expert_ffn_dropout,
            token_embedding_dropout=token_embedding_dropout,
            attention_dropout=attention_dropout,
            attention_batch_first=attention_batch_first,
            layer_norm_eps=layer_norm_eps,
            router_dtype=router_dtype,
            router_bias=router_bias,
            router_jitter=router_jitter,
            router_ignore_padding_tokens=router_ignore_padding_tokens,
            padding_idx=padding_idx,
            router_class=router_class,
            expert_class=expert_class,
        )
        self.lm_head = BalmLMHead(
            embed_dim=embed_dim,
            output_dim=vocab_size,
            # weight=self.balm.embed_tokens.weight,
        )

        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.router_z_loss_coef = router_z_loss_coef
        self.router_aux_loss_coef = router_aux_loss_coef

    @property
    def num_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_router_logits: bool = True,
        return_dict: bool = True,
    ):
        """
        Args:
            input_ids (torch.LongTensor): tokenized input IDs
            attention_mask (torch.BoolTensor): attention mask
            return_dict (bool): return a dictionary of outputs
        """
        # encoder
        outputs = self.balm(
            input_ids,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=True,
            return_dict=True,
        )
        x = outputs["last_hidden_state"]
        router_z_loss = outputs["router_z_loss"]
        router_aux_loss = outputs["router_aux_loss"]

        # LM head
        lm_logits = self.lm_head(x)
        outputs["logits"] = lm_logits

        # loss
        if labels is not None:
            # move labels to correct device
            labels = labels.to(lm_logits.device)
            loss = self.criterion(
                lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)
            )
            outputs["lm_loss"] = loss

            if output_router_logits:
                z_loss = self.router_z_loss_coef * (router_z_loss)
                aux_loss = self.router_aux_loss_coef * (router_aux_loss)
                outputs["router_z_loss"] = z_loss
                outputs["router_aux_loss"] = aux_loss
                loss = loss + z_loss + aux_loss
            outputs["loss"] = loss

        if return_dict:
            return outputs
        return outputs.as_tuple()

In [37]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [38]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [39]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text",
)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [40]:
collator = DataCollator(tokenizer=tokenizer)

In [41]:
# model = BalmMoERoPEForMaskedLM(
model = BalmMoEForMaskedLM(
    embed_dim=256,
    ffn_dim=1024,
    num_experts=4,
    num_layers=8,
    num_heads=8,
    router_top_k=1,
    router_class=ExpertChoiceRouter,
    expert_capacity=128,
    # expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

In [42]:
model.num_parameters

18982689

In [43]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    epochs=1,
    logging_steps=10,
    eval_steps=50,
    warmup_steps=50,
    per_device_train_batch_size=32,
    # per_device_eval_batch_size=32,
    use_cpu=True,
    # compute_metrics=compute_metrics,
)

In [44]:
trainer.device

device(type='cpu')

In [45]:
trainer.train()

  0%|          | 0/31 [00:00<?, ?it/s]

step 10 | loss: 2.9996 | lm_loss: 2.9614 | router_z_loss: 0.0191 | router_aux_loss: 0.0191 | lr: 0.000080
step 20 | loss: 2.8047 | lm_loss: 2.7731 | router_z_loss: 0.0133 | router_aux_loss: 0.0184 | lr: 0.000160
step 30 | loss: 2.6965 | lm_loss: 2.6716 | router_z_loss: 0.0098 | router_aux_loss: 0.0150 | lr: 0.000240
Training complete
