<a href="https://colab.research.google.com/github/subhajitchatterjee07/Implementation/blob/main/Mamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Learning about Mamba (LSTM with Selective state space)

Mamba is a recent model and is thought to be a better model compared to Transformers. Just to remember the existing models:

1. Transformers: have an attention mechanism where any part of the sequence can dynamically interact with any other. The ones with causal attention are good at handling individual elements of a sequence.
Problem with them is that they come with a significant computational and memory cost, scaling with the square of the sequence length(L^2).
2. Recurrent Neural Networks: RNNs update a hidden state sequentially, considering only the current input and the last hidden state. This approach allows them to potentially handle infinite sequence lengths and that too with constant memory requirements. Yet, this simplicity might cause problems like limiting their ability to remember long term dependencies. Additionally, backpropagation through time (BPTT) in RNNs can be memory intensive and may suffer from vanishing/exploding gradients, despite innovations like LSTM.
3. State-Space Models: These models have given promising results. They have both, the ability to capture long range dependencies more effectively than RNNs, and are also more memory-efficient than transformers.


## What Mamba does:
1. Mamba builds upon the concept of SSMs but it leverages selective state spaces to enable more efficient and effective capture of relevant information across long sequences.
2. Linear Time complexity: Mamba operates in linear time w.r.t. seq. length. This property mmakes it suitable for tasks involving long sequences.


## Selective State Spaces
Mamba has a slightly different approach than the traditional state space models, making it more adaptible and flexible, somewhat akin to LSTMs. However it retains the efficient computation trait of state space models, enabling it to perform the forward passes of an entire sequence in one sweep- a feature present in transformers.

## Training and inference with Mamba
During training, Mamba behaves like transformers, processing the entire sequence in one go. This approach contrasts with LSTMs, where forward pass must be computed step by step, even if all inputs are known. In inference, Mamba's behaviour aligns nmore with traditional recurrence models, offering efficient processing of sequences.

## Limitations of prior Models
A key limitation of prior SSMs is their rigid, input-invariant structure. Typically, these models employ a set of fixed parameters(let's call them A and B) for the entire sequence. This structure is even more restrictive than models like LSTMs, where the transformation of signals can depend on the previous hidden state and the input.

## The Mamba approach
What Mamba introduces is a shift in how the transition to the next hidden state is computed. In Mamba's architecture, the transition can be dependent on the current input. This approach strikes a balance between the fixed computational backbone of traditional SSMs and the input-dependent dynamism of RNNs.
### Key components:
1. Fixed Backbone:
The transition from one hidden state(defined by A matrix), allowing for precomputation across the sequence.
2. Input-dependent transformation: The way the input influences the next hidden state(defined by B matrix) is dependent on the current input, not on the previous hidden state. This input dependency allows for more flexibility compared to traditional SSMs.

### Overcoming computational challenges
To address the commputational demands of the approach, mamba utilizes a hardware-aware algorithm. The algorithm performs computations recurrently using a scen operation instead of convolution, making it highly efficient for GPUs. This efficiency is crucial for maintaining high performance despite the algorithmic complexity introduced by the input dependent transitions.

### Mamba vs Selective State Space
It is important to clarify that Mamba and selectively state space models are not synonymous. Mamba is an implementation that uses the concept of selective state spaces.  This distinction is crucial because it highlights Mamba's unique contribution: adapting the SSM framework to be more flexible and input-responsive while retaining computational efficiency.


Mamba promises to be the bridge between highly flexible but computationally efficient Transformers and efficient but rigid traditional SSMs. This balance could potentially unlock new capabilities in processing long sequences across various domains, from NLPs to genomic sequencing.

### GPU Memory: SRAM and HBM
GPUs contain 2 primary types of memory: High bandwidth memory (HBM) and Static Random Access Memory(SRAM). HBM, though high bandwidth, has a relatively slower access time compared to much faster but smaller SRAM. Understanding this, Mamba strategically uses SRAM for rapid access during matrix multiplications, which form the crux of its computations.

### Overcoming Data movement bottlenecks:
The primary bottleneck in computation is often not the calculations themselves but the movement of various data between the memory types. Mamba addresses this by significantly reducing the need to transfer large amounts of data. It does so by executing the critical parts of the algorithm, like discretization and recurrence computations, directly in SRAM, thus reducing latency.
### The fused selective scan layer:
Mamba introduces a fused selective scan layer, which brings its memory requirements on par with  optimized transformer implementations using Flash attention. This layer is crucial for maintaining efficiency, especially when dealing with input-dependent elements in the model.
### Efficient computation with prefix sums/parallel scans
Mamba utilizes prefix sums or parallel scans for efficient computation. Unlike convolutions, which require a constant kernel, prefix sums can handle the varying elements introduced by Mamba's input-dependence. This method is essential for computing the cumulative multiplications of the matrices at different timesteps.
### Experimental Results and Scaling
Mamba has shown promising results in language modeling and DNA sequencing, areas where long sequences are prevalent. Its scalability and efficiency, even at longer sequence lengths, make it a strong candidate for a general sequence model backbone.

In [50]:
!pip install einops



In [51]:
!python -m pip show zip_files

Name: zip-files
Version: 0.4.1
Summary: Command line utilities for creating zip files
Home-page: https://github.com/goerz/zip_files
Author: Michael Goerz
Author-email: mail@michaelgoerz.net
License: BSD license
Location: /usr/local/lib/python3.10/dist-packages
Requires: click
Required-by: 


In [52]:
!pip install zip_files



In [53]:
import zipfile

In [54]:
#step 1
#import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
from einops import rearrange
from tqdm import tqdm

import math
import os
import urllib.request
from zipfile import ZipFile

from transformers import AutoTokenizer
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x7db075f9f340>

In [55]:
USE_MAMBA = 1
DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM = 0

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

Setting up flags and hyperparameters and flags.
Also configured the GPU for training.

In [56]:
#user_defined_hyperparameters
d_model = 8 #dimensions of model
state_size = 128 #example state size
seq_len = 100 #example sequence length
batch_size = 256 #example batch size
last_batch_size = 81 #only for the very last batch of the dataset
current_batch_size = batch_size
different_batch_size = False
h_new = None
temp_buffer = None

## Defining the S6 module
The s6 class represents a sophisticated component within the Mamba architecture, responsible for processing input sequences through a series of linear transformations and a discretization process. It plays a critical role in capturing the temporal dynamics of sequences, a key aspect of sequence modelling tasks such as language modelling. The class showcases advanced techniques such as tensor operations and custom discretization methods to handle the complex requirements of sequence data.

Discretization function is defined based on the Mamba paper's description using ZOH on page 28, in section : Mechanics on selective SSMs.
See also "Zero order hold discretization" maths proof inside https://studywolf.wordpress.com/tag/zero-order-hold/


Here is an explanation for the mathematical rationale for the formulation of Δt used in Mamba:

The key idea is that Δt contros the discretization rate of the continous SSM dynamics. By making Δt input-independent, it introduces selectivity into discrete transition matrices.
Specifically, in Mamba they parametrize Δt as:
Δt = τΔ(Parameter + sΔ(xt))

Here is an explanation of the mathematical rationale for the formulation of Δt used in Mamba:
  The key idea is that Δt controls the discretization rate of the continous SSM dynamics. By making Δt input dependent, it introduces selectivity into discrete transition matrices. Specifically, in Mamba they parametrize Δt as:
   Δt = τΔ(Parameter + sΔ(xt))
   Where :
   - Parameter is a learned scalar parameter that controls the baseline discretization rate
   - sΔ(xt) is a projection that makes Δt input-dependent by computing a value based on xt
   - τΔ(x) = softplus(x) transforms the result to be positive through the softplus nonlinearity.
   The rationale for this formulation is:
   - Parameter provides a reasonable default discretization rate.
   - sΔ(xt) injects input-dependence through projection
   - softplus ensures Δt is positive as required to be a valid timestep
   - the projection sΔ allows the model to learn modulate Δ
   - This modulation creates selectivity in how rapidly or slowly the states update
   So the summary, the learned input-dependent projection allows Δt, and thus the discrete dynamics, to become selective. The softplus and scalar parameter provide useful inductive biases on the top of this flexibility.
   The end result is discrete transition matrices that are selective on the input, enabling powerful sequence meodelling capabilities.

In [57]:
class S6(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(S6, self).__init__()

        self.fc1 = nn.Linear(d_model, d_model, device=device)
        self.fc2 = nn.Linear(d_model, state_size, device=device)
        self.fc3 = nn.Linear(d_model, state_size, device=device)

        self.seq_len = seq_len
        self.d_model = d_model
        self.state_size = state_size

        #self.A = nn.Parameter(torch.ones(d_model, state_size, device=device))
        self.A = nn.Parameter(F.normalize(torch.ones(d_model, state_size, device=device), p=2, dim=-1))
        nn.init.xavier_uniform_(self.A)

        self.B = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)
        self.C = torch.zeros(batch_size, self.seq_len, self.state_size, device=device)

        self.delta = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)
        self.dA = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.dB = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)

        # h should have dimensions [batch_size, seq_len, d_model, state_size]
        self.h = torch.zeros(batch_size, self.seq_len, self.d_model, self.state_size, device=device)
        self.y = torch.zeros(batch_size, self.seq_len, self.d_model, device=device)


    def discretization(self):
        # discretization function is defined based on the MAMBA paper's description using ZOH on page 28
        # in Section C : Mechanics on Selective SSMs
        # See also "Zero-order hold discretization" maths proof inside https://studywolf.wordpress.com/tag/zero-order-hold/
        """
        Here is an explanation of the mathematical rationale for the formulation of Δt used in Mamba:
        The key idea is that Δt controls the discretization rate of the continuous SSM dynamics. By making Δt input-dependent, it introduces selectivity into the discrete transition matrices.
        Specifically, in Mamba they parameterize Δt as:
        Δt = τΔ(Parameter + sΔ(xt))
        Where:
        - Parameter is a learned scalar parameter that controls the baseline discretization rate
        - sΔ(xt) is a projection that makes Δt input-dependent by computing a value based on xt
        - τΔ(x) = softplus(x) transforms the result to be positive through the softplus nonlinearity
        The rationale for this formulation is:
        - Parameter provides a reasonable default discretization rate
        - sΔ(xt) injects input-dependence through the projection
        - softplus ensures Δt is positive as required to be a valid timestep
        - The projection sΔ allows the model to learn to modulate Δt based on the input xt
        - This modulation creates selectivity in how rapidly or slowly the states update
        So in summary, the learned input-dependent projection allows Δt, and thus the discrete dynamics, to become selective. The softplus and scalar parameter provide useful inductive biases on top of this flexibility.
        The end result is discrete transition matrices that are selective on the input, enabling powerful sequence modeling capabilities.
        Credit: Claude2 AI chatbot
        """

        # inverse() only supports square matrix
        #dB = torch.matmul(torch.inverse(A * delta), torch.matmul(dA - torch.eye(A.shape[0]), B))
        self.dB = torch.einsum("bld,bln->bldn", self.delta, self.B)

        # https://github.com/state-spaces/mamba/blob/0131c1e94a46fc9f70bcfc9d57962963bb2f0b9e/mamba_ssm/modules/mamba_simple.py#L240
        #dA = torch.matrix_exp(A * delta)  # matrix_exp() only supports square matrix
        self.dA = torch.exp(torch.einsum("bld,dn->bldn", self.delta, self.A))
        #print(f"self.dA.shape = {self.dA.shape}")
        #print(f"self.dA.requires_grad = {self.dA.requires_grad}")

        return self.dA, self.dB

    def forward(self, x):
        # Refer to Algorithm 2 in the MAMBA paper
        self.B = self.fc2(x)
        self.C = self.fc3(x)
        self.delta = F.softplus(self.fc1(x))

        # Uses ZOH as in MAMBA, Hungry Hippo still uses bilinear transform for discretization
        self.discretization()

        if DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:  # this will trigger in-place runtime error if without using `h_new`

            global current_batch_size
            current_batch_size = x.shape[0]

            if self.h.shape[0] != current_batch_size:
                #print("Adjusting h_new for the different batch size of input data `x`")
                different_batch_size = True

                # Resize self.h to match the current batch size
                h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h[:current_batch_size, ...]) + rearrange(x, "b l d -> b l d 1") * self.dB

            else:
                different_batch_size = False
                h_new =  torch.einsum('bldn,bldn->bldn', self.dA, self.h) + rearrange(x, "b l d -> b l d 1") * self.dB

            # y needs to have a shape of [batch_size, seq_len, d_model]
            self.y = torch.einsum('bln,bldn->bld', self.C, h_new)

            # Update self.h with the detached state of h_new
            # Only do this if retaining gradients for self.h is not necessary for backprop
            # Otherwise, store h_new in a temporary list and update self.h after the loop
            global temp_buffer
            temp_buffer = h_new.detach().clone() if not self.h.requires_grad else h_new.clone()

            return self.y

        else:  # this will not trigger in-place runtime error
            # h should have dimensions [batch_size, seq_len, d_model, state_size]
            h = torch.zeros(x.size(0), self.seq_len, self.d_model, self.state_size, device=x.device)
            y = torch.zeros_like(x)

            h =  torch.einsum('bldn,bldn->bldn', self.dA, h) + rearrange(x, "b l d -> b l d 1") * self.dB

            # y needs to have a shape of [batch_size, seq_len, d_model]
            y = torch.einsum('bln,bldn->bld', self.C, h)

            return y



In [58]:
class MambaBlock(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(MambaBlock, self).__init__()

        self.inp_proj = nn.Linear(d_model, 2*d_model, device=device)
        self.out_proj = nn.Linear(2*d_model, d_model, device=device)

        # For residual skip connection
        self.D = nn.Linear(d_model, 2*d_model, device=device)

        # Set _no_weight_decay attribute on bias
        self.out_proj.bias._no_weight_decay = True

        # Initialize bias to a small constant value
        nn.init.constant_(self.out_proj.bias, 1.0)

        self.S6 = S6(seq_len, 2*d_model, state_size, device)

        # Add 1D convolution with kernel size 3
        self.conv = nn.Conv1d(seq_len, seq_len, kernel_size=3, padding=1, device=device)

        # Add linear layer for conv output
        self.conv_linear = nn.Linear(2*d_model, 2*d_model, device=device)

        # rmsnorm
        self.norm = RMSNorm(d_model, device=device)

    def forward(self, x):
        """
        x_proj.shape = torch.Size([batch_size, seq_len, 2*d_model])
        x_conv.shape = torch.Size([batch_size, seq_len, 2*d_model])
        x_conv_act.shape = torch.Size([batch_size, seq_len, 2*d_model])
        """
        # Refer to Figure 3 in the MAMBA paper

        x = self.norm(x)

        x_proj = self.inp_proj(x)
        #print(f"x_proj.shape = {x_proj.shape}")

        # Add 1D convolution with kernel size 3
        x_conv = self.conv(x_proj)
        #print(f"x_conv.shape = {x_conv.shape}")

        x_conv_act = F.silu(x_conv)
        #print(f"x_conv_act.shape = {x_conv_act.shape}")

        # Add linear layer for conv output
        x_conv_out = self.conv_linear(x_conv_act)
        #print(f"x_conv_out.shape = {x_conv_out.shape}")

        x_ssm = self.S6(x_conv_out)
        x_act = F.silu(x_ssm)  # Swish activation can be implemented as x * sigmoid(x)
        #print(f"x_act.shape = {x_act.shape}")

        # residual skip connection with nonlinearity introduced by multiplication
        x_residual = F.silu(self.D(x))
        #print(f"x_residual.shape = {x_residual.shape}")
        x_combined = x_act * x_residual
        #print(f"x_combined.shape = {x_combined.shape}")

        x_out = self.out_proj(x_combined)
        #print(f"x_out.shape = {x_out.shape}")

        return x_out

In [59]:
class Mamba(nn.Module):
    def __init__(self, seq_len, d_model, state_size, device):
        super(Mamba, self).__init__()
        self.mamba_block1 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block2 = MambaBlock(seq_len, d_model, state_size, device)
        self.mamba_block3 = MambaBlock(seq_len, d_model, state_size, device)

    def forward(self, x):
        x = self.mamba_block1(x)
        x = self.mamba_block2(x)
        x = self.mamba_block3(x)
        return x

In [60]:
class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5,
                 device: str ='cuda'):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model, device=device))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [61]:
x = torch.rand(batch_size, seq_len, d_model, device=device)
# Create the Mamba model
mamba = Mamba(seq_len, d_model, state_size, device)

# rmsnorm
norm = RMSNorm(d_model)
x = norm(x)

# Forward pass
test_output = mamba(x)
print(f"test_output.shape = {test_output.shape}")  # Should be [batch_size, seq_len, d_model]

test_output.shape = torch.Size([256, 100, 8])


In [62]:
class Enwiki8Dataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data['input_ids'])

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.data.items()}
        return item

In [63]:
# Define a function for padding
def pad_sequences_3d(sequences, max_len=None, pad_value=0):
    # Assuming sequences is a tensor of shape (batch_size, seq_len, feature_size)
    batch_size, seq_len, feature_size = sequences.shape

    if max_len is None:
        max_len = seq_len + 1


    # Initialize padded_sequences with the pad_value
    padded_sequences = torch.full((batch_size, max_len, feature_size), fill_value=pad_value, dtype=sequences.dtype, device=sequences.device)
    # Pad each sequence to the max_len
    padded_sequences[:, :seq_len, :] = sequences

    return padded_sequences

In [64]:
def train(model, tokenizer, data_loader, optimizer, criterion, device, max_grad_norm=1.0, DEBUGGING_IS_ON=False):
    model.train()
    total_loss = 0
    for batch in data_loader:
        optimizer.zero_grad()

        input_data = batch['input_ids'].clone().to(device)
        attention_mask = batch['attention_mask'].clone().to(device)

        # In most sequence modeling tasks, like language modeling, the target should be the next token
        # in the sequence rather than the input token itself.
        # This is because the model's goal is to predict the next word given the previous words.
        # Shift the input data by one position to get the target, so that each target token
        # is the next token following the input token.
        target = input_data[:, 1:]
        input_data = input_data[:, :-1]

        # Pad all the sequences in the batch:
        input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
        target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

        if USE_MAMBA:
            output = model(input_data)
            loss = criterion(output, target)

        loss.backward(retain_graph=True)

        # Clip gradients: gradients are modified in place
        #torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        for name, param in model.named_parameters():
           if 'out_proj.bias' not in name:
               # clip weights but not bias for out_proj
               torch.nn.utils.clip_grad_norm_(param, max_norm=max_grad_norm)

        if DEBUGGING_IS_ON:
            for name, parameter in model.named_parameters():
                if parameter.grad is not None:
                    print(f"{name} gradient: {parameter.grad.data.norm(2)}")
                else:
                    print(f"{name} has no gradient")

        if USE_MAMBA and DIFFERENT_H_STATES_RECURRENT_UPDATE_MECHANISM:
            model.S6.h[:current_batch_size, ...].copy_(temp_buffer)

        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(data_loader)

In [65]:
def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in data_loader:
            input_data = batch['input_ids'].clone().detach().to(device)
            attention_mask = batch['attention_mask'].clone().detach().to(device)

            # In most sequence modeling tasks, like language modeling, the target should be the next token
            # in the sequence rather than the input token itself.
            # This is because the model's goal is to predict the next word given the previous words.
            # Shift the input data by one position to get the target, so that each target token
            # is the next token following the input token.
            target = input_data[:, 1:]
            input_data = input_data[:, :-1]

            # Pad all the sequences in the batch:
            input_data = pad_sequences_3d(input_data, pad_value=tokenizer.pad_token_id)
            target = pad_sequences_3d(target, max_len=input_data.size(1), pad_value=tokenizer.pad_token_id)

            if USE_MAMBA:
                output = model(input_data)
                loss = criterion(output, target)
            total_loss += loss.item()
    return total_loss / len(data_loader)

In [66]:
def calculate_perplexity(loss):
    return math.exp(loss)

In [67]:
def load_enwiki8_dataset():
    print(f"Download and extract enwiki8 data")
    url = "http://mattmahoney.net/dc/enwik8.zip"
    urllib.request.urlretrieve(url, "enwik8.zip")

    with ZipFile("enwik8.zip") as f:
        data = f.read("enwik8").decode("utf-8")

    return data

In [68]:
# Tokenize and encode the dataset
def encode_dataset(tokenizer, text_data):
    def batch_encode(tokenizer, text_data, batch_size=1000):
        # Tokenize in batches
        batched_input_ids = []
        for i in range(0, len(text_data), batch_size):
            batch = text_data[i:i+batch_size]
            inputs = tokenizer(batch, add_special_tokens=True, truncation=True,
                               padding='max_length', max_length=seq_len,
                               return_tensors='pt')
            batched_input_ids.append(inputs['input_ids'])
        return torch.cat(batched_input_ids)

    # Assuming enwiki8_data is a list of sentences
    input_ids = batch_encode(tokenizer, enwiki8_data)

    # vocab_size is the number of unique tokens in the tokenizer's vocabulary
    global vocab_size
    vocab_size = len(tokenizer.vocab)  # Note that for some tokenizers, we might access the vocab directly
    print(f"vocab_size = {vocab_size}")

    # Create an embedding layer
    # embedding_dim is the size of the embedding vectors (MAMBA model's D)
    embedding_layer = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model)

    # Pass `input_ids` through the embedding layer
    # This will change `input_ids` from shape [B, L] to [B, L, D]
    #encoded_input = embedding_layer(input_ids)   ## this eats memory, so use batched_embedding_calls instead
    def batch_embedding_calls(input_ids, embedding_layer, batch_size=256):
        # Check if input_ids is already a tensor, if not convert it
        if not isinstance(input_ids, torch.Tensor):
            input_ids = torch.tensor(input_ids, dtype=torch.long)

        # Calculate the number of batches needed
        num_batches = math.ceil(input_ids.size(0) / batch_size)

        # List to hold the output embeddings
        output_embeddings = []

        # Process each batch
        for i in range(num_batches):
            # Calculate start and end indices for the current batch
            start_idx = i * batch_size
            end_idx = start_idx + batch_size

            # Get the batch
            input_id_batch = input_ids[start_idx:end_idx]

            # Call the embedding layer
            with torch.no_grad():  # No need gradients for this operation
                batch_embeddings = embedding_layer(input_id_batch)

            # Append the result to the list
            output_embeddings.append(batch_embeddings)

        # Concatenate the embeddings from each batch into a single tensor
        all_embeddings = torch.cat(output_embeddings, dim=0)

        return all_embeddings

    # `input_ids` is a list or tensor of the input IDs and `embedding_layer` is model's embedding layer
    if USE_MAMBA:
        # Set `batch_size` to a value that works for memory constraints
        encoded_inputs = batch_embedding_calls(input_ids, embedding_layer, batch_size=1).float()

    attention_mask = (input_ids != tokenizer.pad_token_id).type(input_ids.dtype)

    return encoded_inputs, attention_mask

In [69]:
# Load a pretrained tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

In [70]:
# Assuming encoded_inputs is a preprocessed tensor of shape [num_samples, seq_len, d_model]
encoded_inputs_file = 'encoded_inputs_mamba.pt'


if os.path.exists(encoded_inputs_file):
    print("Loading pre-tokenized data...")
    encoded_inputs = torch.load(encoded_inputs_file)
else:
    print("Tokenizing raw data...")
    enwiki8_data = load_enwiki8_dataset()
    encoded_inputs, attention_mask = encode_dataset(tokenizer, enwiki8_data)
    torch.save(encoded_inputs, encoded_inputs_file)
    print(f"finished tokenizing data")


# Combine into a single dictionary
data = {
    'input_ids': encoded_inputs,
    'attention_mask': attention_mask
}

# Split the data into train and validation sets
total_size = len(data['input_ids'])
train_size = int(total_size * 0.8)

train_data = {key: val[:train_size] for key, val in data.items()}
val_data = {key: val[train_size:] for key, val in data.items()}

train_dataset = Enwiki8Dataset(train_data)
val_dataset = Enwiki8Dataset(val_data)


train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


# Initialize the model

model = Mamba(seq_len, d_model, state_size, device).to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=5e-6)

# Training loop
num_epochs = 25  # Number of epochs to train for

for epoch in tqdm(range(num_epochs)):  # loop over the dataset multiple times
    train_loss = train(model, tokenizer, train_loader, optimizer, criterion, device, max_grad_norm=10.0, DEBUGGING_IS_ON=False)
    val_loss = evaluate(model, val_loader, criterion, device)
    val_perplexity = calculate_perplexity(val_loss)
    print(f'Epoch: {epoch+1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Perplexity: {val_perplexity:.4f}')

Tokenizing raw data...
Download and extract enwiki8 data
vocab_size = 30522
finished tokenizing data


  4%|▍         | 1/25 [01:20<32:02, 80.11s/it]

Epoch: 1, Training Loss: -4.4533, Validation Loss: -5.7294, Validation Perplexity: 0.0032


  8%|▊         | 2/25 [02:39<30:36, 79.87s/it]

Epoch: 2, Training Loss: -4.4611, Validation Loss: -5.7470, Validation Perplexity: 0.0032


 12%|█▏        | 3/25 [03:59<29:12, 79.67s/it]

Epoch: 3, Training Loss: -4.5059, Validation Loss: -5.8712, Validation Perplexity: 0.0028


 16%|█▌        | 4/25 [05:18<27:50, 79.55s/it]

Epoch: 4, Training Loss: -8.0943, Validation Loss: -21.4852, Validation Perplexity: 0.0000


 20%|██        | 5/25 [06:38<26:33, 79.69s/it]

Epoch: 5, Training Loss: -49.7272, Validation Loss: -82.6938, Validation Perplexity: 0.0000


 24%|██▍       | 6/25 [07:58<25:16, 79.80s/it]

Epoch: 6, Training Loss: -139.8954, Validation Loss: -197.8762, Validation Perplexity: 0.0000


 28%|██▊       | 7/25 [09:18<23:56, 79.80s/it]

Epoch: 7, Training Loss: -307.8996, Validation Loss: -409.9386, Validation Perplexity: 0.0000


 32%|███▏      | 8/25 [10:39<22:41, 80.07s/it]

Epoch: 8, Training Loss: -607.3140, Validation Loss: -776.2809, Validation Perplexity: 0.0000


 36%|███▌      | 9/25 [11:59<21:25, 80.35s/it]

Epoch: 9, Training Loss: -1120.4014, Validation Loss: -1399.4222, Validation Perplexity: 0.0000


 40%|████      | 10/25 [13:20<20:07, 80.50s/it]

Epoch: 10, Training Loss: -1985.1196, Validation Loss: -2438.0631, Validation Perplexity: 0.0000


 44%|████▍     | 11/25 [14:41<18:45, 80.42s/it]

Epoch: 11, Training Loss: -3396.0594, Validation Loss: -4121.0314, Validation Perplexity: 0.0000


 48%|████▊     | 12/25 [16:00<17:23, 80.27s/it]

Epoch: 12, Training Loss: -5671.3287, Validation Loss: -6808.0584, Validation Perplexity: 0.0000


 52%|█████▏    | 13/25 [17:20<16:02, 80.18s/it]

Epoch: 13, Training Loss: -9285.9885, Validation Loss: -11079.8674, Validation Perplexity: 0.0000


 56%|█████▌    | 14/25 [18:41<14:42, 80.24s/it]

Epoch: 14, Training Loss: -15003.7683, Validation Loss: -17821.7826, Validation Perplexity: 0.0000


 60%|██████    | 15/25 [20:01<13:21, 80.15s/it]

Epoch: 15, Training Loss: -23983.7659, Validation Loss: -28335.5159, Validation Perplexity: 0.0000


 64%|██████▍   | 16/25 [21:21<12:01, 80.13s/it]

Epoch: 16, Training Loss: -37907.8720, Validation Loss: -44572.5927, Validation Perplexity: 0.0000


 68%|██████▊   | 17/25 [22:41<10:41, 80.22s/it]

Epoch: 17, Training Loss: -59144.7030, Validation Loss: -69052.8578, Validation Perplexity: 0.0000


 72%|███████▏  | 18/25 [24:01<09:20, 80.13s/it]

Epoch: 18, Training Loss: -90554.4735, Validation Loss: -104812.1774, Validation Perplexity: 0.0000


 76%|███████▌  | 19/25 [25:22<08:01, 80.24s/it]

Epoch: 19, Training Loss: -136133.1153, Validation Loss: -156402.3925, Validation Perplexity: 0.0000


 80%|████████  | 20/25 [26:44<06:43, 80.72s/it]

Epoch: 20, Training Loss: -200947.0880, Validation Loss: -228849.0625, Validation Perplexity: 0.0000


 84%|████████▍ | 21/25 [28:06<05:25, 81.38s/it]

Epoch: 21, Training Loss: -291866.5336, Validation Loss: -330356.0304, Validation Perplexity: 0.0000


 88%|████████▊ | 22/25 [29:29<04:05, 81.74s/it]

Epoch: 22, Training Loss: -417620.5732, Validation Loss: -469720.2248, Validation Perplexity: 0.0000


 92%|█████████▏| 23/25 [30:52<02:44, 82.04s/it]

Epoch: 23, Training Loss: -590254.9548, Validation Loss: -661680.9423, Validation Perplexity: 0.0000


 96%|█████████▌| 24/25 [32:14<01:22, 82.23s/it]

Epoch: 24, Training Loss: -826559.9042, Validation Loss: -922239.3726, Validation Perplexity: 0.0000


100%|██████████| 25/25 [33:35<00:00, 80.62s/it]

Epoch: 25, Training Loss: -1147515.0503, Validation Loss: -1274747.3814, Validation Perplexity: 0.0000



