# Assignment 5: Extended Long Short-Term Memory (xLSTM)

*Author:* Philipp Seidl

*Copyright statement:* This  material,  no  matter  whether  in  printed  or  electronic  form,  may  be  used  for  personal  and non-commercial educational use only.  Any reproduction of this manuscript, no matter whether as a whole or in parts, no matter whether in printed or in electronic form, requires explicit prior acceptance of the authors.

In this assignment, we will explore the xLSTM architecture, a novel extension of the classic LSTM model. The paper can be found here: https://arxiv.org/abs/2405.04517

## Background
Recurrent Neural Networks (RNNs), particularly LSTMs, have proven highly effective in various sequence modeling tasks. However, the emergence of Transformers, with their parallel processing capabilities, has shifted the focus away from LSTMs, especially in large-scale language modeling.
The xLSTM architecture aims to bridge this gap by enhancing LSTMs with mechanisms inspired by modern LLMs (e.g. block-strucutre, residual connections, ...).  Further it introduces:
- Exponential gating with normalization and stabilization techniques, which improves gradient flow and memory capacity.
- Modifications to the LSTM memory structure, resulting in two variants:
    - sLSTM: Employs a scalar memory with a scalar update rule and a new memory mixing technique through recurrent connections.
    - mLSTM: Features a matrix memory, employs a covariance update rule, and is fully parallelizable, making it suitable for scaling.

By integrating these extensions into residual block backbones, xLSTM blocks are formed, which can then be residually stacked to create complete xLSTM architectures.

## Exercise 1: Environment Setup

When working with new architectures or specialized frameworks, it's essential to correctly set up the environment to ensure reproducability. This exercise focuses on setting up the environment for working with the `xlstm` repository.

1. Visit and clone the official repository: [https://github.com/NX-AI/xlstm](https://github.com/NX-AI/xlstm).  
2. Set up the environment  
3. Document your setup:  
   - OS, Python version, Environment setup, CUDA version (if applicable), and GPU details.  
   - Note any challenges you faced and how you resolved them. 
4. Submit your setup as a bash script using the IPython `%%bash` magic. Ensure it is reproducible.

Getting mLSTM working only is fine (if you encounter issues with sLSTM cuda kernels)

> **Note**: Depending on your system setup, you may need to adjust the `environment_pt220cu121.yaml` file, such as for the CUDA version. For this assignment, it is recommended to run it on GPUs. If you don't have one, consider using  [Colab](https://colab.research.google.com/notebooks/welcome.ipynb#recent=true) or other online resources.

> **Recommendations**: While the repository suggests using `conda`, we recommend using `mamba` or `micromamba` instead (way faster) (except if you are using colab). Learn more about them here: [https://mamba.readthedocs.io/en/latest/index.html](https://mamba.readthedocs.io/en/latest/index.html).

In [47]:
%%bash
########## SOLUTION BEGIN ##########
pip install git+https://github.com/NX-AI/xlstm@79e463c84cd8bb839bb9a7d81138f1a0184c68a1
pip install omegaconf
pip install dacite
########## YOUR SOLUTION HERE ##########

Collecting git+https://github.com/NX-AI/xlstm@79e463c84cd8bb839bb9a7d81138f1a0184c68a1
  Cloning https://github.com/NX-AI/xlstm (to revision 79e463c84cd8bb839bb9a7d81138f1a0184c68a1) to /private/var/folders/z1/xfwm3xr90019fcl3by7y113c0000gn/T/pip-req-build-en87xn2t


  Running command git clone --filter=blob:none --quiet https://github.com/NX-AI/xlstm /private/var/folders/z1/xfwm3xr90019fcl3by7y113c0000gn/T/pip-req-build-en87xn2t
  Running command git rev-parse -q --verify 'sha^79e463c84cd8bb839bb9a7d81138f1a0184c68a1'
  Running command git fetch -q https://github.com/NX-AI/xlstm 79e463c84cd8bb839bb9a7d81138f1a0184c68a1
  Running command git checkout -q 79e463c84cd8bb839bb9a7d81138f1a0184c68a1


  Resolved https://github.com/NX-AI/xlstm to commit 79e463c84cd8bb839bb9a7d81138f1a0184c68a1
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'


In [48]:
# Verify your installation of xLSTM:
from omegaconf import OmegaConf
from dacite import from_dict
from dacite import Config as DaciteConfig
from xlstm import xLSTMBlockStack, xLSTMBlockStackConfig
import os
import torch
import time, math

DEVICE = "cuda" if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else "cpu"
print(DEVICE)
use_slstm_kernels = False # set to True if you want to check if sLSTM cuda kernels are working

xlstm_cfg = f"""
mlstm_block:
  mlstm:
    conv1d_kernel_size: 4
    qkv_proj_blocksize: 4
    num_heads: 8
slstm_block:
  slstm:
    backend: {'cuda' if use_slstm_kernels else 'vanilla'}
    num_heads: 4
    conv1d_kernel_size: 4
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 64
num_blocks: 7
embedding_dim: 64
slstm_at: [] # empty = mLSTM only
"""
cfg = OmegaConf.create(xlstm_cfg)
cfg = from_dict(data_class=xLSTMBlockStackConfig, data=OmegaConf.to_container(cfg), config=DaciteConfig(strict=True))
xlstm_stack = xLSTMBlockStack(cfg)

x = torch.randn(4, 64, 64).to(DEVICE)
xlstm_stack = xlstm_stack.to(DEVICE)
y = xlstm_stack(x)
y.shape == x.shape

mps


True

In [49]:
print(xlstm_stack.config)

xLSTMBlockStackConfig(mlstm_block=mLSTMBlockConfig(mlstm=mLSTMLayerConfig(proj_factor=2.0, round_proj_up_dim_up=True, round_proj_up_to_multiple_of=64, _proj_up_dim=128, conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=8, embedding_dim=64, bias=False, dropout=0.0, context_length=64, _num_blocks=7, _inner_embedding_dim=128), _num_blocks=7, _block_idx=None), slstm_block=sLSTMBlockConfig(slstm=sLSTMLayerConfig(hidden_size=64, num_heads=4, num_states=4, backend='vanilla', function='slstm', bias_init='powerlaw_blockdependent', recurrent_weight_init='zeros', _block_idx=None, _num_blocks=7, num_gates=4, gradient_recurrent_cut=False, gradient_recurrent_clipval=None, forward_clipval=None, batch_size=8, input_shape='BSGNH', internal_input_shape='SBNGH', output_shape='BNSH', constants={}, dtype='bfloat16', dtype_b='float32', dtype_r='bfloat16', dtype_w='bfloat16', dtype_g='bfloat16', dtype_s='bfloat16', dtype_a='float32', enable_automatic_mixed_precision=True, initial_val=0.0, embedding_dim=6

## Exercise 2: Understanding xLSTM Hyperparameters
Explain key hyperparameters that influence the performance and behavior of the xLSTM architecture and explain how they influence total parameter count.
The explanation should include: proj_factor, num_heads, act_fn, context_length, num_blocks, embedding_dim, hidden_size, dropout, slstm_at, qkv_proj_blocksize, conv1d_kernel_size. Also include how the matrix memory size of mLSTM is determined.

In [50]:
########## SOLUTION BEGIN ##########

########## YOUR SOLUTION HERE ##########

## Exercise 3: Train an xLSTM model on the Trump Dataset from the previous exercise
Your task is to train an xLSTM model on the Trump Dataset from the previous exercise. 
- The goal is to achieve an average validation loss $\mathcal{L}_{\text{val}} < 1.35$. 
- You do not need to perform an extensive hyperparameter search, but you should document your runs. Log your runs with used hyperparameters using tools like wandb, neptune, mlflow, ... or a similar setup. Log training/validation loss and learning rate over steps as well as total trainable parameters of the model for each run.
- You can use the training setup from the previous exercises or any setup of your choice using high level training libaries.

In [51]:
class SeppGPTResidual(torch.nn.Module):
    def __init__(self, config, hidden_size = 64, vocab_size = 40):
        super().__init__()
        
        # 1 xlstm block
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.xlstm_stack = xLSTMBlockStack(config)
        self.proj = torch.nn.Linear(hidden_size, hidden_size)
        self.ln = torch.nn.LayerNorm(hidden_size)

        # 2 xlstm blocks
        self.xlstm_stack2 = xLSTMBlockStack(config)
        self.proj2 = torch.nn.Linear(hidden_size, vocab_size)
        self.ln2 = torch.nn.LayerNorm(vocab_size)

    def forward(self, x):
        emb = self.embedding(x)
        x = self.xlstm_stack(emb)
        x = self.proj(x)
        x = self.ln(x) + emb
        x = self.xlstm_stack2(x)
        x = self.proj2(x)
        x = self.ln2(x)
        return x

In [52]:
class SeppGPT(torch.nn.Module):
    def __init__(self, config, hidden_size = 64, vocab_size = 40):
        super().__init__()
        self.embedding = torch.nn.Embedding(vocab_size, hidden_size)
        self.xlstm_stack = xLSTMBlockStack(config)
        self.proj = torch.nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        emb = self.embedding(x)
        out1 = self.xlstm_stack(emb)
        out2 = self.proj(out1)
        return out2

In [53]:
model = SeppGPT(xlstm_stack.config).to(DEVICE)
print(model)

SeppGPT(
  (embedding): Embedding(40, 64)
  (xlstm_stack): xLSTMBlockStack(
    (blocks): ModuleList(
      (0-6): 7 x mLSTMBlock(
        (xlstm_norm): LayerNorm()
        (xlstm): mLSTMLayer(
          (proj_up): Linear(in_features=64, out_features=256, bias=False)
          (q_proj): LinearHeadwiseExpand(in_features=128, num_heads=32, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (k_proj): LinearHeadwiseExpand(in_features=128, num_heads=32, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (v_proj): LinearHeadwiseExpand(in_features=128, num_heads=32, expand_factor_up=1, bias=False, trainable_weight=True, trainable_bias=True, )
          (conv1d): CausalConv1d(
            (conv): Conv1d(128, 128, kernel_size=(4,), stride=(1,), padding=(3,), groups=128)
          )
          (conv_act_fn): SiLU()
          (mlstm_cell): mLSTMCell(
            (igate): Linear(in_features=384, out_features=8, bias=True)
   

In [54]:
########## SOLUTION BEGIN ##########
model = SeppGPT(xlstm_stack.config).to(DEVICE)


eval_interval = 200 # validate model every .. iterations
log_interval = 10 # log training loss every .. iterations
eval_iters = 20 # number of batches for loss estimation
gradient_accumulation_steps = 5 # used to simulate larger training batch sizes
batch_size = 6 # if gradient_accumulation_steps > 1, this is the micro-batch size
context_size = 64 # sequence length
vocab = 'abcdefghijklmnopqrstuvwxyz0123456789 .!?' # vocabulary
vocab_size = len(vocab) # 40
n_layer = 8 # number of layers
n_head = 8 # number of attention heads
hidden_size = 64 # layer size
dropout = 1e-5 # for pretraining 0 is good, for finetuning try 0.1+
learning_rate = 1e-3 # max learning rate
max_iters = 3000 # total number of training iterations
weight_decay = 1e-1
beta1 = 0.9 # for AdamW
beta2 = 0.999 # for AdamW
grad_clip = 1.0 # clip gradients at this value, or disable with 0.0
warmup_iters = 100 # how many steps to warm up for
min_lr = 1e-4 # minimum learning rate, usually ~= learning_rate/10

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > max_iters, return min learning rate
    if it > max_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (max_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

def load_data(split):
    import re
    
    with open(f'trump_{split}.txt', 'r') as f:
        text = f.read()
    
    text = text.lower() # convert to lower case
    text = re.sub('[^a-z0-9 .!?]', ' ', text) # replace all unknown chars with ' '
    text = re.sub(' +', ' ', text) # reduce multiple blanks to one
    text = [vocab.index(t) for t in text]
    text = torch.tensor(text, dtype=torch.long, device=DEVICE)
    return text
    
def get_batch(split):
    data = train_data if split == 'train' else val_data
    # Random starting indices (shape: [batch_size])
    ix = torch.randint(len(data) - context_size, (batch_size,), device=DEVICE)
    # Create a 2D index tensor of shape batch_size X context_size
    #  For each element in ix, we want to collect [i, i+1, ..., i+context_size-1].
    #  So we broadcast-add a range of length `context_size` to each element of ix.
    x_positions = ix.unsqueeze(-1) + torch.arange(context_size, device=DEVICE)
    y_positions = x_positions + 1  # Shift by 1
    x = data[x_positions]  # batch_size X context_size
    y = data[y_positions]  # batch_size X context_size
    return x, y

# helps estimate an arbitrarily accurate loss over either split using many batches
@torch.no_grad()
def estimate_loss():
    output = {}
    model.eval()
    for split in ['train', 'val']:
        total_loss = 0.0
        for _ in range(eval_iters):
            X, Y = get_batch(split)
            out = model(X)
            loss = torch.nn.functional.cross_entropy(out.view(-1, vocab_size), Y.view(-1))
            total_loss += loss.item()
        output[split] = total_loss / eval_iters
    model.train()
    return output

# data, model, optimizer, etc.
train_data = load_data('train')
val_data = load_data('val')
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(beta1, beta2), weight_decay=weight_decay)
optimizer.param_groups[0]['lr'] = learning_rate
iter_num = 0
best_val_loss = 1e9
X, Y = get_batch('train') # fetch the very first batch
t0 = time.time()
embedding = torch.nn.Embedding(vocab_size, hidden_size).to(DEVICE)
########## YOUR SOLUTION HERE ##########

for iter_num in range(max_iters):
    optimizer.zero_grad()
    out = model(X)
    loss = torch.nn.functional.cross_entropy(out.view(-1, vocab_size), Y.view(-1))
    if grad_clip > 0.0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    loss.backward()
    optimizer.step()
    if iter_num % log_interval == 0:
        print(f'[{iter_num}/{max_iters}] loss={loss.item()}')
    if iter_num % eval_interval == 0:
        val_loss = estimate_loss()['val']
        print(f'[{iter_num}/{max_iters}] val_loss={val_loss}')
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'best_model.pth')
    X, Y = get_batch('train')
    lr = get_lr(iter_num)
    optimizer.param_groups[0]['lr'] = lr

print(f'training took {time.time()-t0} seconds')
########## YOUR SOLUTION HERE ##########

[0/3000] loss=3.729189157485962
[0/3000] val_loss=3.69344357252121
[10/3000] loss=3.6556808948516846
[20/3000] loss=3.610262632369995
[30/3000] loss=3.4089105129241943
[40/3000] loss=3.139468193054199
[50/3000] loss=2.8497982025146484
[60/3000] loss=2.6778461933135986
[70/3000] loss=2.3739306926727295
[80/3000] loss=2.2838406562805176
[90/3000] loss=2.2506320476531982
[100/3000] loss=2.388808012008667
[110/3000] loss=2.178525924682617
[120/3000] loss=2.1315104961395264
[130/3000] loss=2.0674633979797363
[140/3000] loss=1.9370713233947754
[150/3000] loss=2.032170534133911
[160/3000] loss=1.9046087265014648
[170/3000] loss=1.809989333152771
[180/3000] loss=1.7912341356277466
[190/3000] loss=2.0180132389068604
[200/3000] loss=1.929832935333252
[200/3000] val_loss=2.014889633655548
[210/3000] loss=1.8513693809509277
[220/3000] loss=1.9059723615646362
[230/3000] loss=1.8157931566238403
[240/3000] loss=1.958178162574768
[250/3000] loss=1.6897996664047241
[260/3000] loss=1.6835004091262817
[2

## Exercise 4: Utilizing a Pretrained Model (Bonus)

Foundation Models, those pretrained on large amounts of data are more and more important. We can use those models and fine-tune them on our dataset, rather then training them from scratch.
Here are the things to consider:

- Model Selection: Choose a pretrained language model from an online repository. Hint: You can explore platforms like Hugging Face (huggingface.co), which host numerous pretrained models.

- Dataset: Use the Trump dataset with the same training and validation split as in previous exercises. You do not need to use character tokenization.

- Performance Evaluation: Evaluate the performance of the pretrained model on the validation set before and during fine-tuning. Report average-CE-loss as well as an example generated sequence with the same prompt for each epoch.
 
- Fine-tuning: Adjust the learning rate, potentially freeze some layers, train for a few epochs with a framework of your choice (e.g. [lightning](https://lightning.ai/docs/pytorch/stable/), [huggingface](https://huggingface.co/models), ...)

- Computational Resources: Be mindful of the computational demands of pretrained models. You might need access to GPUs. Try to keep the model size at a minimum and go for e.g. distilled versions or other small LMs

- Hyperparameter Tuning: You can experiment with different learning rates and potentially other hyperparameters during fine-tuning but no need to do this in depth

By completing this exercise, you will gain experience with utilizing pretrained models, understanding their capabilities, and the process of fine-tuning. Decreasing the validation loss can be seen a success for this exercise.

> **Note**: This is a standalone exercise and doesn't build upon the previous tasks.

In [55]:
!pip install transformers



In [56]:
########## SOLUTION BEGIN ##########
# Load model directly
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-3-mini-4k-instruct", trust_remote_code=True)
########## YOUR SOLUTION HERE ##########

Downloading shards:  50%|█████     | 1/2 [00:36<00:36, 36.91s/it]


KeyboardInterrupt: 