In [1]:
import sys

sys.path.append("../..")
import torch
import numpy as np

In [2]:
from xlstm.xlstm_large.model import xLSTMLargeConfig, xLSTMLarge
from mlstm_kernels.torch import get_available_mlstm_step_kernels, get_available_mlstm_kernels, get_available_mlstm_sequence_kernels

In [3]:
get_available_mlstm_kernels(), get_available_mlstm_step_kernels(), get_available_mlstm_sequence_kernels()

(['chunkwise--native_autograd',
  'chunkwise--native_custbw',
  'chunkwise--triton_limit_chunk',
  'chunkwise--triton_xl_chunk',
  'chunkwise--triton_xl_chunk_siging',
  'parallel--native_autograd',
  'parallel--native_custbw',
  'parallel--native_stablef_autograd',
  'parallel--native_stablef_custbw',
  'parallel--triton_limit_headdim',
  'parallel--native_siging_autograd',
  'parallel--native_siging_custbw'],
 ['native', 'triton'],
 ['native_sequence__native', 'native_sequence__triton'])

In [4]:
xlstm_config = xLSTMLargeConfig(
    embedding_dim=2048,
    num_heads=8,
    num_blocks=32,
    vocab_size=65536,
    return_last_states=True,
    mode="inference",
    chunkwise_kernel="chunkwise--triton_xl_chunk", # xl_chunk == TFLA kernels
    sequence_kernel="native_sequence__triton",
    step_kernel="triton",
)

In [5]:
xlstm = xLSTMLarge(xlstm_config)

In [6]:
xlstm

xLSTMLarge(
  (embedding): Embedding(65536, 2048)
  (backbone): xLSTMLargeBlockStack(
    (blocks): ModuleList(
      (0-31): 32 x mLSTMBlock(
        (norm_mlstm): RMSNorm()
        (mlstm_layer): mLSTMLayer(
          (q): Linear(in_features=2048, out_features=1024, bias=False)
          (k): Linear(in_features=2048, out_features=1024, bias=False)
          (v): Linear(in_features=2048, out_features=2048, bias=False)
          (ogate_preact): Linear(in_features=2048, out_features=2048, bias=False)
          (igate_preact): Linear(in_features=2048, out_features=8, bias=True)
          (fgate_preact): Linear(in_features=2048, out_features=8, bias=True)
          (ogate_act_fn): Sigmoid()
          (mlstm_backend): mLSTMBackend(mLSTMBackendConfig(chunkwise_kernel='chunkwise--triton_xl_chunk', sequence_kernel='native_sequence__triton', step_kernel='triton', mode='inference', chunk_size=64, return_last_states=True, autocast_kernel_dtype='bfloat16', eps=1e-06, inference_state_dtype='float3

In [7]:
xlstm = xlstm.to("cuda")

In [8]:
f"[model] parameters ≈ {sum(p.numel() for p in xlstm.parameters()) / 1e6:.1f}M"

'[model] parameters ≈ 1888.7M'

In [9]:
input = torch.randint(0, 2048, (3, 256)).to("cuda")
input.shape

torch.Size([3, 256])

In [10]:
import os
if os.path.exists("/usr/local/cuda/bin") and "/usr/local/cuda/bin" not in os.environ["PATH"]:
    os.environ["PATH"] += ":/usr/local/cuda/bin"

In [11]:
out = xlstm(input)

In [12]:
if len(out) == 2:
    out, state = out

In [13]:
out.shape[1:] == (256, 2048)

False

In [14]:
state.keys()

dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])

In [15]:
len(state), len(state[0])

(32, 3)

In [16]:
input[:, 0:1].shape, input.shape

(torch.Size([3, 1]), torch.Size([3, 256]))

In [17]:
step_out, step_state = xlstm(input[:, 0:1], state)

In [18]:
step_out.shape

torch.Size([3, 1, 65536])

In [19]:
out_chunkwise, last_state_chunkwise = xlstm(input)

In [20]:
out_steps = []
state = None
for i in range(input.shape[1]):
    out_step, state = xlstm(input[:, i:i + 1], state)
    out_steps.append(out_step)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 23.43 GiB of which 2.06 MiB is free. Including non-PyTorch memory, this process has 23.42 GiB memory in use. Of the allocated memory 21.88 GiB is allocated by PyTorch, and 1.10 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [21]:
out_steps = torch.cat(out_steps, dim=1)

In [22]:
out_steps.shape, out_chunkwise.shape

(torch.Size([3, 256, 2048]), torch.Size([3, 256, 2048]))

In [23]:
(out_chunkwise - out_steps).abs().max()

tensor(0.0138, device='cuda:0', grad_fn=<MaxBackward1>)

In [24]:
torch.allclose(out_chunkwise, out_steps, atol=7e-2, rtol=1e-3)

True

In [25]:
list(xlstm.state_dict().keys())

['embedding.weight',
 'backbone.blocks.0.norm_mlstm.weight',
 'backbone.blocks.0.mlstm_layer.q.weight',
 'backbone.blocks.0.mlstm_layer.k.weight',
 'backbone.blocks.0.mlstm_layer.v.weight',
 'backbone.blocks.0.mlstm_layer.ogate_preact.weight',
 'backbone.blocks.0.mlstm_layer.igate_preact.weight',
 'backbone.blocks.0.mlstm_layer.igate_preact.bias',
 'backbone.blocks.0.mlstm_layer.fgate_preact.weight',
 'backbone.blocks.0.mlstm_layer.fgate_preact.bias',
 'backbone.blocks.0.mlstm_layer.multihead_norm.weight',
 'backbone.blocks.0.mlstm_layer.out_proj.weight',
 'backbone.blocks.0.norm_ffn.weight',
 'backbone.blocks.0.ffn.proj_up_gate.weight',
 'backbone.blocks.0.ffn.proj_up.weight',
 'backbone.blocks.0.ffn.proj_down.weight',
 'backbone.blocks.1.norm_mlstm.weight',
 'backbone.blocks.1.mlstm_layer.q.weight',
 'backbone.blocks.1.mlstm_layer.k.weight',
 'backbone.blocks.1.mlstm_layer.v.weight',
 'backbone.blocks.1.mlstm_layer.ogate_preact.weight',
 'backbone.blocks.1.mlstm_layer.igate_preact.we