# Divide, Triton, and Conquer (Independently)
## Exploring Chunked States in Bamba
### Rafka Daou, Maria Garmonina, Sarah Korb


The following code blocks handle importing the necessary modules required for the program to run.





In [None]:
! pip install -e .

! pip install pynvml rouge_score

! pip install ibm-fms

Clone Repository




In [None]:
! git clone https://github.com/rafkamicheldaou/foundation-model-stack.git

In [None]:
# Set working directory and fix sys.path
import sys, os

os.chdir("/content/foundation-model-stack")
sys.path.insert(0, os.getcwd())

# Confirm you're using the correct `fms`
import fms
print("Using fms from:", fms.__file__)  # should point to /content/foundation-model-stack/fms

In [None]:
import torch
import time
import torch.nn.functional as F
import psutil
import pynvml
import pandas as pd
from torch.profiler import profile, ProfilerActivity

from fms.utils.tokenizers import get_tokenizer
from fms.utils.generation import generate

The following code is responsible for loading the IBM Bamba module and configuring key parameters, such as the number of layers the model will use. For the purpose of evaluating throughput and latency, we prioritized a configuration with 4 layers while varying prompt lengths.
To address the limitations of state-space models, we implemented two distinct module variants: default optimized, independent.
To specify which SSM (State Space Model) module is used during model execution, we manually set the desired module in the statements below.

**The model version you use should depend on whether you are benchmarking performance or accuracy. Please ensure you load the appropriate model based on the specific task.**



In [None]:
# CODE TO LOAD FOR PERFORMANCE BENCHMARKING
from fms.modules.ssm import SSM as DefaultSSM # you can change this to fms.modules.default_optimized_ssm or fms.modules.independent_simple_ssm
import fms.models.bamba as _bamba_mod


_bamba_mod.SSM = DefaultSSM # assigning this before fetching the pretrained model to not run into errors

# Now load Bamba
from fms.models import get_model

import torch

# Load trimmed model properly
model = get_model(
    "hf_configured",
    "ibm-ai-platform/Bamba-9B",
    device_type="cuda",
    data_type=torch.bfloat16,
    nlayers=4,
)
model.config.attn_layer_indices = []


print("Number of layers:", len(model.base_model.layers))
print("Config nlayers:", model.config.nlayers)
print("Attention layers indices:", model.config.attn_layer_indices)

In [None]:
# CODE TO LOAD FOR ACCURACY BENCHMARKING
from fms.modules.default_triton_ssm import SSM as DefaultSSM # you can change this to fms.modules.default_optimized_ssm or fms.modules.independent_simple_ssm
import fms.models.bamba as _bamba_mod


_bamba_mod.SSM = DefaultSSM # assigning this before fetching the pretrained model to not run into errors

# Now load Bamba
from fms.models import get_model

import torch
model = get_model(
    "hf_pretrained",
    "ibm-ai-platform/Bamba-9B-v2",
    device_type="cuda",
)
model.config.attn_layer_indices = []

print("Number of layers:", len(model.base_model.layers), flush=True)
print("Config nlayers:", model.config.nlayers, flush=True)
print("Attention layers indices:", model.config.attn_layer_indices, flush=True)

In [None]:
# confirm model parameters
print(model.config)

In [None]:
# confirm which SSM module is running
layer = next(
    block for block in model.base_model.layers
    if hasattr(block, "ssm")
)
print(isinstance(layer.ssm, DefaultSSM), # PASS IN THE LOADED SSM MODULE HERE TO VERIFY IT IS BEING LOADED PROPERLY
      layer.ssm.__class__.__module__ + "." + layer.ssm.__class__.__name__)

## Performance Benchmarking

This code benchmarks the performance of different chunked state strategies for the Bamba model. It processes a list of prompts by tokenizing the input, running the model to generate output, and collecting detailed performance metrics. These include total latency, first-token and inter-token generation times, throughput (tokens per second), peak memory usage, memory bandwidth, CPU/GPU utilization, and total FLOPs. The profiler also logs the most time-consuming CUDA operations. All results are recorded for analysis, allowing comparison across different SSM module variants and prompt configurations.

In [None]:
import json
# prompts of varying lengths
with open("./benchmarking_data/longer_qa_for_benchmarking_performance.json", "r",encoding="utf-8") as f:
  qa_pairs = json.load(f)

In [None]:
def ids_for_prompt(prompt, tokenizer, device):
    toks = tokenizer.tokenize(prompt)
    ids  = tokenizer.convert_tokens_to_ids(toks)
    return torch.tensor(ids, dtype=torch.long, device=device)

def decode_ids(ids):
    toks  = tokenizer.convert_ids_to_tokens(ids)
    return tokenizer.convert_tokens_to_string(toks)

device = torch.device("cuda")
tokenizer = get_tokenizer("ibm-ai-platform/Bamba-9B")

pynvml.nvmlInit()
gpu_handle = pynvml.nvmlDeviceGetHandleByIndex(0)

model.compile()

records = []
MAX_NEW_TOKENS = 100
log_path = "./benchmarking_data/performance_default.txt" # change the log path name
with open(log_path, "a", encoding="utf-8") as log_file:
    for idx, item in enumerate(qa_pairs[:70], start=1):
        inputs = ids_for_prompt(item["prompt"], tokenizer, device)

        # system stats before
        cpu0 = psutil.cpu_percent(None)
        io0  = psutil.cpu_times_percent(None).iowait
        gpu0 = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle).gpu
        torch.cuda.reset_peak_memory_stats()

        # profile the generate step
        with profile(
            activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
            record_shapes=True,
            profile_memory=True,
            with_flops=True,
        ) as prof:
            torch.cuda.synchronize()
            t_start = time.time()

            out_ids, times = generate(
                model,
                inputs,
                max_new_tokens=MAX_NEW_TOKENS,
                use_cache=False,
                timing="per-token",
            )

            torch.cuda.synchronize()
            t_end = time.time()

        # system stats after
        cpu1 = psutil.cpu_percent(None)
        io1  = psutil.cpu_times_percent(None).iowait
        gpu1 = pynvml.nvmlDeviceGetUtilizationRates(gpu_handle).gpu
        peak_mem = torch.cuda.max_memory_allocated() / 1024**2

        # derive metrics
        t_first     = times[0]
        t_mean      = sum(times[1:]) / len(times[1:])
        total_time  = t_end - t_start
        throughput  = MAX_NEW_TOKENS / total_time
        mem_bw      = peak_mem / total_time
        total_flops = sum(evt.flops for evt in prof.key_averages() if hasattr(evt, "flops"))
        top_ops     = prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=5)

        rec = {
            "id": idx,
            "total_latency_s": total_time,
            "first_token_s": t_first,
            "mean_inter_token_s": t_mean,
            "throughput_tok_s": throughput,
            "peak_mem_MB": peak_mem,
            "mem_bw_MBps": mem_bw,
            "cpu_start_%": cpu0, "cpu_end_%": cpu1,
            "gpu_start_%": gpu0, "gpu_end_%": gpu1,
            "io_wait_diff_%": io1 - io0,
            "total_flops": total_flops,
            "profiler_top_ops": top_ops,
            "output": decode_ids(out_ids),
            "num_chunks": item["num_chunks"],
            "token_len": item["token_len"],
            "prompt": item["prompt"]
        }

        log_file.write(f"{rec}\n")
        records.append(rec)

        print(
            f"{idx}/160 | tot={total_time:.3f}s "
            f"| first={t_first:.3f}s | inter={t_mean:.4f}s | thr={throughput:.1f} tok/s",flush=True
        )

df = pd.DataFrame(records)
df.to_csv("./benchamrking_data/performance_default.csv", index=False)

# Accuarcy Benchmarking

This code performs benchmarking for evaluating the accuracy of different chunked state strategies in the Bamba model. It processes a list of QA pairs by tokenizing the prompts, generating only the predicted answer portion (based on reference answer length), and decoding the generated output. The results—including prompt, reference answer, and model prediction—are logged to a TSV file for later analysis. This setup isolates the model’s generative accuracy, allowing precise comparisons across chunking strategies while controlling for output length.

Please refer to the notebook `GPTScore.ipynb` for accuracy evaluation.

In [None]:
with open("./benchmarking_data/qa_for_accuracy_256.json", "r", encoding="utf-8") as f:
    qa_pairs = json.load(f)

In [None]:
import csv
import os


device = torch.device("cuda")
tokenizer = get_tokenizer("ibm-ai-platform/Bamba-9B")
model.compile()

def ids_for_prompt(prompt: str, tokenizer, device):
    toks = tokenizer.tokenize(prompt)
    ids = tokenizer.convert_tokens_to_ids(toks)
    if tokenizer.bos_token_id != tokenizer.eos_token_id:
        ids = [tokenizer.bos_token_id] + ids
    return torch.tensor(ids, dtype=torch.long, device=device)

def decode_ids(ids: torch.Tensor):
    toks = tokenizer.convert_ids_to_tokens(ids.tolist())
    return tokenizer.convert_tokens_to_string(toks)

log_path = "./benchmarking_data/accuracy_default.tsv" # change the name of the log path
write_header = not os.path.exists(log_path)

with open(log_path, "a", newline="", encoding="utf-8") as f:
    writer = csv.writer(f, delimiter="\t")
    if write_header:
        writer.writerow(["id", "prompt", "reference", "prediction"])

    with torch.no_grad():
        for idx, item in enumerate(qa_pairs[63:73], start=1):
            prompt = item["prompt"]
            inputs = ids_for_prompt(prompt, tokenizer, device)
            prompt_len = inputs.size(0)
            prompt_len = inputs.size(0)
            answer_len = len(tokenizer.tokenize(item["answer"]))
            total_len = prompt_len + answer_len

            print(f"Prompt {idx} | Prompt tokens: {prompt_len} | Answer tokens: {answer_len} | Total: {total_len}",flush=True)
            out_ids = generate(
                model,
                inputs,
                max_new_tokens=len(tokenizer.tokenize(item["answer"])),
                use_cache=True,
                timing="",
                eos_token_id=tokenizer.eos_token_id,
            )

            new_ids = out_ids[prompt_len:]
            output_text = decode_ids(new_ids)
            reference = item["answer"]

            writer.writerow([idx, prompt, reference, output_text])
            print(f"{idx:3d}: {output_text}…")