### Dependencies

In [None]:
!pip install mamba-ssm
!pip install causal-conv1d>=1.4.0

In [None]:
!pip install numpy tqdm transformers datasets

In [None]:
!pip install wandb

### Dataset

In [15]:
import torch
from datasets import load_dataset
from transformers import GPT2Tokenizer

def load_and_preprocess_data(max_length=128, stride=64):
    dataset = load_dataset("wikitext", "wikitext-2-v1")
    
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
    tokenizer.pad_token = tokenizer.eos_token
    
    def tokenize_function(examples):
        tokenized_inputs = tokenizer(
            examples["text"],
            truncation=True,
            max_length=max_length,
            return_overflowing_tokens=True,
            return_length=True,
            stride=stride,
        )
        
        input_batch = []
        for length, input_ids in zip(tokenized_inputs["length"], tokenized_inputs["input_ids"]):
            if length == max_length:
                input_batch.append(input_ids)
        
        return {"input_ids": input_batch}
    
    tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    tokenized_dataset.set_format(type="torch")
    
    return tokenized_dataset, tokenizer

def create_dataloaders(dataset, batch_size=32):
    train_dataloader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(dataset["validation"], batch_size=batch_size)
    test_dataloader = torch.utils.data.DataLoader(dataset["test"], batch_size=batch_size)
    
    return train_dataloader, val_dataloader, test_dataloader

In [16]:
dataset, tokenizer = load_and_preprocess_data()

In [17]:
train_dataloader, val_dataloader, test_dataloader = create_dataloaders(dataset)

print(f"Vocabulary size: {len(tokenizer)}")
print(f"Train samples: {len(dataset['train'])}")
print(f"Validation samples: {len(dataset['validation'])}")
print(f"Test samples: {len(dataset['test'])}")

Vocabulary size: 50257
Train samples: 8624
Validation samples: 922
Test samples: 1017


In [20]:
# Check a sample batch
for batch in train_dataloader:
    print("Sample batch shape:", batch["input_ids"].shape)
    print("Sample input:", tokenizer.decode(batch["input_ids"][0]))
    break

Sample batch shape: torch.Size([32, 128])
Sample input:  Several terrestrial starlings, including those in the genus Sturnus, have adaptations of the skull and muscles that help with feeding by probing. This adaptation is most strongly developed in the common starling ( along with the spotless and white @-@ <unk> starlings ), where the <unk> muscles responsible for opening the jaw are enlarged and the skull is narrow, allowing the eye to be moved forward to peer down the length of the bill. This technique involves inserting the bill into the ground and opening it as a way of searching for hidden food items. Common starlings have the physical traits that enable them to use this feeding


### Mamba Implementation

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba

class MambaBlock(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand):
        super().__init__()
        self.mamba = Mamba(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand
        )
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x):
        return self.norm(x + self.mamba(x))

class MambaLM(nn.Module):
    def __init__(self, vocab_size, d_model, n_layer, d_state, d_conv, expand):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.layers = nn.ModuleList([
            MambaBlock(d_model, d_state, d_conv, expand)
            for _ in range(n_layer)
        ])
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.layers:
            x = layer(x)
        return self.lm_head(x)

def create_mamba_model(vocab_size, d_model=256, n_layer=4, d_state=16, d_conv=4, expand=2):
    return MambaLM(vocab_size, d_model, n_layer, d_state, d_conv, expand)

### Training

In [27]:
import wandb
wandb.login()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /teamspace/studios/this_studio/.netrc


True

In [40]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import wandb
import math
# from data_preparation import load_and_preprocess_data, create_dataloaders
# from mamba_model import create_mamba_model

def compute_metrics(logits, targets):
    loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
    perplexity = math.exp(loss.item())
    predictions = logits.argmax(dim=-1)
    accuracy = (predictions == targets).float().mean().item()
    return {
        'loss': loss,  # Return the tensor, not the item
        'loss_value': loss.item(),  # Add this for logging
        'perplexity': perplexity,
        'accuracy': accuracy
    }

def train(model, train_dataloader, val_dataloader, num_epochs, lr, device):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    wandb.init(project="mamba-next-word-prediction", config={
        "learning_rate": lr,
        "epochs": num_epochs,
        "batch_size": train_dataloader.batch_size,
        "model_config": model.config if hasattr(model, 'config') else None
    })

    global_step = 0
    for epoch in range(num_epochs):
        model.train()
        epoch_metrics = {'train_loss': 0, 'train_perplexity': 0, 'train_accuracy': 0}
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            input_ids = batch['input_ids'].to(device)
            targets = input_ids.clone()
            targets[:, :-1] = input_ids[:, 1:]
            targets[:, -1] = input_ids[:, 0]

            optimizer.zero_grad()
            outputs = model(input_ids)
            metrics = compute_metrics(outputs, targets)
            loss = metrics['loss']  # This is now a tensor
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            global_step += 1

            # Update epoch metrics
            epoch_metrics['train_loss'] += metrics['loss_value']
            epoch_metrics['train_perplexity'] += metrics['perplexity']
            epoch_metrics['train_accuracy'] += metrics['accuracy']

            if global_step % 50 == 0:
                wandb.log({
                    "train_loss": metrics['loss_value'],
                    "train_perplexity": metrics['perplexity'],
                    "train_accuracy": metrics['accuracy']
                }, step=global_step)

            if global_step % 250 == 0:
                val_metrics = evaluate(model, val_dataloader, device)
                wandb.log(val_metrics, step=global_step)
                model.train()  # Switch back to train mode after evaluation

        # Log epoch-level metrics
        epoch_metrics = {k: v / len(train_dataloader) for k, v in epoch_metrics.items()}
        wandb.log(epoch_metrics, step=global_step)
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_metrics['train_loss']:.4f}, "
              f"Train Perplexity: {epoch_metrics['train_perplexity']:.4f}, "
              f"Train Accuracy: {epoch_metrics['train_accuracy']:.4f}")

        scheduler.step()

    wandb.finish()
    return model

def evaluate(model, dataloader, device):
    model.eval()
    total_metrics = {'val_loss': 0, 'val_perplexity': 0, 'val_accuracy': 0}
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            targets = input_ids.clone()
            targets[:, :-1] = input_ids[:, 1:]
            targets[:, -1] = input_ids[:, 0]

            outputs = model(input_ids)
            metrics = compute_metrics(outputs, targets)
            total_metrics['val_loss'] += metrics['loss_value']
            total_metrics['val_perplexity'] += metrics['perplexity']
            total_metrics['val_accuracy'] += metrics['accuracy']

    avg_metrics = {k: v / len(dataloader) for k, v in total_metrics.items()}
    print(f"Validation Loss: {avg_metrics['val_loss']:.4f}, "
          f"Validation Perplexity: {avg_metrics['val_perplexity']:.4f}, "
          f"Validation Accuracy: {avg_metrics['val_accuracy']:.4f}")
    return avg_metrics

In [41]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

num_epochs = 10
lr = 1e-4
batch_size = 32
d_model=256
n_layer=4
d_state=16
d_conv=4
expand=2

wandb.init(project="mamba-next-word-prediction", config={
    "d_model": d_model,
    "n_layer": n_layer,
    "d_state": d_state,
    "d_conv": d_conv,
    "expand": expand,
    "learning_rate": lr,
    "epochs": num_epochs,
    "batch_size": batch_size,
})

Using device: cuda


In [42]:
dataset, tokenizer = load_and_preprocess_data()
train_dataloader, val_dataloader, _ = create_dataloaders(dataset, batch_size=batch_size)

In [43]:
vocab_size = len(tokenizer)
model = create_mamba_model(
    vocab_size,
    d_model,
    n_layer,
    d_state,
    d_conv,
    expand
)

print(f"Created Mamba model with parameters:")
print(f"d_model: {d_model}")
print(f"n_layer: {n_layer}")
print(f"d_state: {d_state}")
print(f"d_conv: {d_conv}")
print(f"expand: {expand}")

Created Mamba model with parameters:
d_model: 256
n_layer: 4
d_state: 16
d_conv: 4
expand: 2


In [39]:
trained_model = train(model, train_dataloader, val_dataloader, num_epochs, lr, device)

# Save the trained model
torch.save(trained_model.state_dict(), "mamba_lm.pth")

VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112978888887584, max=1.0…

Epoch 1/10:  94%|█████████▎| 253/270 [00:20<00:02,  6.11it/s]

Validation Loss: 5.2376, Validation Perplexity: 191.1554, Validation Accuracy: 0.2721


Epoch 1/10: 100%|██████████| 270/270 [00:21<00:00, 12.31it/s]


Epoch 1/10, Train Loss: 6.3893, Train Perplexity: 1996.8801, Train Accuracy: 0.1848


Epoch 2/10:  86%|████████▌ | 232/270 [00:18<00:06,  6.11it/s]

Validation Loss: 4.9199, Validation Perplexity: 138.7694, Validation Accuracy: 0.2935


Epoch 2/10: 100%|██████████| 270/270 [00:21<00:00, 12.41it/s]


Epoch 2/10, Train Loss: 5.1983, Train Perplexity: 182.9008, Train Accuracy: 0.2532


Epoch 3/10:  79%|███████▊  | 212/270 [00:17<00:09,  6.07it/s]

Validation Loss: 4.8017, Validation Perplexity: 123.4015, Validation Accuracy: 0.3027


Epoch 3/10: 100%|██████████| 270/270 [00:21<00:00, 12.31it/s]


Epoch 3/10, Train Loss: 4.6097, Train Perplexity: 100.8756, Train Accuracy: 0.2893


Epoch 4/10:  71%|███████   | 192/270 [00:15<00:12,  6.05it/s]

Validation Loss: 4.7992, Validation Perplexity: 123.2223, Validation Accuracy: 0.3042


Epoch 4/10: 100%|██████████| 270/270 [00:22<00:00, 12.21it/s]


Epoch 4/10, Train Loss: 4.1611, Train Perplexity: 64.4575, Train Accuracy: 0.3194


Epoch 5/10:  64%|██████▎   | 172/270 [00:14<00:16,  6.01it/s]

Validation Loss: 4.8403, Validation Perplexity: 128.6397, Validation Accuracy: 0.3030


Epoch 5/10: 100%|██████████| 270/270 [00:22<00:00, 12.15it/s]


Epoch 5/10, Train Loss: 3.7786, Train Perplexity: 43.9582, Train Accuracy: 0.3488


Epoch 6/10:  56%|█████▋    | 152/270 [00:12<00:19,  5.99it/s]

Validation Loss: 4.9257, Validation Perplexity: 140.2911, Validation Accuracy: 0.3024


Epoch 6/10: 100%|██████████| 270/270 [00:22<00:00, 12.07it/s]


Epoch 6/10, Train Loss: 3.4383, Train Perplexity: 31.2624, Train Accuracy: 0.3805


Epoch 7/10:  49%|████▉     | 132/270 [00:11<00:23,  5.97it/s]

Validation Loss: 5.0186, Validation Perplexity: 154.2027, Validation Accuracy: 0.2993


Epoch 7/10: 100%|██████████| 270/270 [00:22<00:00, 12.06it/s]


Epoch 7/10, Train Loss: 3.1443, Train Perplexity: 23.3008, Train Accuracy: 0.4134


Epoch 8/10:  41%|████▏     | 112/270 [00:09<00:26,  5.97it/s]

Validation Loss: 5.0936, Validation Perplexity: 166.3232, Validation Accuracy: 0.2967


Epoch 8/10: 100%|██████████| 270/270 [00:22<00:00, 12.05it/s]


Epoch 8/10, Train Loss: 2.9115, Train Perplexity: 18.4464, Train Accuracy: 0.4438


Epoch 9/10:  34%|███▍      | 92/270 [00:08<00:29,  5.94it/s]

Validation Loss: 5.1552, Validation Perplexity: 177.0880, Validation Accuracy: 0.2943


Epoch 9/10: 100%|██████████| 270/270 [00:22<00:00, 12.05it/s]


Epoch 9/10, Train Loss: 2.7507, Train Perplexity: 15.6890, Train Accuracy: 0.4672


Epoch 10/10:  27%|██▋       | 72/270 [00:06<00:33,  5.95it/s]

Validation Loss: 5.1848, Validation Perplexity: 182.4843, Validation Accuracy: 0.2925


Epoch 10/10: 100%|██████████| 270/270 [00:22<00:00, 12.05it/s]


Epoch 10/10, Train Loss: 2.6628, Train Perplexity: 14.3682, Train Accuracy: 0.4812


VBox(children=(Label(value='0.008 MB of 0.022 MB uploaded (0.004 MB deduped)\r'), FloatProgress(value=0.384911…

0,1
train_accuracy,▁▂▂▃▃▃▃▃▄▄▄▄▄▅▄▄▅▅▅▅▆▆▅▅▇▆▆▆▇▇▆▆▇▇█▇████
train_loss,█▇▆▆▆▅▆▅▄▄▅▅▄▄▄▄▃▃▃▃▂▂▃▃▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁
train_perplexity,█▆▄▄▃▂▃▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_accuracy,▁▆████▇▆▆▅
val_loss,█▃▁▁▂▃▅▆▇▇
val_perplexity,█▃▁▁▂▃▄▅▇▇

0,1
train_accuracy,0.48116
train_loss,2.66282
train_perplexity,14.36821
val_accuracy,0.29254
val_loss,5.18484
val_perplexity,182.48425
