## Import library

In [1]:
!pip install xlstm

Collecting xlstm
  Downloading xlstm-2.0.1-py3-none-any.whl.metadata (20 kB)
Collecting reportlab (from xlstm)
  Downloading reportlab-4.2.5-py3-none-any.whl.metadata (1.5 kB)
Collecting joypy (from xlstm)
  Downloading joypy-0.2.6-py2.py3-none-any.whl.metadata (812 bytes)
Collecting ftfy (from xlstm)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
INFO: pip is looking at multiple versions of xlstm to determine which version is compatible with other requirements. This could take a while.
Collecting xlstm
  Downloading xlstm-2.0.0-py3-none-any.whl.metadata (20 kB)
Downloading xlstm-2.0.0-py3-none-any.whl (89 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m89.8/89.8 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xlstm
Successfully installed xlstm-2.0.0


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

import math

import datasets

from tqdm import tqdm
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMLMModel, xLSTMLMModelConfig

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

<torch._C.Generator at 0x7dc8548a4a70>

## Dataset 

In [4]:
### Load the Dataset
dataset = datasets.load_dataset('wikitext', 'wikitext-2-raw-v1')
print(dataset)
print(dataset['train'][88]['text'])

from collections import Counter
import re

class Tokenizer:
    def __init__(self):
        self.pattern = re.compile(r'\b\w+\b|[^\w\s]')
        
    def __call__(self, text):
        return self.pattern.findall(text.lower())

class Vocab:
    def __init__(self, min_freq=3):
        self.stoi = {'<unk>': 0, '<eos>': 1}
        self.itos = ['<unk>', '<eos>']
        self.min_freq = min_freq
        
    def build(self, tokens):
        counter = Counter([t for doc in tokens for t in doc])
        for word, freq in counter.items():
            if freq >= self.min_freq and word not in self.stoi:
                self.stoi[word] = len(self.itos)
                self.itos.append(word)
    
    def __len__(self):
        return len(self.itos)
        
    def __getitem__(self, token):
        return self.stoi.get(token, self.stoi['<unk>'])

tokenizer = Tokenizer()
tokenize_data = lambda example, tokenizer: {'tokens': tokenizer(example['text'])}
tokenized_dataset = dataset.map(tokenize_data, remove_columns=['text'], fn_kwargs={'tokenizer': tokenizer})

vocab = Vocab(min_freq=3)
vocab.build(tokenized_dataset['train']['tokens'])

print(len(vocab))                         # total number words in the vocabulary

def get_data(dataset, vocab, batch_size):
    data = []                                                       # Merge everything into one gigantic document that we wish to model (all the tokens)
    for example in dataset:
        if example['tokens']:                                       # if the example has tokens (not empty)
            tokens = example['tokens'].append('<eos>')              # append <eos> at the end of the sentence
            tokens = [vocab[token] for token in example['tokens']]  # convert tokens to indices
            data.extend(tokens)                                     # append tokens to data
    data = torch.LongTensor(data)                                   # convert data to tensor
    num_batches = data.shape[0] // batch_size 
    data = data[:num_batches * batch_size]                         # We only need the first num_batches * batch_size elements
    data = data.view(batch_size, num_batches)            # Perceive the data as a matrix of batch_size rows and num_batches columns
    return data

#Notice that train_data[:, i] is the batch of next tokens for train_data[:, i - 1] 
batch_size = 256
train_data = get_data(tokenized_dataset['train'], vocab, batch_size)
valid_data = get_data(tokenized_dataset['validation'], vocab, batch_size)
test_data = get_data(tokenized_dataset['test'], vocab, batch_size)

def get_batch(data, seq_len, num_batches, idx):
    src = data[:, idx:idx+seq_len]                   
    target = data[:, idx+1:idx+seq_len+1]             # The target is the src shifted by one batch
    return src, target

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/733k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/6.36M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/657k [00:00<?, ?B/s]

Generating test split:   0%|          | 0/4358 [00:00<?, ? examples/s]

Generating train split:   0%|          | 0/36718 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3760 [00:00<?, ? examples/s]

DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})
 This ammunition , and that which I brought with me , was rapidly prepared for use at the Laboratory established at the Little Rock Arsenal for that purpose . As illustrating as the pitiful scarcity of material in the country , the fact may be stated that it was found necessary to use public documents of the State Library for cartridge paper . Gunsmiths were employed or conscripted , tools purchased or impressed , and the repair of the damaged guns I brought with me and about an equal number found at Little Rock commenced at once . But , after inspecting the work and observing the spirit of the men I decided that a garrison 500 strong could hold out against Fitch and that I would lead the remainder - about 1500 - to Gen 'l Rust as 

Map:   0%|          | 0/4358 [00:00<?, ? examples/s]

Map:   0%|          | 0/36718 [00:00<?, ? examples/s]

Map:   0%|          | 0/3760 [00:00<?, ? examples/s]

29482


## xLSTM

In [5]:
### Define xLSTM Configuration
xlstm_cfg = """ 
vocab_size: 29482
mlstm_block:
  mlstm:
    conv1d_kernel_size: 8
    qkv_proj_blocksize: 8
    num_heads: 8
slstm_block:
  slstm:
    backend: cuda
    num_heads: 8
    conv1d_kernel_size: 8
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 8
embedding_dim: 128
slstm_at: [1]
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMLMModelConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMLMModel(cfg)
vocab_size = len(vocab)                 
lr = 1e-3                        # They used 30 and a different optimizer
model = xlstm_stack.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {num_params:,} trainable parameters')

{'verbose': True, 'with_cuda': True, 'extra_ldflags': ['-L/usr/local/cuda/lib', '-lcublas'], 'extra_cflags': ['-DSLSTM_HIDDEN_SIZE=128', '-DSLSTM_BATCH_SIZE=8', '-DSLSTM_NUM_HEADS=8', '-DSLSTM_NUM_STATES=4', '-DSLSTM_DTYPE_B=float', '-DSLSTM_DTYPE_R=__nv_bfloat16', '-DSLSTM_DTYPE_W=__nv_bfloat16', '-DSLSTM_DTYPE_G=__nv_bfloat16', '-DSLSTM_DTYPE_S=__nv_bfloat16', '-DSLSTM_DTYPE_A=float', '-DSLSTM_NUM_GATES=4', '-DSLSTM_SIMPLE_AGG=true', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL_VALID=false', '-DSLSTM_GRADIENT_RECURRENT_CLIPVAL=0.0', '-DSLSTM_FORWARD_CLIPVAL_VALID=false', '-DSLSTM_FORWARD_CLIPVAL=0.0', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', '-U__CUDA_NO_BFLOAT16_OPERATORS__', '-U__CUDA_NO_BFLOAT16_CONVERSIONS__', '-U__CUDA_NO_BFLOAT162_OPERATORS__', '-U__CUDA_NO_BFLOAT162_CONVERSIONS__'], 'extra_cuda_cflags': ['-Xptxas="-v"', '-gencode', 'arch=compute_80,code=compute_80', '-res-usage', '--use_fast_math', '-O3', '-Xptxas -O3', '--extra-device-vectorization', '-DSLSTM_

Using /root/.cache/torch_extensions/py310_cu121 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py310_cu121/slstm_HS128BS8NH8NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py310_cu121/slstm_HS128BS8NH8NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module slstm_HS128BS8NH8NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module slstm_HS128BS8NH8NS4DBfDRbDWbDGbDSbDAfNG4SA1GRCV0GRC0d0FCV0FC0d0...
  def forward(ctx, training, *inputs):
  def backward(ctx, grad_s):


The model has 8,477,552 trainable parameters


## Train

In [7]:
def train(model, data, optimizer, criterion, batch_size, seq_len, clip, device):
    
    epoch_loss = 0
    model.train()
    # drop all batches that are not a multiple of seq_len
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = None  # xLSTM does not require explicit hidden state initialization
    
    for idx in tqdm(range(0, num_batches - 1, seq_len), desc='Training: ',leave=False):  # The last batch can't be a src
        optimizer.zero_grad()

        src, target = get_batch(data, seq_len, num_batches, idx)
        src, target = src.to(device), target.to(device)
        batch_size = src.shape[0]
        prediction = model(src)                 # model output

        prediction = prediction.reshape(batch_size * seq_len, -1)   
        target = target.reshape(-1)
        loss = criterion(prediction, target)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches
def evaluate(model, data, criterion, batch_size, seq_len, device):

    epoch_loss = 0
    model.eval()
    num_batches = data.shape[-1]
    data = data[:, :num_batches - (num_batches -1) % seq_len]
    num_batches = data.shape[-1]

    hidden = None  # xLSTM does not require explicit hidden state initialization

    with torch.no_grad():
        for idx in range(0, num_batches - 1, seq_len):
            src, target = get_batch(data, seq_len, num_batches, idx)
            src, target = src.to(device), target.to(device)
            batch_size= src.shape[0]

            prediction = model(src)
            prediction = prediction.reshape(batch_size * seq_len, -1)
            target = target.reshape(-1)

            loss = criterion(prediction, target)
            epoch_loss += loss.item() * seq_len
    return epoch_loss / num_batches
n_epochs = 50
seq_len = 50
clip = 0.25
saved = False

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=0)

if saved:
    model.load_state_dict(torch.load('best-val-xlstm_lm.pt',  map_location=device))
    test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
    print(f'Test Perplexity: {math.exp(test_loss):.3f}')
else:
    best_valid_loss = float('inf')

    for epoch in range(n_epochs):
        print(f"Epoch {epoch+1}:")
        train_loss = train(model, train_data, optimizer, criterion, batch_size, seq_len, clip, device)
        valid_loss = evaluate(model, valid_data, criterion, batch_size, seq_len, device)
        
        lr_scheduler.step(valid_loss)

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), 'best-xlstm.pt')

        print(f'\tTrain Perplexity: {math.exp(train_loss):.3f}')
        print(f'\tValid Perplexity: {math.exp(valid_loss):.3f}')


Epoch 1:


                                                           

	Train Perplexity: 690.731
	Valid Perplexity: 283.054
Epoch 2:


                                                           

	Train Perplexity: 257.716
	Valid Perplexity: 208.917
Epoch 3:


                                                           

	Train Perplexity: 184.550
	Valid Perplexity: 179.101
Epoch 4:


                                                           

	Train Perplexity: 145.117
	Valid Perplexity: 163.054
Epoch 5:


                                                           

	Train Perplexity: 120.086
	Valid Perplexity: 155.606
Epoch 6:


                                                           

	Train Perplexity: 102.757
	Valid Perplexity: 151.040
Epoch 7:


                                                           

	Train Perplexity: 90.045
	Valid Perplexity: 149.421
Epoch 8:


                                                           

	Train Perplexity: 80.145
	Valid Perplexity: 150.094
Epoch 9:


                                                           

	Train Perplexity: 69.643
	Valid Perplexity: 146.796
Epoch 10:


                                                           

	Train Perplexity: 65.249
	Valid Perplexity: 148.449
Epoch 11:


                                                           

	Train Perplexity: 60.433
	Valid Perplexity: 147.769
Epoch 12:


                                                           

	Train Perplexity: 57.851
	Valid Perplexity: 147.175
Epoch 13:


                                                           

	Train Perplexity: 56.581
	Valid Perplexity: 147.071
Epoch 14:


                                                           

	Train Perplexity: 55.954
	Valid Perplexity: 146.950
Epoch 15:


                                                           

	Train Perplexity: 55.470
	Valid Perplexity: 146.838
Epoch 16:


                                                           

	Train Perplexity: 55.169
	Valid Perplexity: 146.700
Epoch 17:


                                                           

	Train Perplexity: 55.068
	Valid Perplexity: 146.717
Epoch 18:


                                                           

	Train Perplexity: 54.928
	Valid Perplexity: 146.640
Epoch 19:


                                                           

	Train Perplexity: 54.851
	Valid Perplexity: 146.605
Epoch 20:


                                                           

	Train Perplexity: 54.829
	Valid Perplexity: 146.597
Epoch 21:


                                                           

	Train Perplexity: 54.791
	Valid Perplexity: 146.594
Epoch 22:


                                                           

	Train Perplexity: 54.771
	Valid Perplexity: 146.594
Epoch 23:


                                                           

	Train Perplexity: 54.761
	Valid Perplexity: 146.594
Epoch 24:


                                                           

	Train Perplexity: 54.756
	Valid Perplexity: 146.594
Epoch 25:


                                                           

	Train Perplexity: 54.753
	Valid Perplexity: 146.594
Epoch 26:


                                                           

	Train Perplexity: 54.752
	Valid Perplexity: 146.594
Epoch 27:


                                                           

	Train Perplexity: 54.752
	Valid Perplexity: 146.594
Epoch 28:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 29:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 30:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 31:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 32:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 33:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 34:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 35:


                                                           

	Train Perplexity: 54.751
	Valid Perplexity: 146.593
Epoch 36:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.593
Epoch 37:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.593
Epoch 38:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.593
Epoch 39:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.592
Epoch 40:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.592
Epoch 41:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.592
Epoch 42:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.592
Epoch 43:


                                                           

	Train Perplexity: 54.750
	Valid Perplexity: 146.592
Epoch 44:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592
Epoch 45:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592
Epoch 46:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592
Epoch 47:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592
Epoch 48:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592
Epoch 49:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592
Epoch 50:


                                                           

	Train Perplexity: 54.749
	Valid Perplexity: 146.592


## Test


In [9]:
model.load_state_dict(torch.load('/kaggle/working/best-xlstm.pt',  map_location=device))
test_loss = evaluate(model, test_data, criterion, batch_size, seq_len, device)
print(f'Test Perplexity: {math.exp(test_loss):.3f}')

  model.load_state_dict(torch.load('/kaggle/working/best-xlstm.pt',  map_location=device))


Test Perplexity: 137.482
