In [None]:
import sys
sys.path.append('..')
from mamba_ssm import Mamba2
import torch
import torch.nn as nn
import configs.common as cc
class MambaStack(nn.Module):
    def __init__(self, d_model=512, n_layers=12):
        super().__init__()

        self.token_embedding = nn.Embedding(cc.vocab_size, d_model)
        self.metadata_embedding = nn.Embedding(cc.metadata_vocab_size, d_model)
        self.output_layer = nn.Linear(d_model, cc.vocab_size)

        self.layers = nn.ModuleList([
            Mamba2(
                # This module uses roughly 3 * expand * d_model^2 parameters
                d_model=d_model, # Model dimension d_model
                d_state=64,  # SSM state expansion factor, typically 64 or 128
                d_conv=4,    # Local convolution width
                expand=2,    # Block expansion factor
            ).to("cuda") for _ in range(n_layers)
        ])
        self.norm = nn.LayerNorm(d_model)

    def forward(self, tokens, meta):
        x = self.token_embedding(tokens)
        x = torch.cat((self.metadata_embedding(meta), x), dim=-2)
        for layer in self.layers:
            x = layer(x)
        x = self.norm(x)
        return self.output_layer(x[:,6:])

model = MambaStack(512, 12).to("cuda")

In [None]:
import sys
sys.path.append('..')
import processing
import configs.common as cc

loader = processing.DatasetLoader('/mnt/e/github/dataset/np_dataset/ABBA')
train_dataloader, test_dataloader = loader.get_dataloaders()
# random_sample = loader.get_random_sample('train')
# random_sample
for src, trg, meta in train_dataloader:
    break
# src, trg, meta = random_sample

[45,
 16585,
 16665,
 16768,
 16960,
 84,
 16596,
 16664,
 16741,
 16960,
 48,
 16575,
 16641,
 16740,
 16960,
 76,
 16574,
 16654,
 16960,
 52,
 16581,
 16664,
 16747,
 16960,
 81,
 16590,
 16664,
 16740,
 16960,
 76,
 16580,
 16650,
 16747,
 16960,
 57,
 16572,
 16665,
 16741,
 16960,
 84,
 16599,
 16666,
 16749,
 16960,
 45,
 16591,
 16666,
 16740,
 16960,
 76,
 16579,
 16655,
 16741,
 16960,
 81,
 16591,
 16658,
 16747,
 16960,
 52,
 16572,
 16665,
 16740,
 16960,
 57,
 16581,
 16665,
 16748,
 16960,
 76,
 16585,
 16650,
 16740,
 16960,
 84,
 16602,
 16665,
 16750,
 16960,
 45,
 16584,
 16665,
 16740,
 16960,
 48,
 16577,
 16641,
 16960,
 76,
 16575,
 16654,
 16741,
 16960,
 81,
 16587,
 16646,
 16740,
 16960,
 81,
 16603,
 16665,
 16747,
 16960,
 52,
 16579,
 16665,
 16740,
 16960,
 57,
 16582,
 16666,
 16747,
 16960,
 76,
 16583,
 16665,
 16741,
 16960,
 69,
 16573,
 16641,
 16960,
 84,
 16603,
 16739,
 16748,
 16960,
 45,
 16591,
 16739,
 16741,
 16960,
 93,
 16593,
 16641,
 167

In [None]:
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")

tensor([  127, 16585, 16665,  ..., 16739, 16750, 16960])

In [None]:
output = model(src, meta)
output.shape

tensor([   45, 16585, 16665,  ..., 16739, 16750, 16960])