In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from typing import Optional, Tuple, Union

from balm.data import load_dataset, DataCollator
from balm.embedding import RelativePositionalEmbedding
from balm.loss import router_z_loss, router_load_balancing_loss
from balm.models.base import BalmBase
from balm.modules import (
    Expert,
    BalmLMHead,
    MaskedLMOutput,
    DenseTransformerLayer,
    # SparseTransformerLayer,
    SparseMLP,
)
from balm.router import TopKRouter, ExpertChoiceRouter
from balm.tokenizer import Tokenizer
from balm.train.trainer import Trainer

In [35]:
class HybridSparseTransformerLayer(nn.Module):
    """
    Hybrid sparse transformer layer. Inspired by Snowflake's `Arctic model`_.

    .. _Arctic model:
        https://www.snowflake.com/blog/arctic-open-efficient-foundation-language-models-snowflake/
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        residual_ffn_dim: int,
        num_heads: int,
        num_experts: int,
        expert_capacity: int,
        max_length: int,
        num_shared_experts: int = 0,
        top_k: int = 2,
        activation: str = "swiglu",
        expert_activation: str = "swiglu",
        dropout: float = 0.1,
        attention_dropout: float = 0.0,
        expert_ffn_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        router_dtype: str = "float32",
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        expert_choice_router: bool = False,
        pre_norm: bool = True,
        positional_embedding_type: str = "rotary",
    ):
        super().__init__()

        # dense transformer
        self.dense_transformer = DenseTransformerLayer(
            embed_dim=embed_dim,
            ffn_dim=ffn_dim,
            num_heads=num_heads,
            max_length=max_length,
            dropout=dropout,
            attention_dropout=attention_dropout,
            token_embedding_dropout=token_embedding_dropout,
            layer_norm_eps=layer_norm_eps,
            activation=activation,
            positional_embedding_type=positional_embedding_type,
            pre_norm=pre_norm,
        )

        # sparse residual connection
        self.residual_norm = nn.LayerNorm(embed_dim, eps=layer_norm_eps)
        self.sparse_residual = SparseMLP(
            embed_dim=embed_dim,
            ffn_dim=residual_ffn_dim,
            num_experts=num_experts,
            expert_capacity=expert_capacity,
            num_shared_experts=num_shared_experts,
            top_k=top_k,
            expert_ffn_dropout=expert_ffn_dropout,
            activation=expert_activation,
            router_dtype=router_dtype,
            router_bias=router_bias,
            router_jitter=router_jitter,
            router_ignore_padding_tokens=router_ignore_padding_tokens,
            router_class=ExpertChoiceRouter if expert_choice_router else TopKRouter,
            expert_class=Expert,
        )

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
        output_router_logits: bool = True,
    ) -> Union[torch.Tensor, Tuple[torch.Tensor, Tuple]]:
        """
        Process the input hidden states.

        Parameters:
        -----------
        x : torch.Tensor
            Input tensor of shape (batch_size, sequence_length, embed_dim).

        attn_mask : torch.Tensor, optional
            Attention mask of shape (batch_size * num_heads, sequence_length, sequence_length). The default is None.

        key_padding_mask : torch.Tensor, optional
            Mask of shape (batch_size, sequence_length). The default is None.

        need_weights : bool, optional
            Whether to return attention weights. The default is False.

            .. note::
                if `need_weights` is ``True``, the output will be a tuple of (x, attn). Also,
                nn.MultiHeadAttention will not be able to use the optimized torch implementation
                of ``scaled_dot_product_attention``. See `here`_ for more details.

        output_router_logits : bool, optional
            Whether to output router logits. The default is True.

        Returns:
        --------
        x : torch.Tensor or Tuple

            Output tensor of shape (batch_size, sequence_length, embed_dim). If `need_weights`, is ``True``,
            output is a tuple of (x, attn). If `output_router_logits` is ``True``, the output will be a tuple
            of (x, router_logits) or (x, attn, router_logts) depending on the value of `need_weights`.


        .. _here:
            https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward
        """
        # residual connection
        residual, (router_logits, _) = self.sparse_residual(self.residual_norm(x))

        # dense transformer
        x = self.dense_transformer(
            x,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
        )
        if need_weights:
            x, attn = x
        else:
            x = x[0]

        # add residual
        x = x + residual

        # outputs
        if need_weights:
            if output_router_logits:
                return (x, attn, router_logits)
            return (x, attn)
        if output_router_logits:
            return (x, router_logits)
        return x

In [36]:
class BalmHybridMoEModel(nn.Module):
    """
    BALM Mixture of Experts model.
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        residual_ffn_dim: int,
        num_layers: int,
        num_heads: int,
        num_experts: int,
        expert_capacity: int,
        vocab_size: int,
        max_length: int = 320,
        num_shared_experts: int = 0,
        activation: str = "swiglu",
        expert_activation: str = "swiglu",
        dropout: float = 0.1,
        attention_dropout: float = 0.0,
        expert_ffn_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        router_dtype: str = "float32",
        router_top_k: int = 1,
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_ignore_padding_tokens: bool = True,
        padding_idx: int = 0,
        expert_choice_router: bool = False,
        pre_norm: bool = True,
        positional_embedding_type: str = "rotary",
    ):
        super().__init__()
        self.expert_choice_router = expert_choice_router
        self.embed_tokens = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx)
        self.layers = nn.ModuleList(
            [
                HybridSparseTransformerLayer(
                    embed_dim=embed_dim,
                    ffn_dim=ffn_dim,
                    residual_ffn_dim=residual_ffn_dim,
                    num_heads=num_heads,
                    max_length=max_length,
                    num_experts=num_experts,
                    expert_capacity=expert_capacity,
                    num_shared_experts=num_shared_experts,
                    top_k=router_top_k,
                    activation=activation,
                    expert_activation=expert_activation,
                    dropout=dropout,
                    attention_dropout=attention_dropout,
                    expert_ffn_dropout=expert_ffn_dropout,
                    token_embedding_dropout=token_embedding_dropout,
                    layer_norm_eps=layer_norm_eps,
                    router_dtype=router_dtype,
                    router_bias=router_bias,
                    router_jitter=router_jitter,
                    router_ignore_padding_tokens=router_ignore_padding_tokens,
                    expert_choice_router=expert_choice_router,
                    pre_norm=pre_norm,
                    positional_embedding_type=positional_embedding_type,
                )
                for _ in range(num_layers)
            ]
        )
        self.final_norm = nn.LayerNorm(embed_dim)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_router_logits: bool = True,
        output_expert_indices: bool = False,
        return_dict: bool = True,
    ):
        """
        Parameters:
        -----------

        input_ids: torch.LomgTensor
            Tokenized input IDs

        attention_mask: torch.BoolTensor
            Attention mask

        output_attentions: bool
            Whether to output attention weights

        output_hidden_states: bool
            Whether to output hidden states

        output_router_logits: bool
            Whether to output router logits

        return_dict: bool
            Whether to return a dictionary of outputs (returns a tuple by default)


        Returns:
        --------
        output (tuple or dict):
            If `return_dict` is ``True``, the output is a ``dict`` of outputs:
                - last_hidden_state (torch.FloatTensor): last hidden state
                - router_z_loss (torch.FloatTensor): router z loss
                - router_aux_loss (torch.FloatTensor): router auxiliary loss
                - attentions (torch.FloatTensor): attention weights
                - hidden_states (torch.FloatTensor): hidden states
                - router_logits (torch.FloatTensor): router logits
            If `return_dict` is ``False``, the output is a ``tuple`` with the f0llowing elements:
                - last_hidden_state (torch.FloatTensor): last hidden state
                - attentions (torch.FloatTensor): attention weights
                - hidden_states (torch.FloatTensor): hidden states
                - router_logits (torch.FloatTensor): router logits
        """
        # init
        attn_weights = []
        hidden_states = {}
        router_logits = []
        expert_indexes = []

        # embeddings
        x = self.embed_tokens(input_ids)

        # encoder
        for layer_idx, layer in enumerate(self.layers, 1):
            # because of the sparse residual, we need to collect router/expert info
            x = layer(
                x,
                attention_mask=attention_mask,
                key_padding_mask=key_padding_mask,
                need_weights=output_attentions,
                output_router_logits=output_router_logits,
            )
            if output_attentions:
                x, attn, router_tuple = x
                attn_weights.append(attn)
            else:
                x, router_tuple = x
            router_logits.append(router_tuple[0])
            expert_indexes.append(router_tuple[1])
            if output_hidden_states:
                hidden_states[layer_idx] = x
        x = self.final_norm(x)

        # Compute the router losses (only z_loss for expert choice MoEs)
        cat_router_logits = torch.cat(router_logits, dim=1)
        cat_expert_indexes = torch.cat(expert_indexes, dim=1)
        router_probs = nn.Softmax(dim=-1)(cat_router_logits)
        z_loss = router_z_loss(cat_router_logits)
        if self.expert_choice_router:
            aux_loss = None
        else:
            aux_loss = router_load_balancing_loss(router_probs, cat_expert_indexes)

        # results
        result = MaskedLMOutput(
            last_hidden_state=x,
            router_z_loss=z_loss,
            router_aux_loss=aux_loss,
        )
        if output_attentions:
            # attentions: B x L x H x T x T
            attentions = torch.stack(attn_weights, 1)
            attentions = attentions * attention_mask[:, None, None, :, :]
            result["attentions"] = attentions
        if output_hidden_states:
            result["hidden_states"] = hidden_states
        if output_router_logits:
            result["router_logits"] = cat_router_logits
        if output_expert_indices:
            result["expert_indices"] = cat_expert_indexes
        if return_dict:
            return result
        return result.as_tuple()


class BalmHybridMoEForMaskedLM(nn.Module):
    """
    BALM Hybrid Mixture of Experts model for Masked Language Modeling. Inspired by Snowflake's `Arctic model`_.

    .. _Arctic model:
        https://www.snowflake.com/blog/arctic-open-efficient-foundation-language-models-snowflake/
    """

    def __init__(
        self,
        embed_dim: int,
        ffn_dim: int,
        residual_ffn_dim: int,
        num_layers: int,
        num_heads: int,
        num_experts: int,
        expert_capacity: int,
        vocab_size: int,
        max_length: int = 320,
        num_shared_experts: int = 0,
        activation: str = "swiglu",
        expert_activation: str = "swiglu",
        dropout: float = 0.1,
        attention_dropout: float = 0.0,
        expert_ffn_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        layer_norm_eps: float = 1e-5,
        router_dtype: str = "float32",
        router_top_k: int = 1,
        router_bias: bool = False,
        router_jitter: float = 0.0,
        router_z_loss_coef: float = 0.001,
        router_aux_loss_coef: float = 0.001,
        router_ignore_padding_tokens: bool = True,
        padding_idx: int = 0,
        expert_choice_router: bool = False,
        pre_norm: bool = True,
        positional_embedding_type: str = "rotary",
    ):
        super().__init__()
        self.balm = BalmHybridMoEModel(
            embed_dim=embed_dim,
            ffn_dim=ffn_dim,
            residual_ffn_dim=residual_ffn_dim,
            num_layers=num_layers,
            num_heads=num_heads,
            num_experts=num_experts,
            num_shared_experts=num_shared_experts,
            router_top_k=router_top_k,
            expert_capacity=expert_capacity,
            vocab_size=vocab_size,
            max_length=max_length,
            activation=activation,
            expert_activation=expert_activation,
            dropout=dropout,
            expert_ffn_dropout=expert_ffn_dropout,
            token_embedding_dropout=token_embedding_dropout,
            attention_dropout=attention_dropout,
            layer_norm_eps=layer_norm_eps,
            router_dtype=router_dtype,
            router_bias=router_bias,
            router_jitter=router_jitter,
            router_ignore_padding_tokens=router_ignore_padding_tokens,
            padding_idx=padding_idx,
            expert_choice_router=expert_choice_router,
            pre_norm=pre_norm,
            positional_embedding_type=positional_embedding_type,
        )
        self.lm_head = BalmLMHead(
            embed_dim=embed_dim,
            output_dim=vocab_size,
        )

        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)
        self.router_z_loss_coef = router_z_loss_coef
        self.router_aux_loss_coef = router_aux_loss_coef

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        output_router_logits: bool = True,
        output_expert_indices: bool = False,
        return_dict: bool = True,
    ):
        """
        Args:
            input_ids (torch.LongTensor): tokenized input IDs
            attention_mask (torch.BoolTensor): attention mask
            return_dict (bool): return a dictionary of outputs
        """
        # encoder
        outputs = self.balm(
            input_ids,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            output_router_logits=output_router_logits,
            output_expert_indices=output_expert_indices,
            return_dict=True,
        )
        x = outputs["last_hidden_state"]
        router_z_loss = outputs["router_z_loss"]
        router_aux_loss = outputs["router_aux_loss"]

        # LM head
        lm_logits = self.lm_head(x)
        outputs["logits"] = lm_logits

        # loss
        if labels is not None:
            # move labels to correct device
            labels = labels.to(lm_logits.device)
            loss = self.criterion(
                lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)
            )
            outputs["lm_loss"] = loss

            if output_router_logits:
                z_loss = self.router_z_loss_coef * (router_z_loss)
                aux_loss = self.router_aux_loss_coef * (router_aux_loss)
                outputs["router_z_loss"] = z_loss
                outputs["router_aux_loss"] = aux_loss
                loss = loss + z_loss + aux_loss
            outputs["loss"] = loss

        if return_dict:
            return outputs.to_dict()
        return outputs.as_tuple()

In [25]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [26]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [27]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text",
)

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

In [28]:
collator = DataCollator(tokenizer=tokenizer)

In [37]:
# model = BalmMoERoPEForMaskedLM(
model = BalmHybridMoEForMaskedLM(
    embed_dim=320,
    ffn_dim=320*4,
    residual_ffn_dim=1024,
    num_experts=16,
    max_length=320,
    num_shared_experts=0,
    expert_capacity=320/16*1.5,
    num_layers=6,
    num_heads=20,
    router_top_k=2,
    router_z_loss_coef=0.001,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)

In [38]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    output_dir="./training_runs/save_tests",
    epochs=1,
    logging_steps=5,
    eval_steps=100,
    warmup_steps=10,
    # save_steps=15,
    per_device_train_batch_size=32,
    # use_cpu=True,
    # use_wandb=True,
    wandb_project="test_wandb_logging",
    # wandb_entity="bryanbriney",
    run_name="save_test_001",
)

In [39]:
trainer.train()

Training:   0%|          | 0/31 [00:00<?, ?step/s]

ValueError: not enough values to unpack (expected 3, got 2)