In [None]:
from mamba.head import MambaClassificationHead
from mamba.model import MambaTextClassification
from dataset import ImdbDataset
from utils import preprocess_function, compute_metrics
from mamba.trainer import MambaTrainer

from transformers import AutoTokenizer, TrainingArguments

import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from tqdm import tqdm
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.metrics import accuracy_score, classification_report
from torch.utils.data import DataLoader


**Hyperparameters**

In [None]:
model = MambaTextClassification.from_pretrained("state-spaces/mamba-130m")
model.to("cuda")

tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token_id = tokenizer.eos_token_id

max_seq_len = 256
batch_size = 32
num_epochs = 2

In [None]:
newsgroups = fetch_20newsgroups(subset='all', remove=('headers', 'footers', 'quotes'))
data = pd.DataFrame({'text_data': newsgroups.data, 'label': newsgroups.target})
data = data.sample(frac=1).reset_index(drop=True)

train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)

def tokenize_data(data, tokenizer, max_seq_len):
    input_ids, attention_masks, labels = [], [], []

    for index, row in tqdm(data.iterrows(), total=len(data)):
        encoded = tokenizer.encode_plus(
            row["text_data"],
            add_special_tokens=True,  
            max_length=max_seq_len,  
            padding="max_length",  
            truncation=True,  
            return_attention_mask=True,  
        )

        input_ids.append(encoded["input_ids"])
        attention_masks.append(encoded["attention_mask"])
        labels.append(row["label"])

    return torch.tensor(input_ids), torch.tensor(attention_masks), torch.tensor(labels)

train_input_ids, train_attention_masks, train_labels = tokenize_data(train_data, tokenizer, max_seq_len)
val_input_ids, val_attention_masks, val_labels = tokenize_data(val_data, tokenizer, max_seq_len)

In [None]:
train_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=batch_size)

val_dataset = TensorDataset(val_input_ids, val_attention_masks, val_labels)
val_sampler = SequentialSampler(val_dataset)
val_dataloader = DataLoader(val_dataset, sampler=val_sampler, batch_size=batch_size)

In [None]:
import torch
import torch.nn as nn

class MambaClassificationHead(nn.Module):
    def __init__(self, d_model, num_classes, **kwargs):
        super(MambaClassificationHead, self).__init__()
        
        self.classification_head = nn.Linear(d_model, num_classes, **kwargs)
        
    def forward(self, hidden_states):
        return self.classification_head(hidden_states)

In [None]:
import numpy as np
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from mamba_ssm.utils.hf import load_config_hf,load_state_dict_hf
from collections import namedtuple
import torch.nn as nn
import torch

from cfg.config import MambaConfig

class MambaTextClassification(MambaLMHeadModel):
    def __init__(
        self,
        config: MambaConfig,
        initializer_cfg = None,
        device = None,
        dtype = None,
    ) -> None:
        super().__init__(config, initializer_cfg, device, dtype)
        
        self.classification_head = MambaClassificationHead(d_model=config.d_model, num_classes=2)
        
        del self.lm_head
    
    def forward(self, input_ids, attention_mask = None, labels = None):
        hidden_states = self.backbone(input_ids)
        
        mean_hidden_states = hidden_states.mean(dim = 1)
        
        logits = self.classification_head(mean_hidden_states)
        
        if labels is None:
            ClassificationOuptput = namedtuple("ClassificationOutput", ["logits"])
            return ClassificationOuptput(logits = logits)
        else:
            ClassificationOutput = namedtuple("ClassificationOutput", ["loss", "logits"])
            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
            
            return ClassificationOuptput(loss = loss, logits = logits)
    def predict(self, text, tokenizer, id2label = None):
        input_ids = torch.tensor(tokenizer(text)['input_ids'], device = "cuda")[None]
        with torch.no_grad():
            logits = self.forward(input_ids).logits[0]
            label = np.argmax(logits.cpu().numpy())
            
        if id2label is not None:
            return id2label[label]
        else:
            return label
    
    @classmethod
    def from_pretrained(cls, pretrained_model_name, device = None, dtype = None, **kwargs):
        config_data = load_config_hf(pretrained_model_name)
        config = MambaConfig(**config_data)
        
        model = cls(config, device = device, dtype = dtype, **kwargs)
        
        model_state_dict = load_state_dict_hf(pretrained_model_name, device = device, dtype = dtype)
        model.load_state_dict(model_state_dict , strict=False)
        
        print (" Newly initialized embedding :", 
              set(model.state_dict().keys()) - set(model_state_dict.keys())
        )

        return model.to(device)

In [None]:
import json
import os
from transformers import Trainer
import torch

class MambaTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        input_ids = inputs.pop("input_ids") 
        labels = inputs.pop('labels')
        
        outputs = model(input_ids=input_ids , labels=labels)
        
        loss = outputs.loss
        
        return (loss, outputs) if return_outputs else loss

In [None]:
import numpy as np
import evaluate

accuracy = evaluate.load("accuracy")

def preprocess_function(tokenizer, examples):
    samples = tokenizer(examples["text"], truncation=True)
    samples.pop('attention_mask')
    return samples

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    
    return accuracy.compute(predictions=predictions, references=labels)

**Model Training**

In [None]:
training_args = TrainingArguments(
    learning_rate=5e-5,
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=16,  
    num_train_epochs=10,  
    warmup_ratio=0.01,  
    lr_scheduler_type="cosine"
)

trainer = MambaTrainer(
    model=model,  
    train_dataset=train_dataset,  
    eval_dataset=val_dataset, 
    tokenizer=tokenizer,  
    args=training_args,  
    compute_metrics=compute_metrics  
)

trainer.train()

**Evaluating Model Using Performance Metrics**

In [None]:
def get_predictions(model, dataloader, device):
    model.eval()
    predictions, true_labels = [], []

    for batch in tqdm(dataloader, desc="Evaluating"):
        input_ids, attention_masks, labels = [t.to(device) for t in batch]

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_masks)

        logits = outputs[0].detach().cpu().numpy()
        label_ids = labels.cpu().numpy()

        predictions.extend(logits.argmax(axis=-1))
        true_labels.extend(label_ids)

    return np.array(predictions), np.array(true_labels)

predictions, true_labels = get_predictions(model, val_dataloader, device)

accuracy = accuracy_score(true_labels, predictions)

report = classification_report(true_labels, predictions, digits=4)

print(f"Validation Accuracy: {accuracy:.4f}")
print("Classification Report:")
print(report)