<a href="https://colab.research.google.com/github/vaibhav803/express/blob/master/examples/language_modeling_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import random
import numpy as np

import torch
from torch.utils.data import Dataset, Sampler

import lance


def apply_fim(sample, fim_prefix, fim_middle, fim_suffix, fim_pad, mode, np_rng):
    """
    Applies FIM transformation on one sample
    """
    boundaries = sorted(np_rng.randint(low=0, high=len(sample) + 1, size=2))

    prefix = sample[: boundaries[0]]
    middle = sample[boundaries[0] : boundaries[1]]
    suffix = sample[boundaries[1] :]

    total_length = len(prefix) + len(middle) + len(suffix) + 3
    diff = total_length - len(sample)
    if diff > 0:
        suffix = suffix[: max(0, len(suffix) - diff)]
    elif diff < 0:
        extend = torch.cat([fim_pad for _ in range(-diff)])
        suffix = torch.cat([suffix, extend])

    if mode == "spm":
        # Apply SPM
        transfomed_example = torch.cat(
            [fim_prefix, fim_suffix, suffix, fim_middle, prefix, middle]
        )
    else:
        # Apply PSM
        transfomed_example = torch.cat(
            [fim_prefix, prefix, fim_suffix, suffix, fim_middle, middle]
        )

    return transfomed_example


class MambaDataset(Dataset):
    def __init__(
        self,
        dataset_path,
        context_len,
        fim_prefix,
        fim_middle,
        fim_suffix,
        fim_pad,
        fim_rate=0.5,
        mode="psm",
        rng_seed=42,
    ):
        # Load the lance dataset from the saved path
        self.ds = lance.dataset(dataset_path)
        self.context_len = context_len

        # Doing this so the sampler never asks for an index at the end of text
        self.length = self.ds.count_rows() - context_len

        self.np_rng = np.random.RandomState(seed=rng_seed)

        self.fim_prefix = torch.tensor([fim_prefix])
        self.fim_middle = torch.tensor([fim_middle])
        self.fim_suffix = torch.tensor([fim_suffix])
        self.fim_pad = torch.tensor([fim_pad])
        self.fim_rate = fim_rate
        self.mode = mode

    def __len__(self):
        return self.length

    def from_idxs(self, idxs):
        """
        Little utility function to get the data from lance
        """
        data = self.ds.take(idxs).to_pylist()
        data = torch.tensor(list(map(lambda x: x["value"], data)))
        return data

    def apply_fim(self, sample):
        """
        Applies FIM transformation on one sample
        """
        boundaries = sorted(self.np_rng.randint(low=0, high=len(sample) + 1, size=2))

        prefix = sample[: boundaries[0]]
        middle = sample[boundaries[0] : boundaries[1]]
        suffix = sample[boundaries[1] :]

        total_length = len(prefix) + len(middle) + len(suffix) + 3
        diff = total_length - len(sample)
        if diff > 0:
            suffix = suffix[: max(0, len(suffix) - diff)]
        elif diff < 0:
            extend = torch.cat([self.fim_pad for _ in range(-diff)])
            suffix = torch.cat([suffix, extend])

        if self.mode == "spm":
            # Apply SPM
            transfomed_example = torch.cat(
                [
                    self.fim_prefix,
                    self.fim_suffix,
                    suffix,
                    self.fim_middle,
                    prefix,
                    middle,
                ]
            )
        else:
            # Apply PSM
            transfomed_example = torch.cat(
                [
                    self.fim_prefix,
                    prefix,
                    self.fim_suffix,
                    suffix,
                    self.fim_middle,
                    middle,
                ]
            )

        return transfomed_example

    def __getitem__(self, idx):
        """
        Generate a list of indices starting from the current idx to idx+context_len+1
        with optional fim transformation
        """
        current_window_idxs = np.arange(idx, idx + self.context_len + 1)
        sample = self.from_idxs(current_window_idxs)

        # Apply FIM transformation depending on the rate
        if self.np_rng.binomial(1, self.fim_rate):
            sample = self.apply_fim(sample)

        # +1 in labels because it is 1 step ahead of input tokens
        tokens = sample[0 : self.context_len]
        labels = sample[1 : self.context_len + 1]
        return {"tokens": tokens, "labels": labels}


class MambaSampler(Sampler):
    r"""Samples tokens randomly but `k` indices apart where `k` is generally the context length of the LLM.

    Args:
        data_source (Dataset): dataset to sample from
        k (int): minimum index distance between each random sample
    """

    def __init__(self, data_source, k=16):
        self.data_source = data_source
        self.num_samples = len(self.data_source)
        self.available_indices = list(range(0, self.num_samples, k))
        random.shuffle(self.available_indices)

    def __iter__(self):
        yield from self.available_indices

    def __len__(self) -> int:
        return len(self.available_indices)


In [12]:
# Copyright (c) 2023, Albert Gu, Tri Dao.

import math
from functools import partial
import json
import os

from collections import namedtuple

import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.modules.mamba_simple import Mamba, Block
from mamba_ssm.utils.generation import GenerationMixin
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


def create_block(
    d_model,
    ssm_cfg=None,
    norm_epsilon=1e-5,
    rms_norm=False,
    residual_in_fp32=False,
    fused_add_norm=False,
    layer_idx=None,
    device=None,
    dtype=None,
):
    if ssm_cfg is None:
        ssm_cfg = {}
    factory_kwargs = {"device": device, "dtype": dtype}
    mixer_cls = partial(Mamba, layer_idx=layer_idx, **ssm_cfg, **factory_kwargs)
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )
    block = Block(
        d_model,
        mixer_cls,
        norm_cls=norm_cls,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32,
    )
    block.layer_idx = layer_idx
    return block


# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
def _init_weights(
    module,
    n_layer,
    initializer_range=0.02,  # Now only used for embedding layer.
    rescale_prenorm_residual=True,
    n_residuals_per_layer=1,  # Change to 2 if we have MLP
):
    if isinstance(module, nn.Linear):
        if module.bias is not None:
            if not getattr(module.bias, "_no_reinit", False):
                nn.init.zeros_(module.bias)
    elif isinstance(module, nn.Embedding):
        nn.init.normal_(module.weight, std=initializer_range)

    if rescale_prenorm_residual:
        # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
        #   > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
        #   > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
        #   >   -- GPT-2 :: https://openai.com/blog/better-language-models/
        #
        # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
        for name, p in module.named_parameters():
            if name in ["out_proj.weight", "fc2.weight"]:
                # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
                # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
                # We need to reinit p since this code could be called multiple times
                # Having just p *= scale would repeatedly scale it down
                nn.init.kaiming_uniform_(p, a=math.sqrt(5))
                with torch.no_grad():
                    p /= math.sqrt(n_residuals_per_layer * n_layer)


class MixerModel(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_layer: int,
        vocab_size: int,
        ssm_cfg=None,
        norm_epsilon: float = 1e-5,
        rms_norm: bool = False,
        initializer_cfg=None,
        fused_add_norm=False,
        residual_in_fp32=False,
        device=None,
        dtype=None,
    ) -> None:
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32

        self.embedding = nn.Embedding(vocab_size, d_model, **factory_kwargs)

        # We change the order of residual and layer norm:
        # Instead of LN -> Attn / MLP -> Add, we do:
        # Add -> LN -> Attn / MLP / Mixer, returning both the residual branch (output of Add) and
        # the main branch (output of MLP / Mixer). The model definition is unchanged.
        # This is for performance reason: we can fuse add + layer_norm.
        self.fused_add_norm = fused_add_norm
        if self.fused_add_norm:
            if layer_norm_fn is None or rms_norm_fn is None:
                raise ImportError("Failed to import Triton LayerNorm / RMSNorm kernels")

        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    **factory_kwargs,
                )
                for i in range(n_layer)
            ]
        )

        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
            d_model, eps=norm_epsilon, **factory_kwargs
        )

        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return {
            i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
            for i, layer in enumerate(self.layers)
        }

    def forward(self, input_ids, inference_params=None):
        hidden_states = self.embedding(input_ids)
        residual = None
        for layer in self.layers:
            hidden_states, residual = layer(
                hidden_states, residual, inference_params=inference_params
            )
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
        else:
            # Set prenorm=False here since we don't need the residual
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                hidden_states,
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                prenorm=False,
                residual_in_fp32=self.residual_in_fp32,
            )
        return hidden_states


class MambaLMHeadModel(nn.Module, GenerationMixin):

    def __init__(
        self,
        config: MambaConfig,
        initializer_cfg=None,
        device=None,
        dtype=None,
    ) -> None:
        self.config = config
        d_model = config.d_model
        n_layer = config.n_layer
        vocab_size = config.vocab_size
        ssm_cfg = config.ssm_cfg
        rms_norm = config.rms_norm
        residual_in_fp32 = config.residual_in_fp32
        fused_add_norm = config.fused_add_norm
        pad_vocab_size_multiple = config.pad_vocab_size_multiple
        factory_kwargs = {"device": device, "dtype": dtype}

        super().__init__()
        if vocab_size % pad_vocab_size_multiple != 0:
            vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple)
        self.backbone = MixerModel(
            d_model=d_model,
            n_layer=n_layer,
            vocab_size=vocab_size,
            ssm_cfg=ssm_cfg,
            rms_norm=rms_norm,
            initializer_cfg=initializer_cfg,
            fused_add_norm=fused_add_norm,
            residual_in_fp32=residual_in_fp32,
            **factory_kwargs,
        )
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)

        # Initialize weights and apply final processing
        self.apply(
            partial(
                _init_weights,
                n_layer=n_layer,
                **(initializer_cfg if initializer_cfg is not None else {}),
            )
        )
        self.tie_weights()

    def tie_weights(self):
        self.lm_head.weight = self.backbone.embedding.weight

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

    def get_input_embeddings(self):
        return self.backbone.embedding

    def set_input_embeddings(self, new_embeddings):
        self.backbone.embedding = new_embeddings
        self.tie_weights()

    def resize_token_embeddings(self, vocab_size):
        old_embeddings = self.backbone.embedding
        old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
        new_embeddings = nn.Embedding(
            vocab_size,
            old_embedding_dim,
            device=old_embeddings.weight.device,
            dtype=old_embeddings.weight.dtype,
        )
        nn.init.normal_(new_embeddings.weight, std=0.02)
        n = min(old_num_tokens, vocab_size)
        new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
        self.backbone.embedding = new_embeddings

        self.tie_weights()

    def forward(self, input_ids, position_ids=None, inference_params=None, num_last_tokens=0):
        """
        Changing this function from the original Mamba implementation to make it work
        with my training scripts (-Tanay)

        "position_ids" is just to be compatible with Transformer generation. We don't use it.
        num_last_tokens: if > 0, only return the logits for the last n tokens
        """
        hidden_states = self.backbone(input_ids, inference_params=inference_params)
        if num_last_tokens > 0:
            hidden_states = hidden_states[:, -num_last_tokens]
        lm_logits = self.lm_head(hidden_states)
        return lm_logits

    @classmethod
    def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
        config_data = load_config_hf(pretrained_model_name)
        config = MambaConfig(**config_data)
        model = cls(config, device=device, dtype=dtype, **kwargs)
        model.load_state_dict(load_state_dict_hf(pretrained_model_name, device=device, dtype=dtype))
        return model

    def save_pretrained(self, save_directory):
        """
        Minimal implementation of save_pretrained for MambaLMHeadModel.
        Save the model and its configuration file to a directory.
        """
        # Ensure save_directory exists
        if not os.path.exists(save_directory):
            os.makedirs(save_directory)

        # Save the model's state_dict
        model_path = os.path.join(save_directory, 'pytorch_model.bin')
        torch.save(self.state_dict(), model_path)

        # Save the configuration of the model
        config_path = os.path.join(save_directory, 'config.json')
        with open(config_path, 'w') as f:
            json.dump(self.config.__dict__, f)

ModuleNotFoundError: No module named 'mamba_ssm'

In [22]:
pip install causal-conv1d>=1.1.0

In [23]:
pip install mamba-ssm


Collecting mamba-ssm
  Downloading mamba_ssm-2.2.5.tar.gz (113 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/113.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.8/113.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mamba-ssm
  Building wheel for mamba-ssm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mamba-ssm: filename=mamba_ssm-2.2.5-cp312-cp312-linux_x86_64.whl size=532566033 sha256=c8b65fcabfb49a94456c9971619007218e4073f19a84fb6b3894f33d43bee4a1
  Stored in directory: /root/.cache/pip/wheels/21/55/c4/85b634055d6a9b599d27f5cbeacf353c6c532d8e2d8769960b
Successfully built mamba-ssm
Installing collected packages: mamba-ssm
Successfully installed mamba-ssm-2.2.5


In [24]:
pip install numpy



In [25]:
pip install torch



In [26]:
pip install transformers



In [27]:
pip install wandb




In [8]:
# Single GPU training script using FIM
import os
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

import transformers

from mamba_ssm import MambaLMHeadModel

import lance
import pyarrow as pa

from tqdm.auto import tqdm

from data import MambaDataset, MambaSampler

import wandb


# Params (replace with Arg parser later)
class Args:
    wandb = False
    tokenizer_model = "EleutherAI/gpt-neox-20b"
    model_name = "state-spaces/mamba-790m"
    dataset_path = (
        "/teamspace/studios/codeparrot-dataset-lance/code_parrot_github_python.lance"
    )
    eval_dataset_path = "fim_data_eval.lance"
    dataset = lance.dataset(dataset_path)
    low_cpu_mem_usage = False
    fim_training = True
    fim_rate = 0.9
    truncate_or_pad = True
    fim_prefix_token = "<fim_prefix>"
    fim_middle_token = "<fim_middle_token>"
    fim_suffix_token = "<fim_suffix_token>"
    fim_pad_token = "<fim_pad>"
    pad_factor = 8
    lr = 1e-4
    epochs = 10
    context_len = 384
    train_batch_size = 8
    valid_batch_size = 8
    T_0 = 1000
    T_mult = 1
    eta_min = 1e-5
    device = torch.device("cuda:0")
    # Total chunks of context_len+1 size we can get
    steps_per_epoch = (dataset.count_rows() // context_len + 1) // 4


# Define Tokenizer and Model
tokenizer = transformers.AutoTokenizer.from_pretrained(Args.tokenizer_model)
tokenizer.pad_token = tokenizer.eos_token

model = MambaLMHeadModel.from_pretrained(
    Args.model_name,
).to(Args.device)

# Get the FIM-specific tokens and get their token ids
tokenizer.add_tokens(
    [
        Args.fim_prefix_token,
        Args.fim_middle_token,
        Args.fim_middle_token,
        Args.fim_pad_token,
    ]
)
prefix_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_prefix_token)
middle_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_middle_token)
suffix_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_middle_token)
pad_tok_id = None

fim_tokens = [prefix_tok_id, middle_tok_id, suffix_tok_id]

# If truncate_or_pad is on, also get pad token id
if Args.truncate_or_pad:
    pad_tok_id = tokenizer.convert_tokens_to_ids(Args.fim_pad_token)
    fim_tokens.append(pad_tok_id)

# Add new tokens and resize model token embeddings according to multivariate normal distribution
original_embeddings = model.get_input_embeddings().weight
model.resize_token_embeddings(len(tokenizer))
mean = original_embeddings.mean(dim=0)
n = original_embeddings.size()[0]
sigma = ((original_embeddings - mean).T @ (original_embeddings - mean)) / n
dist = torch.distributions.MultivariateNormal(mean, covariance_matrix=1e-5 * sigma)
new_token_embeddings = torch.stack(
    tuple((dist.sample() for _ in range(len(fim_tokens)))), dim=0
)

# Get updated embedding layer and make a copy of it's weights
embeddings = model.get_input_embeddings()
new_embeddings = embeddings.weight.clone()

# Set the new token' embeddings to the newly sampled embeddings
new_embeddings[-len(fim_tokens) :] = new_token_embeddings

# Update the model's embeddings with the new embeddings
embeddings.weight = torch.nn.Parameter(new_embeddings)
model.set_input_embeddings(embeddings)

# Make train dataset and train dataloader
train_dataset = MambaDataset(
    Args.dataset_path,
    context_len=Args.context_len,
    fim_prefix=prefix_tok_id,
    fim_middle=middle_tok_id,
    fim_suffix=suffix_tok_id,
    fim_pad=pad_tok_id,
    fim_rate=Args.fim_rate,
    mode="psm",
)

train_dataloader = iter(
    DataLoader(
        train_dataset,
        batch_size=Args.train_batch_size,
        sampler=MambaSampler(train_dataset, k=Args.context_len + 1),
        shuffle=False,
        pin_memory=True,
    )
)

# Optimizer and Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=Args.lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=Args.T_0, T_mult=Args.T_mult, eta_min=Args.eta_min
)

# Start training
print(f"{'*'*8} Starting training {'*'*8}")
print(f"Total training tokens: {lance.dataset(Args.dataset_path).count_rows():,}")
print(f"Epochs to train: {Args.epochs}")
print(f"Training steps per epoch: {Args.steps_per_epoch:,}\n")
# print(f"Total training steps in training: {Args.steps_per_epoch * Args.epochs:,}")


def wandb_log(**kwargs):
    """Easy interface to log stuff to wandb"""
    for k, v in kwargs.items():
        wandb.log({k: v})


if Args.wandb:
    # Convert the Config class to a dict for logging
    config_dict = dict(vars(Args))
    del [config_dict["__module__"]]
    del [config_dict["__dict__"]]
    del [config_dict["__weakref__"]]
    del [config_dict["__doc__"]]

    from dotenv import load_dotenv

    load_dotenv()
    wandb.login()
    run = wandb.init(
        project="pytorch",
        config=config_dict,
        group="mamba-train",
        job_type="train",
    )
    wandb.watch(model)

prog_bar = tqdm(
    range(Args.steps_per_epoch * Args.epochs), total=Args.steps_per_epoch * Args.epochs
)
for epoch in range(Args.epochs):
    model.train()
    total_loss = []
    for step in range(Args.steps_per_epoch):
        # Get the next batch
        batch = next(train_dataloader)
        for k, v in batch.items():
            batch[k] = v.to(Args.device)

        # Get predictions
        predictions = model(batch["tokens"])

        # Reshape predictions and calculate loss
        B, C, V = predictions.shape
        predictions = predictions.view(B * C, V)
        targets = batch["labels"].view(B * C)
        loss = torch.nn.functional.cross_entropy(predictions, targets)
        prog_bar.set_description((f"loss: {loss.item():.4f}"))

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad(set_to_none=True)
        prog_bar.update(1)

        total_loss.append(loss.item())
        if Args.wandb:
            wandb_log(step_loss=loss.item())

    # Calculate perplexity for the epoch
    try:
        perplexity = np.exp(np.mean(total_loss))
    except OverflowError:
        perplexity = float("-inf")

    if Args.wandb:
        wandb_log(train_perplexity=perplexity)

    print(f"epoch: {epoch} | train perplexity: {perplexity:.4f}")

# Save the model after training
model_name = Args.model_name.split("/")[-1]
torch.save(model.state_dict(), f"{model_name}-fim.bin")
print("Saved the model!")


ModuleNotFoundError: No module named 'mamba_ssm'

In [None]:
pip install mamba_ssm


Collecting mamba_ssm
  Downloading mamba_ssm-2.2.5.tar.gz (113 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/113.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m112.6/113.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.8/113.8 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [6]:
pip install causal-conv1d>=1.4.0