# 1. Install Dependencies

In [1]:
# Install required libraries
!pip install xlstm transformers datasets torch
!apt-get install git-lfs

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.3).
0 upgraded, 0 newly installed, 0 to remove and 30 not upgraded.


# 2. Preprocess Data

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import BertTokenizerFast
import time
import torch
import transformers

transformers.set_seed(42)
torch.manual_seed(42)

BATCH_SIZE = 32
MAX_LEN = 256
NUM_BLOCKS = 2
NUM_HEADS = 4
EPOCHS = 20
EMBED_DIM = 128
NUM_CLASSES = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
dataset = load_dataset("imdb")

class IMDBDataset(Dataset):
    def __init__(self, split):
        self.texts = dataset[split]['text']
        self.labels = dataset[split]['label']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        encoded = tokenizer(
            self.texts[idx],
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors='pt'
        )
        input_ids = encoded['input_ids'].squeeze(0)
        return input_ids, torch.tensor(self.labels[idx], dtype=torch.long)

train_data = IMDBDataset('train')
test_data = IMDBDataset('test')
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=True)

print(len(train_loader.dataset.labels))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


25000


In [3]:
import torch
import torch.nn as nn
from xlstm.xlstm_large.model import xLSTMLargeConfig, xLSTMLarge

# Define model configuration
config = xLSTMLargeConfig(
    embedding_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_blocks=NUM_BLOCKS,
    vocab_size=tokenizer.vocab_size,
    return_last_states=False,
    mode="inference",
    chunkwise_kernel="chunkwise--native_autograd",
    sequence_kernel="native_sequence__native",
    step_kernel="native",
)

class xLSTMClassifier(nn.Module):
    def __init__(self, config, num_classes=2):
        super().__init__()
        self.backbone = xLSTMLarge(config)

        if hasattr(self.backbone, "lm_head"):
            del self.backbone.lm_head

        self.backbone.lm_head = nn.Linear(EMBED_DIM, num_classes)

    def forward(self, x):
        x = self.backbone(x)
        if x.ndim == 3: x = x[:, -1, :]
        return x

# Initialize model
model = xLSTMClassifier(config).to(device)

model

xLSTMClassifier(
  (backbone): xLSTMLarge(
    (embedding): Embedding(30522, 128)
    (backbone): xLSTMLargeBlockStack(
      (blocks): ModuleList(
        (0-1): 2 x mLSTMBlock(
          (norm_mlstm): RMSNorm()
          (mlstm_layer): mLSTMLayer(
            (q): Linear(in_features=128, out_features=64, bias=False)
            (k): Linear(in_features=128, out_features=64, bias=False)
            (v): Linear(in_features=128, out_features=128, bias=False)
            (ogate_preact): Linear(in_features=128, out_features=128, bias=False)
            (igate_preact): Linear(in_features=128, out_features=4, bias=True)
            (fgate_preact): Linear(in_features=128, out_features=4, bias=True)
            (ogate_act_fn): Sigmoid()
            (mlstm_backend): mLSTMBackend(mLSTMBackendConfig(chunkwise_kernel='chunkwise--native_autograd', sequence_kernel='native_sequence__native', step_kernel='native', mode='inference', chunk_size=64, return_last_states=False, autocast_kernel_dtype='bfloat

In [4]:
pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Number of trainable parameters:", pytorch_total_params)

Number of trainable parameters: 4336018


# 3. Results

In [5]:
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

# Define optimizer and loss function
optimizer = AdamW(model.parameters(), lr=1e-2, betas=(0.9, 0.99), weight_decay=0.1)
criterion = torch.nn.CrossEntropyLoss()

# Define learning rate scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-5)

def train():
    model.train()
    total, correct = 0, 0
    start = time.time()
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

        pred = outputs.argmax(1)
        correct += (pred == labels).sum().item()
        total += labels.size(0)
    print(f"Train Acc: {correct/total*100:.2f}% | Time: {time.time() - start:.1f}s")


def evaluate(best_acc, patience):
    model.eval()
    total, correct = 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            pred = outputs.argmax(1)
            correct += (pred == labels).sum().item()
            total += labels.size(0)
    acc = correct/total
    print(f"Test Acc: {acc*100:.2f}%")

    if acc > best_acc:
        best_acc = acc
        patience = 5
    else:
        patience -= 1
    return best_acc, patience

best_acc = 0
patience = 5
EPOCHS = 30
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch+1}/{EPOCHS}")
    train()
    best_acc, patience = evaluate(best_acc, patience)

    if patience == 0:
      print("Early stopping!")
      break



Epoch 1/30
Train Acc: 50.05% | Time: 61.7s
Test Acc: 50.00%

Epoch 2/30
Train Acc: 49.12% | Time: 61.2s
Test Acc: 50.00%

Epoch 3/30
Train Acc: 49.35% | Time: 60.5s
Test Acc: 50.00%

Epoch 4/30
Train Acc: 49.84% | Time: 60.6s
Test Acc: 49.22%

Epoch 5/30
Train Acc: 50.22% | Time: 60.2s
Test Acc: 50.00%

Epoch 6/30
Train Acc: 50.78% | Time: 60.3s
Test Acc: 50.01%

Epoch 7/30
Train Acc: 51.24% | Time: 60.3s
Test Acc: 50.00%

Epoch 8/30
Train Acc: 52.98% | Time: 60.3s
Test Acc: 50.96%

Epoch 9/30
Train Acc: 54.59% | Time: 60.4s
Test Acc: 51.17%

Epoch 10/30
Train Acc: 55.44% | Time: 60.2s
Test Acc: 52.46%

Epoch 11/30
Train Acc: 58.15% | Time: 60.2s
Test Acc: 52.58%

Epoch 12/30
Train Acc: 59.65% | Time: 60.2s
Test Acc: 52.61%

Epoch 13/30
Train Acc: 61.74% | Time: 60.7s
Test Acc: 53.05%

Epoch 14/30
Train Acc: 62.97% | Time: 60.1s
Test Acc: 54.57%

Epoch 15/30
Train Acc: 63.59% | Time: 60.3s
Test Acc: 53.22%

Epoch 16/30
Train Acc: 64.50% | Time: 60.2s
Test Acc: 54.14%

Epoch 17/30
Trai

Best accuracy obtained: 84.38%

In [6]:
peak_memory = torch.cuda.max_memory_allocated() / (2 ** 20)
print(f"Peak memory usage: {peak_memory:.2f} MB")

Peak memory usage: 373.97 MB
