In [1]:
from typing import Optional

import torch
import torch.nn as nn

from balm.config import BaseConfig
from balm.data import load_dataset, DataCollator
from balm.embedding import RelativePositionalEmbedding
from balm.modules import (
    BalmClassificationHead,
    BalmLMHead,
    ClassifierOutput,
    MaskedLMOutput,
    RoformerLayer,
    TransformerLayer,
    DenseTransformerLayer,
)
from balm.models.base import BalmBase

from balm.tokenizer import Tokenizer
from balm.train import Trainer

In [2]:
class BalmConfig(BaseConfig):
    def __init__(
        self,
        embed_dim: int = 320,
        ffn_dim: int = 1280,
        num_layers: int = 6,
        num_heads: int = 20,
        num_experts: int = 8,
        max_length: int = 320,
        vocab_size: int = 33,
        dropout: float = 0.1,
        attention_dropout: float = 0.0,
        token_embedding_dropout: float = 0.0,
        positional_embedding_type: str = "rotary",
        pre_norm: bool = True,
        activation: str = "swiglu",
        layer_norm_eps: float = 1e-5,
        padding_idx: int = 0,
    ):
        """
        Configuration for the Balm model. Default parameters are similar to the 8M parameter ESM-2 model.

        Parameters
        ----------
        embed_dim : int, default=320
            The dimension of the token embeddings.

        ffn_dim : int, default=1280
            The dimension of the feed-forward network.

        num_layers : int, default=6
            The number of layers in the transformer.

        num_heads : int, default=20
            The number of heads in the transformer.

        num_experts : int, default=8
            The number of experts in the transformer.

        max_length : int, default=320
            The maximum length of the input sequence.

        vocab_size : int, default=33
            The vocabulary size.

        dropout : float, default=0.1
            The dropout rate. Applied immediately before adding the residual connection.

        attention_dropout : float, default=0.0
            The dropout rate for the attention layer.

        token_embedding_dropout : float, default=0.0
            The dropout rate for the token embedding layer.

        positional_embedding_type : str, default="rotary"
            The type of positional embedding to use. Options are "rotary" or "relative".

        pre_norm : bool, default=True
            Whether to use pre-norm or post-norm.

        ffn_activation : str, default="swiglu"
            The activation function to use in the feed-forward network. Options are "swiglu", "relu", or "gelu".

        layer_norm_eps : float, default=1e-5
            The epsilon value for the layer normalization.

        padding_idx : int, default=0
            The index of the padding token.
        """
        super().__init__()
        self.embed_dim = embed_dim
        self.ffn_dim = ffn_dim
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.num_experts = num_experts
        self.max_length = max_length
        self.vocab_size = vocab_size
        self.dropout = dropout
        self.attention_dropout = attention_dropout
        self.token_embedding_dropout = token_embedding_dropout
        if positional_embedding_type.lower() not in ["rotary", "relative"]:
            raise ValueError(
                f"Invalid positional embedding type: {positional_embedding_type}. Options are 'rotary' or 'relative'."
            )
        self.positional_embedding_type = positional_embedding_type.lower()
        if activation.lower() not in ["swiglu", "relu", "gelu" ]:
            raise ValueError(
                f"Invalid FFN activation: {activation}. Options are 'swiglu', 'relu', or 'gelu'."
            )
        self.activation = activation.lower()
        self.pre_norm = pre_norm
        self.layer_norm_eps = layer_norm_eps
        self.padding_idx = padding_idx



In [11]:
class BalmModel(BalmBase):
    config_cls = BalmConfig

    def __init__(
        self,
        config: BalmConfig,
    ):
        """
        BALM model, with rotary embeddings, pre-norm, and SwiGLU activations.

        Parameters
        ----------
        config : BalmConfig
            The configuration object defining model architecture and hyperparameters.

        """
        super().__init__(config)
        # embedding
        self.embed_tokens = nn.Embedding(
            self.config.vocab_size,
            self.config.embed_dim,
            padding_idx=self.config.padding_idx,
        )

        # layers
        self.layers = nn.ModuleList(
            [
                DenseTransformerLayer(
                    self.config.embed_dim,
                    self.config.ffn_dim,
                    self.config.num_heads,
                    self.config.max_length,
                    dropout=self.config.dropout,
                    attention_dropout=self.config.attention_dropout,
                    token_embedding_dropout=self.config.token_embedding_dropout,
                    layer_norm_eps=self.config.layer_norm_eps,
                    activation=self.config.activation,
                    positional_embedding_type=self.config.positional_embedding_type,
                    pre_norm=self.config.pre_norm,
                )
                for _ in range(self.config.num_layers)
            ]
        )
        
        self.final_layer_norm = nn.LayerNorm(
            self.config.embed_dim, eps=self.config.layer_norm_eps
        )

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        need_weights: bool = False,
    ) -> torch.Tensor:
        """
        Parameters
        ----------
        x : torch.Tensor
            The input tensor. Expected shape is (batch_size, sequence_length).

        Returns
        -------
        torch.Tensor
            The output tensor. The shape is (batch_size, sequence_length, embed_dim).
        """
        x = self.embed_tokens(x)
        for layer in self.layers:
            x = layer(
                x,
                attention_mask=attention_mask,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
            )
            if need_weights:
                x, attn = x
        if self.config.pre_norm:
            x = self.final_layer_norm(x)
        if need_weights:
            return x, attn
        return x


class BalmForMaskedLM(BalmBase):
    config_cls = BalmConfig

    def __init__(
        self,
        config: BalmConfig,
    ):
        """
        BALM model for masked language modeling. Uses the base BALM model with rotary
        embeddings, pre-norm, and SwiGLU activations, and adds a language modeling head.

        Parameters
        ----------
        config : BalmConfig
            The configuration object defining model architecture and hyperparameters.

        """
        super().__init__(config)
        self.balm = BalmModel(config=self.config)
        self.lm_head = BalmLMHead(self.config.embed_dim, self.config.vocab_size)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> MaskedLMOutput:
        """
        Parameters
        ----------

        x : torch.Tensor
            The input tensor. Expected shape is (batch_size, seq_len).

        Returns
        -------
        torch.Tensor
            The output tensor. The shape is (batch_size, seq_len, vocab_size).
        """
        x = self.balm(
            input_ids,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            need_weights=output_attentions,
        )
        if output_attentions:
            x, attn = x
        logits = self.lm_head(x)

        masked_lm_loss = None
        if labels is not None:
            masked_lm_loss = self.criterion(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
            )

        output = MaskedLMOutput(
            logits=logits,
            loss=masked_lm_loss,
        )
        if output_attentions:
            output.attentions = attn
        if output_hidden_states:
            output.hidden_states = x
        if return_dict:
            return output.as_dict()
        return output.as_tuple()


class BalmForSequenceClassification(BalmBase):
    config_cls = BalmConfig

    def __init__(
        self,
        config: BalmConfig,
    ):
        """
        BALM model for masked language modeling. Uses the base BALM model with rotary
        embeddings, pre-norm, and SwiGLU activations, and adds a language modeling head.

        Parameters
        ----------
        config : BalmConfig
            The configuration object defining model architecture and hyperparameters.

        """
        super().__init__(config)
        # model
        self.balm = BalmModel(config=self.config)

        # classifier
        classifier_dropout = (
            self.config.classifier_dropout
            if self.config.classifier_dropout is not None
            else self.config.dropout
        )
        classifier_activation = (
            self.config.classifier_activation
            if self.config.classifier_activation is not None
            else "tanh"
        )
        self.classifier = BalmClassificationHead(
            embed_dim=self.config.embed_dim,
            num_labels=self.config.num_labels,
            dropout=classifier_dropout,
            activation=classifier_activation,
        )

        # loss
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)

    def forward(
        self,
        input_ids: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        key_padding_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        output_hidden_states: bool = False,
        return_dict: bool = True,
    ) -> MaskedLMOutput:
        """
        Parameters
        ----------

        x : torch.Tensor
            The input tensor. Expected shape is (batch_size, seq_len).

        Returns
        -------
        torch.Tensor
            The output tensor. The shape is (batch_size, seq_len, vocab_size).
        """
        x = self.balm(
            input_ids,
            attention_mask=attention_mask,
            key_padding_mask=key_padding_mask,
            need_weights=output_attentions,
        )
        if output_attentions:
            x, attn = x
        logits = self.classifier(x)

        classifier_loss = None
        if labels is not None:
            classifier_loss = self.criterion(
                logits.view(-1, logits.size(-1)),
                labels.view(-1),
            )

        output = ClassifierOutput(
            logits=logits,
            loss=classifier_loss,
        )
        if output_attentions:
            output.attentions = attn
        if output_hidden_states:
            output.hidden_states = x
        if return_dict:
            return output.as_dict()
        return output.as_tuple()

In [12]:
tokenizer = Tokenizer(vocab="./vocab.json")

In [13]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "./balm/test_data/test_1k.txt",
    "test": "./balm/test_data/test_1k.txt",
    "eval": "./balm/test_data/test_1k.txt",
}

# data_files = {
#     "train": "../train-test-eval_paired/train.txt",
#     "test": "../train-test-eval_paired/test.txt",
#     "eval": "../train-test-eval_paired/eval.txt",
# }

# data_files = {
#     "train": "../jaffe-plusHD_clust0.9_split/train.txt",
#     "test": "../jaffe-plusHD_clust0.9_split/test.txt",
#     "eval": "../jaffe-plusHD_clust0.9_split/eval.txt",
# }

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [14]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
    remove_columns="text",
)

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

Encoding:   0%|          | 0/1000 [00:00<?, ?it/s]

In [15]:
collator = DataCollator(tokenizer=tokenizer)

In [16]:
# matched to ESM-2 8M
config = BalmConfig(
    embed_dim=320,
    ffn_dim=320 * 4,
    num_layers=6,
    num_heads=20,
    max_length=320,
    vocab_size=tokenizer.vocab_size,
)
# model = BalmForMaskedLM(
#     embed_dim=320,
#     ffn_dim=320*4,
#     num_layers=6,
#     num_heads=20,
#     vocab_size=tokenizer.vocab_size,
# )
model = BalmForMaskedLM(config=config)

In [17]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    output_dir="./training_runs/save_tests",
    epochs=1,
    logging_steps=5,
    eval_steps=100,
    warmup_steps=10,
    # save_steps=15,
    per_device_train_batch_size=32,
    # use_cpu=True,
    # use_wandb=True,
    wandb_project="test_wandb_logging",
    # wandb_entity="bryanbriney",
    run_name="save_test_001",
)

In [18]:
trainer.train()

Training:   0%|          | 0/31 [00:00<?, ?step/s]

step 5     | loss: 3.1664 | lr: 0.000200
step 10    | loss: 2.7392 | lr: 0.000400
step 15    | loss: 2.6835 | lr: 0.000305
step 20    | loss: 2.6565 | lr: 0.000210
step 25    | loss: 2.6667 | lr: 0.000114
step 30    | loss: 2.6235 | lr: 0.000019
<< SAVING FINAL MODEL >>

Training complete
