In [1]:
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union

from balm.config import BalmMoEConfig
from balm.data import load_dataset, DataCollator
from balm.embedding import RotaryPositionalEmbedding, RelativePositionalEmbedding
from balm.loss import router_load_balancing_loss, router_z_loss
from balm.models.base import BalmBase
from balm.modules import Expert, BalmLMHead, DenseTransformerLayer, MaskedLMOutput
from balm.router import TopKRouter, ExpertChoiceRouter
from balm.tokenizer import Tokenizer
from balm.train import Trainer

In [11]:
class SparseMLP(nn.Module):
    """
    Implementation of a Sparse MLP module, for use in Mixture-of-Experts models.

    Parameters:
    -----------
    embed_dim : int
        Embedding dimension.
    
    ffn_dim : int
        Feedforward dimension.
    
    num_experts : int
        Number of experts.
    
    expert_capacity : int
        Capacity of each expert.
    
    top_k : int, optional
        Top k for the router. The default is 1.
    
    activation : str, optional
        Activation function to use. The default is "swiglu".
    
    expert_ffn_dropout : float, optional
        Dropout rate for the expert feedforward layer. The default is 0.0.
    
    router_dtype : str, optional
        Dtype for the router. The default is "float32".
    
    router_bias : bool, optional
        Whether to use bias for the router. The default is False.
    
    router_jitter : float, optional
        Jitter for the router. The default is 0.0.
    
    router_ignore_padding_tokens : bool, optional
        Whether to ignore padding tokens for the router. The default is True.
    
    router_class : nn.Module, optional
        Router class to use. The default is ``TopKRouter``.
    
    expert_class : nn.Module, optional
        Expert class to use. The default is ``Expert``.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_experts: int,
        expert_capacity: int,
        num_shared_experts: int = 0,
        top_k: int = 1,
        activation: str = "swiglu",
        expert_ffn_dropout: float = 0.0,
        router_dtype: str = "float32",
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        router_class: nn.Module = TopKRouter,
        expert_class: nn.Module = Expert,
    ):
        super().__init__()
        self.router = router_class(
            embed_dim=embed_dim,
            num_experts=num_experts,
            expert_capacity=expert_capacity,
            top_k=top_k,
            num_shared_experts=num_shared_experts,
            dtype=router_dtype,
            bias=router_bias,
            jitter=router_jitter,
            ignore_padding_tokens=router_ignore_padding_tokens,
        )
        self.experts = nn.ModuleDict()
        for idx in range(num_experts):
            self.experts[f"expert_{idx}"] = expert_class(
                embed_dim=embed_dim,
                ffn_dim=ffn_dim,
                dropout=expert_ffn_dropout,
                activation=activation,
            )

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Tuple]:
        """
        Route tokens to experts and process them.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        Returns:
        --------
        x : torch.Tensor
            Output tensor of shape (batch_size, sequence_length, embed_dim).
        """
        # router
        expert_mask, router_probs, router_logits = self.router(x)
        expert_outputs = []

        # experts
        for idx, expert in self.experts.items():
            int_idx = int(idx.split("_")[-1])
            token_indices = expert_mask[..., int_idx].bool()
            expert_output = expert(x[token_indices]).to(x.dtype)
            expanded_output = torch.zeros_like(x)
            expanded_output[token_indices] = expert_output
            expert_outputs.append(expanded_output)

        # combine outputs from the selected tokens for each expert
        x = torch.stack(expert_outputs, dim=-1) * expert_mask.unsqueeze(-2)
        x = x.sum(dim=-1)

        return x, (router_logits, expert_mask)

In [12]:
class SparseTransformerLayer(nn.Module):
    """
    BALM transformer layer with Mixture of Experts. Approximately follows the ESM-2
    implementation, but differs in a few ways:
        - includes (optional) dropout for self-attention and feedforward layers
        - normalize **after**, not before, the self-attention and feedforward layers
        - we don't use rotary embeddings, which aren't (yet?) compatible with
          torch's optimized implementation of ``nn.MultiheadAttention``

    Parameters:
    -----------
    config : BalmMoEConfig
        Model configuration class with all the parameters of the model.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        num_heads: int,
        max_length: int,
        num_experts: int,
        expert_capacity: int,
        num_shared_experts: int = 0,
        top_k: int = 1,
        dropout: float = 0.1,
        attention_dropout: float = 0.0,
        expert_ffn_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        activation: str = "swiglu",
        positional_embedding_type: str = "rotary",
        pre_norm: bool = True,
        router_dtype: str = "float32",
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        expert_choice_router: bool = False,
    ):
        super().__init__()
        self.pre_norm = pre_norm

        # embeddings
        if positional_embedding_type.lower() == "rotary":
            self.positional_embeddings = RotaryPositionalEmbedding(
                embed_dim, max_length
            )
        else:
            self.positional_embeddings = RelativePositionalEmbedding(
                embed_dim, max_length
            )

        # norm
        self.norm1 = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
        self.norm2 = nn.LayerNorm(embed_dim, eps=layer_norm_eps)

        # attention
        self.self_attn = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=attention_dropout,
            batch_first=True,
        )

        # sparse feedforward
        self.mlp = SparseMLP(
            embed_dim=embed_dim,
            ffn_dim=ffn_dim,
            num_experts=num_experts,
            num_shared_experts=num_shared_experts,
            top_k=top_k,
            expert_capacity=expert_capacity,
            activation=activation,
            expert_ffn_dropout=expert_ffn_dropout,
            router_dtype=router_dtype,
            router_bias=router_bias,
            router_jitter=router_jitter,
            router_ignore_padding_tokens=router_ignore_padding_tokens,
            router_class=ExpertChoiceRouter if expert_choice_router else TopKRouter,
            expert_class=Expert,
        )

        # dropout
        self.dropout = nn.Dropout(dropout)
        self.embedding_dropout = nn.Dropout(token_embedding_dropout)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
        output_router_logits: bool = True,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
        """
        Process the input hidden states.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        attn_mask : torch.Tensor, optional
            Attention mask of shape (batch_size * num_heads, sequence_length, sequence_length). The default is None.

        key_padding_mask : torch.Tensor, optional
            Mask of shape (batch_size, sequence_length). The default is None.

        need_weights : bool, optional
            Whether to return attention weights. The default is False.

            .. note::
                if `need_weights` is ``True``, the output will be a tuple of (x, attn). Also,
                nn.MultiHeadAttention will not be able to use the optimized torch implementation
                of ``scaled_dot_product_attention``. See `here`_ for more details.

        output_router_logits : bool, optional
            Whether to output router logits. The default is True.

        Returns:
        --------
        x : torch.Tensor or Tuple

            Output tensor of shape (batch_size, sequence_length, embed_dim). If `need_weights`, is ``True``,
            output is a tuple of (x, attn). If `output_router_logits` is ``True``, the output will be a tuple
            of (x, router_logits) or (x, attn, router_logts) depending on the value of `need_weights`.


        .. _here:
            https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward
        """
        # pre-norm
        residual = x
        if self.pre_norm:
            x = self.norm1(x)

        # positional embeddings
        x = self.embedding_dropout(self.positional_embeddings(x))

        # attention
        x = self.self_attn(
            x,
            x,
            x,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attention_mask,
        )
        if need_weights:
            x, attn = x
        else:
            x = x[0]
        x = residual + self.dropout(x)

        # post-norm
        if not self.pre_norm:
            x = self.norm1(x)

        # pre-norm
        residual = x
        if self.pre_norm:
            x = self.norm2(x)

        # feedforward
        x, router_tuple = self.mlp(x)
        x = residual + self.dropout(x)

        # post-norm
        if not self.pre_norm:
            x = self.norm2(residual + x)

        # outputs
        if output_router_logits and router_tuple is not None:
            if need_weights:
                return (x, attn, router_tuple)
            return (x, router_tuple)
        if need_weights:
            return (x, attn)
        return x

In [26]:
# self,
# embed_dim: int,
# ffn_dim: int,
# num_heads: int,
# max_length: int,
# num_experts: int,
# expert_capacity: int,
# num_shared_experts: int = 0,
# top_k: int = 1,
# dropout: float = 0.1,
# attention_dropout: float = 0.0,
# expert_ffn_dropout: float = 0.0,
# token_embedding_dropout: float = 0.0,
# layer_norm_eps: float = 1e-5,
# activation: str = "swiglu",
# positional_embedding_type: str = "rotary",
# pre_norm: bool = True,
# router_dtype: str = "float32",
# router_bias: bool = False,
# router_jitter: float = 0.0,
# router_ignore_padding_tokens: bool = True,
# expert_choice_router: bool = False,


class BalmMoEModel(BalmBase):
    """
    BALM Mixture of Experts model.
    """

    def __init__(
        self,
        config: BalmMoEConfig,
    ):
        super().__init__(config)
        self.alternate_sparsity = self.config.alternate_sparsity
        self.embed_tokens = nn.Embedding(
            self.config.vocab_size,
            self.config.embed_dim,
            padding_idx=self.config.padding_idx,
        )

        if self.config.alternate_sparsity:
            layers = []
            for layer_num in range(self.config.num_layers):
                if layer_num % 2 == 0:
                    layers.append(
                        DenseTransformerLayer(
                            embed_dim=self.config.embed_dim,
                            ffn_dim=self.config.ffn_dim,
                            num_heads=self.config.num_heads,
                            max_length=self.config.max_length,
                            dropout=self.config.dropout,
                            attention_dropout=self.config.attention_dropout,
                            token_embedding_dropout=self.config.token_embedding_dropout,
                            layer_norm_eps=self.config.layer_norm_eps,
                            activation=self.config.activation,
                            positional_embedding_type=self.config.positional_embedding_type,
                            pre_norm=self.config.pre_norm,
                        )
                    )
                else:
                    layers.append(
                        SparseTransformerLayer(
                            embed_dim=self.config.embed_dim,
                            ffn_dim=self.config.ffn_dim,
                            num_heads=self.config.num_heads,
                            max_length=self.config.max_length,
                            num_experts=self.config.num_experts,
                            expert_capacity=self.config.expert_capacity,
                            num_shared_experts=self.config.num_shared_experts,
                            top_k=self.config.router_top_k,
                            dropout=self.config.dropout,
                            attention_dropout=self.config.attention_dropout,
                            expert_ffn_dropout=self.config.expert_ffn_dropout,
                            token_embedding_dropout=self.config.token_embedding_dropout,
                            layer_norm_eps=self.config.layer_norm_eps,
                            activation=self.config.activation,
                            positional_embedding_type=self.config.positional_embedding_type,
                            pre_norm=self.config.pre_norm,
                            router_dtype=self.config.router_dtype,
                            router_bias=self.config.router_bias,
                            router_jitter=self.config.router_jitter,
                            router_ignore_padding_tokens=self.config.router_ignore_padding_tokens,
                            expert_choice_router=self.config.expert_choice_router,
                        )
                    )
            self.layers = nn.ModuleList(layers)

        else:
            self.layers = nn.ModuleList(
                [
                    SparseTransformerLayer(
                        embed_dim=self.config.embed_dim,
                        ffn_dim=self.config.ffn_dim,
                        num_heads=self.config.num_heads,
                        max_length=self.config.max_length,
                        num_experts=self.config.num_experts,
                        expert_capacity=self.config.expert_capacity,
                        num_shared_experts=self.config.num_shared_experts,
                        top_k=self.config.router_top_k,
                        dropout=self.config.dropout,
                        attention_dropout=self.config.attention_dropout,
                        expert_ffn_dropout=self.config.expert_ffn_dropout,
                        token_embedding_dropout=self.config.token_embedding_dropout,
                        layer_norm_eps=self.config.layer_norm_eps,
                        activation=self.config.activation,
                        positional_embedding_type=self.config.positional_embedding_type,
                        pre_norm=self.config.pre_norm,
                        router_dtype=self.config.router_dtype,
                        router_bias=self.config.router_bias,
                        router_jitter=self.config.router_jitter,
                        router_ignore_padding_tokens=self.config.router_ignore_padding_tokens,
                        expert_choice_router=self.config.expert_choice_router,
                    )
                    for _ in range(self.config.num_layers)
                ]
            )
        self.embedding_dropout = nn.Dropout(self.config.token_embedding_dropout)
        self.final_norm = nn.LayerNorm(self.config.embed_dim)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_router_logits: bool = False,
        output_expert_indexes: bool = False,
        return_dict: bool = True,
    ):
        """
        Parameters:
        -----------

        input_ids: torch.LomgTensor
            Tokenized input IDs

        attention_mask: torch.BoolTensor
            Attention mask

        output_attentions: bool
            Whether to output attention weights

        output_hidden_states: bool
            Whether to output hidden states

        output_router_logits: bool
            Whether to output router logits

        return_dict: bool
            Whether to return a dictionary of outputs (returns a tuple by default)


        Returns:
        --------
        output (tuple or dict):
            If `return_dict` is ``True``, the output is a ``dict`` of outputs:
                - last_hidden_state (torch.FloatTensor): last hidden state
                - router_z_loss (torch.FloatTensor): router z loss
                - router_aux_loss (torch.FloatTensor): router auxiliary loss
                - attentions (torch.FloatTensor): attention weights
                - hidden_states (torch.FloatTensor): hidden states
                - router_logits (torch.FloatTensor): router logits
            If `return_dict` is ``False``, the output is a ``tuple`` with the f0llowing elements:
                - last_hidden_state (torch.FloatTensor): last hidden state
                - attentions (torch.FloatTensor): attention weights
                - hidden_states (torch.FloatTensor): hidden states
                - router_logits (torch.FloatTensor): router logits
        """
        # init
        attn_weights = []
        hidden_states = {}
        router_logits = []
        expert_indexes = []

        # embeddings
        x = self.embed_tokens(x)

        # layers
        for layer_idx, layer in enumerate(self.layers, 1):
            if layer_idx % 2 == 0 or not self.alternate_sparsity:
                # sparse layer, so we need to collect router/expert info
                x = layer(
                    x,
                    attention_mask=attention_mask,
                    key_padding_mask=key_padding_mask,
                    need_weights=output_attentions,
                    output_router_logits=output_router_logits,
                )
                if output_attentions:
                    x, attn, router_tuple = x
                    attn_weights.append(attn)
                else:
                    x, router_tuple = x
                router_logits.append(router_tuple[0])
                expert_indexes.append(router_tuple[1])
                if output_hidden_states:
                    hidden_states[layer_idx] = x
            else:
                # dense layer, no router info needed
                x = layer(
                    x,
                    attention_mask=attention_mask,
                    need_weights=output_attentions,
                )
                if output_attentions:
                    x, attn = x
                    attn_weights.append(attn)
                if output_hidden_states:
                    hidden_states[layer_idx] = x

        if self.config.pre_norm:
            x = self.final_norm(x)

        # router losses
        cat_router_logits = torch.cat(router_logits, dim=1)
        cat_expert_indexes = torch.cat(expert_indexes, dim=1)
        router_probs = nn.Softmax(dim=-1)(cat_router_logits)
        z_loss = router_z_loss(cat_router_logits)
        if self.config.expert_choice_router:
            aux_loss = None
        else:
            aux_loss = router_load_balancing_loss(router_probs, cat_expert_indexes)

        # results
        result = MaskedLMOutput(
            last_hidden_state=x,
            router_z_loss=z_loss,
            router_aux_loss=aux_loss,
        )
        if output_attentions:
            # attentions: B x L x H x T x T
            attentions = torch.stack(attn_weights, 1)
            attentions = attentions * attention_mask[:, None, None, :, :]
            result["attentions"] = attentions
        if output_hidden_states:
            result["hidden_states"] = hidden_states
        if output_router_logits:
            result["router_logits"] = cat_router_logits
        if output_expert_indexes:
            result["expert_indexes"] = cat_expert_indexes
        if return_dict:
            return result
        return result.as_tuple()



In [27]:
class BalmMoEForMaskedLM(BalmBase):
    """
    BALM Mixture of Experts model for Masked Language Modeling.
    """

    def __init__(
        self,
        config: BalmMoEConfig,
    ):
        super().__init__(config)
        self.balm = BalmMoEModel(
            config=self.config,
        )
        self.lm_head = BalmLMHead(
            embed_dim=self.config.embed_dim,
            output_dim=self.config.vocab_size,
        )

        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.router_z_loss_coef = self.config.router_z_loss_coef
        self.router_aux_loss_coef = self.config.router_aux_loss_coef

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_router_logits: bool = True,
        output_expert_indexes: bool = False,
        return_dict: bool = True,
    ):
        """
        Args:
            input_ids (torch.LongTensor): tokenized input IDs
            attention_mask (torch.BoolTensor): attention mask
            return_dict (bool): return a dictionary of outputs
        """
        # encoder
        outputs = self.balm(
            input_ids,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            output_expert_indexes=output_expert_indexes,
            return_dict=True,
        )
        x = outputs["last_hidden_state"]
        router_z_loss = outputs["router_z_loss"]
        router_aux_loss = outputs["router_aux_loss"]

        # LM head
        lm_logits = self.lm_head(x)
        outputs["logits"] = lm_logits

        # loss
        if labels is not None:
            labels = labels.to(lm_logits.device)
            loss = self.criterion(
                lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)
            )
            outputs["lm_loss"] = loss

            if output_router_logits:
                z_loss = self.router_z_loss_coef * (router_z_loss)
                outputs["router_z_loss"] = z_loss
                if self.config.expert_choice_router:
                    loss = loss + z_loss
                else:
                    aux_loss = self.router_aux_loss_coef * (router_aux_loss)
                    outputs["router_aux_loss"] = aux_loss
                    loss = loss + z_loss + aux_loss
            outputs["loss"] = loss

        # outputs
        if return_dict:
            return outputs.as_dict()
        return outputs.as_tuple()

In [28]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [29]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


# data_files = {
#     "train": "./balm/test_data/test.txt",
#     "test": "./balm/test_data/test_1k.txt",
#     "eval": "./balm/test_data/test_1k.txt",
# }

data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [30]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text"
)

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

In [31]:
collator = DataCollator(tokenizer=tokenizer)

In [32]:

config = BalmMoEConfig(
    embed_dim=320,
    ffn_dim=320*4,
    num_experts=8,
    num_shared_experts=0,
    num_layers=6,
    num_heads=20,
    alternate_sparsity=True,
    router_top_k=1,
    expert_choice_router=False,
    max_length=320,
    # expert_capacity=128,
    # router_z_loss_coef=0.001,
    # router_aux_loss_coef=0.001,
    vocab_size=tokenizer.vocab_size,
)

model=BalmMoEForMaskedLM(config=config)



In [33]:
model.num_parameters

19199393

In [34]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    epochs=1,
    logging_steps=5,
    eval_steps=100,
    warmup_steps=10,
    per_device_train_batch_size=32,
    use_cpu=True,
    # use_wandb=True,
    wandb_project="test_wandb_logging",
    # wandb_entity="bryanbriney",
    run_name="save_test_001",
)

In [35]:
trainer.device

device(type='cpu')

In [36]:
trainer.train()

Training:   0%|          | 0/31 [00:00<?, ?step/s]

step 5     | loss: 3.2322 | MLM loss: 3.2256 | router z-loss: 0.0050 | router aux loss: 0.0016 | lr: 0.000200
step 10    | loss: 2.7818 | MLM loss: 2.7761 | router z-loss: 0.0042 | router aux loss: 0.0015 | lr: 0.000400
step 15    | loss: 2.6787 | MLM loss: 2.6742 | router z-loss: 0.0033 | router aux loss: 0.0013 | lr: 0.000305
step 20    | loss: 2.6706 | MLM loss: 2.6668 | router z-loss: 0.0025 | router aux loss: 0.0012 | lr: 0.000210
step 25    | loss: 2.6406 | MLM loss: 2.6373 | router z-loss: 0.0021 | router aux loss: 0.0012 | lr: 0.000114
step 30    | loss: 2.6175 | MLM loss: 2.6144 | router z-loss: 0.0019 | router aux loss: 0.0012 | lr: 0.000019
<< SAVING FINAL MODEL >>

Training complete
