In [None]:
!pip install torchinfo -q

In [None]:
!pip install xlstm --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/117.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.0/117.0 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m91.5/91.5 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m349.0/349.0 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.8/422.8 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m69.5 MB/s[0m eta [36m0:00

In [None]:
import torch
from torch import nn
from torchinfo import summary

from xlstm import (
    xLSTMBlockStack,
    xLSTMBlockStackConfig,
    mLSTMBlockConfig,
    mLSTMLayerConfig,
    sLSTMBlockConfig,
    sLSTMLayerConfig,
    FeedForwardConfig,
)

SEQ_LENGTH_XLSTM = 150

In [None]:
def create_xlstm_model(seq_length, num_blocks, slstm_pos, num_heads=2, conv1d_kernel_size=2, proj_factor=1.1):
    # Define your input size, hidden size, and other relevant parameters
    input_size = 1  # Number of features in your time series
    embedding_dim = 64  # Dimension of the embeddings, reduced to save memory
    output_size = 1  # Number of output features (predicting the next value)

    # Define the xLSTM configuration
    cfg = xLSTMBlockStackConfig(
        mlstm_block=mLSTMBlockConfig(
            mlstm=mLSTMLayerConfig(
                conv1d_kernel_size=conv1d_kernel_size, qkv_proj_blocksize=2, num_heads=num_heads  # Reduced parameters to save memory
            )
        ),
        slstm_block=sLSTMBlockConfig(
            slstm=sLSTMLayerConfig(
                backend="cuda",
                num_heads=num_heads,  # Reduced number of heads to save memory
                conv1d_kernel_size=conv1d_kernel_size,  # Reduced kernel size to save memory
                bias_init="powerlaw_blockdependent",
            ),
            feedforward=FeedForwardConfig(proj_factor=proj_factor, act_fn="gelu"),  # Reduced projection factor to save memory
        ),
        context_length=seq_length,
        num_blocks=num_blocks,  # Reduced number of blocks to save memory
        embedding_dim=embedding_dim,
        slstm_at=slstm_pos,
    )

    # Instantiate the xLSTM stack
    xlstm_stack = xLSTMBlockStack(cfg).to("cuda")

    # Add a linear layer to project input data to the required embedding dimension
    input_projection = nn.Linear(input_size, embedding_dim).to("cuda")

    # Add a final linear layer to project the xLSTM output to the desired output size
    output_projection = nn.Linear(embedding_dim, output_size).to("cuda")

    return xlstm_stack, input_projection, output_projection

In [None]:
import time
from functools import wraps

def timeit(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        architecture = kwargs.get('architecture', 'Unknown Architecture')
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        elapsed_time = end_time - start_time
        print(f"\nArchitecture - {architecture} took {elapsed_time:.2f} seconds to train.")
        return result
    return wrapper

In [None]:
import torch
from tqdm import tqdm

@timeit
def train_model(epochs, model, input_projection, output_projection, train_data, optimizer, criterion, architecture="Unnamed"):

    losses = []
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    start = time.perf_counter()
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        progress = tqdm(train_data, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        for inputs, targets in progress:
            projected_input_data = input_projection(inputs)
            xlstm_output = model(projected_input_data)
            predictions = output_projection(xlstm_output)  # Last timestep output

            predictions = predictions.squeeze()
            batch_y = targets.squeeze()

            loss = criterion(predictions, batch_y)
            losses.append(loss.cpu().detach().item())

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            epoch_loss += loss.item()
            progress.set_postfix(loss=loss.item())

        avg_epoch_loss = epoch_loss / len(train_data)
        print(f"Epoch {epoch+1}: Avg Loss = {avg_epoch_loss:.6f}")
        losses.append(avg_epoch_loss)

    end = time.perf_counter()
    elapsed = end - start
    return losses, elapsed


In [None]:
BATCH_SIZE = 3
SEQUENCE_LENGTH = 8

In [None]:
def train_test_val_split(dataset, train_percent=0.7, val_percent=0.15):
    total_len = len(dataset)
    train_size = int(total_len * train_percent)
    val_size = int(total_len * val_percent)
    test_size = total_len - train_size - val_size

    train_data = dataset[:train_size]
    val_data = dataset[train_size:train_size + val_size]
    test_data = dataset[train_size + val_size:]

    return train_data, val_data, test_data

In [None]:
import numpy as np

def create_dataset(dataset, seq_len=SEQUENCE_LENGTH):
    dataX, dataY = [], []

    for i in range(len(dataset) - seq_len):
        a = dataset[i:(i + seq_len - 1)]
        dataX.append(a)
        dataY.append(dataset[i + seq_len - 1])

    return torch.Tensor(np.array(dataX)).to('cuda'), torch.Tensor(np.array(dataY)).to('cuda')

In [None]:
from torch.utils.data import TensorDataset, DataLoader

def create_dataloader(x, y, is_train=True):
    return DataLoader(TensorDataset(x, y), batch_size=BATCH_SIZE, shuffle=True if is_train else False)

In [None]:
import matplotlib.pyplot as plt

def plot_losses(losses, title=None):
    plt.figure(figsize=(10, 8))
    plt.title(title)
    plt.plot(range(len(losses)), losses)
    plt.xlabel('Num. Batch Epochs')
    plt.ylabel('Loss')
    plt.show()

In [None]:
import numpy as np

# ─── Helper Generators for Modular Tasks ───────────────────────────────────

def generate_one_expression_and_result(modulus: int, length: int, mult: bool = False):
    """Generates a modular arithmetic expression with brackets and its result."""
    def gen_terminal():
        terminal = np.random.randint(low=0, high=modulus)
        return str(terminal), terminal

    if length < 1:
        raise ValueError(f'Length must be ≥1, got {length}')
    if length == 1:
        return gen_terminal()
    if length == 2:
        s, v = gen_terminal()
        return f'-{s}', (-v) % modulus
    if length == 3:
        s, v = gen_terminal()
        return f'({s})', v % modulus
    if length == 4:
        s, v = gen_terminal()
        return f'(-{s})', (-v) % modulus

    # otherwise split
    left_len  = np.random.randint(1, length-2)
    right_len = length - (left_len + 3)
    ls, lv = generate_one_expression_and_result(modulus, left_len,  mult)
    rs, rv = generate_one_expression_and_result(modulus, right_len, mult)
    maxop = 3 if mult else 2
    op = np.random.randint(0, maxop)
    if op == 0:
        return f'({ls}+{rs})', (lv + rv) % modulus
    if op == 1:
        return f'({ls}-{rs})', (lv - rv) % modulus
    return f'({ls}*{rs})', (lv * rv) % modulus

def generate_raw_dataset(n: int, lengths: list, modulus: int, mult: bool=False):
    """
    Generates modular‐arithmetic‐with‐brackets data.
    Returns dict: length → {'expressions': [...], 'results': [...]}
    """
    alphabet_to_int = {'+':modulus, '-':modulus+1, '*':modulus+2, '(':modulus+3, ')':modulus+4}
    for d in range(modulus):
        alphabet_to_int[str(d)] = d

    out = {}
    for L in lengths:
        exprs, res = [], []
        for _ in range(n // len(lengths)):
            s, v = generate_one_expression_and_result(modulus, L, mult)
            exprs.append([alphabet_to_int[c] for c in s])
            res.append(v)
        out[L] = {'expressions': exprs, 'results': res}
    return out

def generate_equation_and_solution(modulus: int, length: int):
    """Generates a modular equation with 'x' unknown and its solution."""
    expr, val = generate_one_expression_and_result(modulus, length-2, mult=False)
    # pick a digit to replace with x
    idx = np.random.randint(0, len(expr))
    digits = [str(d) for d in range(modulus)]
    while expr[idx] not in digits:
        idx = (idx+1) % len(expr)
    sol = int(expr[idx])
    eq = f"{expr[:idx]}x{expr[idx+1:]}={val}"
    return eq, sol

def generate_raw_equation_dataset(n: int, lengths: list, modulus: int):
    """
    Generates modular‐equation‐solving data.
    Returns dict: length → {'equations': [...], 'solutions': [...]}
    """
    alphabet_to_int = {'+':modulus, '-':modulus+1, '(':modulus+2, ')':modulus+3, 'x':modulus+4, '=':modulus+5}
    for d in range(modulus):
        alphabet_to_int[str(d)] = d

    out = {}
    for L in lengths:
        eqs, sols = [], []
        for _ in range(n // len(lengths)):
            s, v = generate_equation_and_solution(modulus, L)
            eqs.append([alphabet_to_int[c] for c in s])
            sols.append(v)
        out[L] = {'equations': eqs, 'solutions': sols}
    return out

# ─── 1. Bucket Sort ───────────────────────────────────────────────────────

def bucket_sort_dataset(n: int, length: int):
    """
    X: (n, length,1), y: (n, length,1) sorted values
    """
    X = np.random.random((n, length, 1))
    y = np.sort(X, axis=1)
    return X, y

# ─── 2. Missing Duplicates ────────────────────────────────────────────────

def missing_duplicates_dataset(n: int, length: int):
    """
    X: (n, length,1), y: pad unique values to length
    """
    X = np.random.random((n, length, 1))
    y = np.zeros((n, length, 1), dtype=X.dtype)
    for i, seq in enumerate(X):
        uniq = np.unique(seq.squeeze())
        L = uniq.shape[0]
        y[i, :L, 0] = uniq
    return X, y

# ─── 3. Modular Arithmetic with Brackets ──────────────────────────────────

def modular_arith_dataset(n: int, lengths: list, modulus: int, mult: bool=False):
    raw = generate_raw_dataset(n, lengths, modulus, mult)
    out = {}
    for L in lengths:
        exprs = raw[L]['expressions']
        res   = raw[L]['results']
        X = np.eye(modulus+5)[exprs]   # one-hot size = digits + +,-,*,(,)
        y = np.array(res)
        out[L] = (X, y)
    return out

# ─── 4. Solve Equation for x ──────────────────────────────────────────────

def solve_equation_dataset(n: int, lengths: list, modulus: int):
    raw = generate_raw_equation_dataset(n, lengths, modulus)
    out = {}
    for L in lengths:
        eqs = raw[L]['equations']
        sols= raw[L]['solutions']
        X = np.eye(modulus+6)[eqs]     # one-hot size = digits + +,-,(,),x,=
        y = np.array(sols)
        out[L] = (X, y)
    return out

# ─── 5. Cycle Navigation ──────────────────────────────────────────────────

def cycle_navigation_dataset(n: int, length: int, cycle_length: int=5):
    actions = np.random.randint(0,3, size=(n, length))
    final   = (np.sum(actions-1, axis=1) % cycle_length)
    X = np.eye(3)[actions]
    y = np.eye(cycle_length)[final]
    return X, y

# ─── 6. Even Pairs ────────────────────────────────────────────────────────

def even_pairs_dataset(n: int, length: int):
    bits   = np.random.randint(0,2, size=(n, length))
    unequal= np.logical_xor(bits[:,:-1], bits[:,1:])
    labels = np.sum(unequal, axis=1) % 2
    X = np.eye(2)[bits]
    y = np.eye(2)[labels]
    return X, y

# ─── 7. Parity Check ──────────────────────────────────────────────────────

def parity_check_dataset(n: int, length: int):
    bits   = np.random.randint(0,2, size=(n, length))
    labels = np.sum(bits, axis=1) % 2
    X = np.eye(2)[bits]
    y = np.eye(2)[labels]
    return X, y


## Missing Duplicates

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
N_SAMPLES = 10000
SEQ_LENGTH = 8

In [None]:
X, y = missing_duplicates_dataset(N_SAMPLES, SEQ_LENGTH)

# 2. Train/test split (80/20)
split_idx = int(0.8 * N_SAMPLES)
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]

# 3. Convert to torch tensors
X_train = torch.tensor(X_train, dtype=torch.float32).to('cuda')
y_train = torch.tensor(y_train, dtype=torch.float32).to('cuda')
X_test  = torch.tensor(X_test,  dtype=torch.float32).to('cuda')
y_test  = torch.tensor(y_test,  dtype=torch.float32).to('cuda')

# 4. Create DataLoaders
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
test_loader  = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE, shuffle=False)

In [None]:
xlstm_stack, input_proj, output_proj = create_xlstm_model(
    seq_length=SEQ_LENGTH,
    num_blocks=2,
    slstm_pos=[1],
    num_heads=2,
    conv1d_kernel_size=2,
    proj_factor=1.1
)

  @conditional_decorator(
  @conditional_decorator(


In [None]:
optimizer = optim.Adam(
    list(xlstm_stack.parameters()) +
    list(input_proj.parameters()) +
    list(output_proj.parameters()),
    lr=1e-3
)

# Since this is regression (sorting real values), use MSE
criterion = nn.MSELoss()

In [None]:
targ, inp = next(iter(train_loader))
targ.shape, inp.shape

(torch.Size([3, 8, 1]), torch.Size([3, 8, 1]))

In [None]:
losses, duration = train_model(
    epochs=20,
    model=xlstm_stack,
    input_projection=input_proj,
    output_projection=output_proj,
    train_data=train_loader,
    optimizer=optimizer,
    criterion=criterion,
    architecture="xLSTM-MissingDuplicates"
)



Epoch 1: Avg Loss = 0.012612




Epoch 2: Avg Loss = 0.011576




Epoch 3: Avg Loss = 0.011052




Epoch 4: Avg Loss = 0.010697




Epoch 5: Avg Loss = 0.010448




Epoch 6: Avg Loss = 0.010312




Epoch 7: Avg Loss = 0.010259




Epoch 8: Avg Loss = 0.010151




Epoch 9: Avg Loss = 0.010070




Epoch 10: Avg Loss = 0.010027




Epoch 11: Avg Loss = 0.010027




Epoch 12: Avg Loss = 0.009935




Epoch 13: Avg Loss = 0.009946




Epoch 14: Avg Loss = 0.009936




Epoch 15: Avg Loss = 0.009888




Epoch 16: Avg Loss = 0.009878




Epoch 17: Avg Loss = 0.009859




Epoch 18: Avg Loss = 0.009849




Epoch 19: Avg Loss = 0.009800


                                                                             

Epoch 20: Avg Loss = 0.009819

Architecture - xLSTM-MissingDuplicates took 752.65 seconds to train.


