<a href="https://colab.research.google.com/github/p20230445-bits/crux-inductions-2025/blob/main/notebooks/Task_2_MAMBA_sentiment_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install dependencies

In [1]:
!pip install datasets --quiet


Imports & device setup

In [2]:
import random
import time
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from datasets import load_dataset

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


Using device: cuda


Load IMDb dataset

In [3]:
dataset = load_dataset("imdb")

train_raw = list(zip(dataset["train"]["label"], dataset["train"]["text"]))
test_raw  = list(zip(dataset["test"]["label"], dataset["test"]["text"]))

# Create validation split (10% of training data)
random.shuffle(train_raw)
split = int(0.9 * len(train_raw))
train_pairs = train_raw[:split]
valid_pairs = train_raw[split:]

print(f"Train: {len(train_pairs)}, Valid: {len(valid_pairs)}, Test: {len(test_raw)}")


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

plain_text/test-00000-of-00001.parquet:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

plain_text/unsupervised-00000-of-00001.p(…):   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Train: 22500, Valid: 2500, Test: 25000


Build vocabulary

In [4]:
MIN_FREQ = 2
PAD_TOKEN = "<pad>"
UNK_TOKEN = "<unk>"

def tokenize(text):
    return text.lower().split()

counter = Counter()
for _, text in train_pairs:
    counter.update(tokenize(text))

vocab = {PAD_TOKEN: 0, UNK_TOKEN: 1}
for word, freq in counter.items():
    if freq >= MIN_FREQ:
        vocab[word] = len(vocab)

PAD_IDX = vocab[PAD_TOKEN]
UNK_IDX = vocab[UNK_TOKEN]

print(f"Vocab size: {len(vocab)}")


Vocab size: 94032


Encode & collate

In [12]:
MAX_LEN = 100

def encode(text):
    tokens = tokenize(text)
    ids = [vocab.get(tok, UNK_IDX) for tok in tokens]
    return ids[:MAX_LEN]

class IMDbDataset(Dataset):
    def __init__(self, pairs):
        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        label, text = self.pairs[idx]
        ids = encode(text)
        return torch.tensor(ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)

def collate_batch(batch):
    ids_list, labels = zip(*batch)
    lengths = torch.tensor([len(x) for x in ids_list], dtype=torch.long)
    max_len = max(lengths).item()
    padded = torch.full((len(ids_list), max_len), PAD_IDX, dtype=torch.long)
    for i, ids in enumerate(ids_list):
        padded[i, :len(ids)] = ids
    return padded, torch.stack(labels), lengths


Data loaders

In [13]:
BATCH_SIZE = 64

train_loader = DataLoader(IMDbDataset(train_pairs), batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
valid_loader = DataLoader(IMDbDataset(valid_pairs), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)
test_loader  = DataLoader(IMDbDataset(test_raw), batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch)


Simplified MAMBA-style block

In [14]:
class SimpleMambaBlock(nn.Module):
    def __init__(self, d_model, hidden_mult=2, dropout=0.1):
        super().__init__()
        inner = hidden_mult * d_model
        self.in_proj = nn.Linear(d_model, inner)
        self.gate_proj = nn.Linear(d_model, inner)
        self.alpha_proj = nn.Linear(d_model, inner)
        self.beta_proj = nn.Linear(d_model, inner)
        self.out_proj = nn.Linear(inner, d_model)
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths):
        B, T, D = x.shape
        h = torch.zeros(B, self.in_proj.out_features, device=x.device)
        x_in = self.in_proj(x)
        x_gate = self.gate_proj(x)
        x_alpha = self.alpha_proj(x)
        x_beta = self.beta_proj(x)
        outs = []
        for t in range(T):
            gate_t = torch.sigmoid(x_gate[:, t, :])
            alpha_t = torch.sigmoid(x_alpha[:, t, :])
            beta_t = F.softplus(x_beta[:, t, :])
            inp_t = x_in[:, t, :]
            h = gate_t * (alpha_t * h + beta_t * inp_t) + (1 - gate_t) * inp_t
            outs.append(h)
        y = torch.stack(outs, dim=1)
        y = self.out_proj(y)
        y = self.dropout(y)
        y = self.norm(x + y)
        return y


Mamba sentiment classifier

In [15]:
class MambaSentimentClassifier(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_layers=2, num_classes=2, pad_idx=0, dropout=0.1):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.blocks = nn.ModuleList([SimpleMambaBlock(d_model, hidden_mult=2, dropout=dropout) for _ in range(n_layers)])
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, ids, lengths):
        x = self.embed(ids)
        for blk in self.blocks:
            x = blk(x, lengths)
        mask = (ids != PAD_IDX).unsqueeze(-1)
        pooled = (x * mask).sum(dim=1) / lengths.clamp(min=1).unsqueeze(-1)
        return self.fc(self.dropout(pooled))


Early stopping

In [16]:
class EarlyStopping:
    def __init__(self, patience=3):
        self.patience = patience
        self.counter = 0
        self.best_acc = 0.0
        self.best_state = None
        self.stop = False

    def step(self, val_acc, model):
        if val_acc > self.best_acc:
            self.best_acc = val_acc
            self.best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.stop = True


Train & evaluate

In [17]:
@torch.no_grad()
def run_eval(model, loader):
    model.eval()
    total, correct = 0, 0
    for ids, labels, lengths in loader:
        ids, labels, lengths = ids.to(DEVICE), labels.to(DEVICE), lengths.to(DEVICE)
        logits = model(ids, lengths)
        preds = logits.argmax(dim=-1)
        total += labels.numel()
        correct += (preds == labels).sum().item()
    return correct / total

def train_model(model, train_loader, valid_loader, epochs=10, lr=2e-3, weight_decay=1e-4, patience=3):
    model.to(DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()
    early = EarlyStopping(patience=patience)

    for ep in range(1, epochs+1):
        model.train()
        running_loss, running_acc, steps = 0, 0, 0
        for ids, labels, lengths in train_loader:
            ids, labels, lengths = ids.to(DEVICE), labels.to(DEVICE), lengths.to(DEVICE)
            logits = model(ids, lengths)
            loss = criterion(logits, labels)
            opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()
            running_loss += loss.item()
            running_acc += (logits.argmax(dim=-1) == labels).float().mean().item()
            steps += 1
        val_acc = run_eval(model, valid_loader)
        print(f"Epoch {ep} | Train Loss: {running_loss/steps:.4f} | Train Acc: {running_acc/steps:.4f} | Val Acc: {val_acc:.4f}")
        early.step(val_acc, model)
        if early.stop:
            print(f"Early stopping triggered at epoch {ep}")
            break
    if early.best_state:
        model.load_state_dict(early.best_state)
    return model


Run training & test

In [18]:
VOCAB_SIZE = len(vocab)
model = MambaSentimentClassifier(VOCAB_SIZE, d_model=128, n_layers=2, num_classes=2, pad_idx=PAD_IDX, dropout=0.2)

trained_model = train_model(model, train_loader, valid_loader, epochs=3, lr=2e-3, patience=2)

test_acc = run_eval(trained_model, test_loader)
print(f"Test Accuracy: {test_acc:.4f}")


Epoch 1 | Train Loss: 0.5532 | Train Acc: 0.7087 | Val Acc: 0.7932
Epoch 2 | Train Loss: 0.3301 | Train Acc: 0.8586 | Val Acc: 0.8180
Epoch 3 | Train Loss: 0.1652 | Train Acc: 0.9357 | Val Acc: 0.8076
Test Accuracy: 0.7959
