# Split LoRA MoE Benchmark: 4-bit QLoRA vs bfloat16 Split LoRA

## Purpose

This notebook benchmarks two LoRA training strategies for the GPT-OSS 20B MoE model side-by-side:

- **4-bit QLoRA (baseline)**: Loads model in 4-bit NF4 quantization. Standard production pipeline configuration. Lower VRAM footprint for weights, but cannot leverage Unsloth's Faster MoE optimization.
- **bfloat16 Split LoRA**: Loads model in bfloat16. Enables Unsloth's Faster MoE (Split LoRA), which reorders LoRA matrix operations to only compute on routed token-expert pairs — giving **7-12x speedup** and **35%+ VRAM savings** during the forward/backward pass despite using unquantized weights.

## How Split LoRA Works

In a standard MoE layer, all tokens are processed by the full LoRA adapter even though each token only routes to top-K experts. Unsloth's Faster MoE reorders the LoRA ops so they execute *after* routing: only the subset of (token, expert) pairs that are actually selected compute LoRA deltas. This eliminates ~(1 - top_k/num_experts) of the LoRA flops per MoE layer.

## Expected Results

| Metric | 4-bit QLoRA | bf16 Split LoRA |
|--------|------------|------------------|
| Step time | Baseline | **7-12x faster** |
| Peak VRAM | Baseline | **35%+ lower** |
| Expert LoRA | Yes (via fix) | Yes |
| Weight precision | 4-bit NF4 | bfloat16 |

Split LoRA should show faster step times and lower peak VRAM despite using bfloat16 base weights, because the VRAM savings from skipping unrouted expert LoRA ops outweigh the cost of unquantized weights.

## Reference
https://unsloth.ai/docs/new/faster-moe

## Step 0: Environment Setup

### 0.1 Mount Google Drive & Install Dependencies

In [None]:
# Colab detection and setup
import os
import subprocess
import sys

IN_COLAB = "google.colab" in sys.modules or os.path.exists("/content")

if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive", force_remount=False)
    # Install dependencies
    subprocess.run(
        [sys.executable, "-m", "pip", "install", "-q",
         "unsloth", "transformers", "datasets", "trl", "peft",
         "accelerate", "ipywidgets", "matplotlib"],
        check=True,
    )

# Environment variables - must be set before torch import
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["WANDB_MODE"] = "offline"

# Path setup
sys.path.insert(0, "scripts")
if IN_COLAB and os.path.exists("/content/drive/MyDrive/llm-training-pipeline"):
    os.chdir("/content/drive/MyDrive/llm-training-pipeline")
    sys.path.insert(0, "scripts")

print(f"Working directory: {os.getcwd()}")
print(f"Colab: {IN_COLAB}")


### 0.2 Configure Benchmark

In [None]:
import json

# @param forms for Colab UI
gpu_tier = "h100_80gb" #@param ["a100_40gb", "a100_80gb", "h100_80gb"] {type: "string"}
benchmark_steps = 50 #@param {type: "integer"}
lora_rank = 64 #@param {type: "integer"}
max_seq_length = 4096 #@param {type: "integer"}
include_training_run = True #@param {type: "boolean"}

GPU_PROFILES = {
    "a100_40gb": {
        "moe_backend": "unsloth_triton",
        "baseline_batch_size": 1,
        "split_batch_size": 1,
        "expected_speedup": "7-10x",
        "bf16_viable": False,  # 40GB too small for bf16 20B
    },
    "a100_80gb": {
        "moe_backend": "unsloth_triton",
        "baseline_batch_size": 2,
        "split_batch_size": 2,
        "expected_speedup": "7-10x",
        "bf16_viable": True,
    },
    "h100_80gb": {
        "moe_backend": "grouped_mm",
        "baseline_batch_size": 2,
        "split_batch_size": 2,
        "expected_speedup": "10-12x",
        "bf16_viable": True,
    },
}

BENCH_CONFIG = {
    "gpu_tier": gpu_tier,
    "benchmark_steps": benchmark_steps,
    "lora_rank": lora_rank,
    "max_seq_length": max_seq_length,
    "include_training_run": include_training_run,
    **GPU_PROFILES[gpu_tier],
}

# Shared LoRA config for fair comparison
LORA_CONFIG = {
    "r": lora_rank,
    "lora_alpha": lora_rank,  # alpha == rank is standard
    "lora_dropout": 0,
    "bias": "none",
    "use_gradient_checkpointing": "unsloth",
    "random_state": 42,
}

os.makedirs("/tmp/split_lora_bench", exist_ok=True)
with open("/tmp/split_lora_bench/config.json", "w") as f:
    json.dump(BENCH_CONFIG, f, indent=2)

print("Benchmark Configuration:")
for k, v in BENCH_CONFIG.items():
    print(f"  {k}: {v}")


### 0.3 Benchmark Tracker

In [None]:
import ipywidgets as widgets
from IPython.display import display


class BenchmarkTracker:
    """Visual progress tracker for benchmark phases."""

    PHASES = [
        ("gpu_detect", "GPU Detection"),
        ("synthetic_data", "Synthetic Data"),
        ("baseline_load", "Baseline: Load Model"),
        ("baseline_lora", "Baseline: Apply LoRA"),
        ("baseline_bench", "Baseline: Benchmark"),
        ("cleanup", "Cleanup"),
        ("split_load", "Split LoRA: Load Model"),
        ("split_lora", "Split LoRA: Apply LoRA"),
        ("split_bench", "Split LoRA: Benchmark"),
        ("training_run", "Training Validation"),
        ("comparison", "Results Comparison"),
        ("recommendation", "Recommendation"),
    ]

    def __init__(self):
        self._bars = {}
        self._labels = {}
        rows = []
        for key, name in self.PHASES:
            label = widgets.HTML(
                value=f"<span style='color:#888'>&#x25CB; {name}</span>",
                layout=widgets.Layout(width="260px"),
            )
            bar = widgets.FloatProgress(
                value=0,
                min=0,
                max=1.0,
                bar_style="info",
                layout=widgets.Layout(width="300px", height="18px"),
            )
            self._bars[key] = bar
            self._labels[key] = label
            rows.append(widgets.HBox([label, bar]))
        self._container = widgets.VBox(rows)
        display(widgets.HTML("<b>Benchmark Progress</b>"))
        display(self._container)

    def start(self, phase):
        self._bars[phase].value = 0.1
        self._bars[phase].bar_style = "info"
        name = dict(self.PHASES)[phase]
        self._labels[phase].value = f"<span style='color:#2196F3'>&#x25B6; {name}</span>"

    def complete(self, phase):
        self._bars[phase].value = 1.0
        self._bars[phase].bar_style = "success"
        name = dict(self.PHASES)[phase]
        self._labels[phase].value = f"<span style='color:#4CAF50'>&#x2714; {name}</span>"

    def skip(self, phase):
        self._bars[phase].value = 1.0
        self._bars[phase].bar_style = ""
        name = dict(self.PHASES)[phase]
        self._labels[phase].value = f"<span style='color:#888'>&#x23ED; {name} (skipped)</span>"

    def fail(self, phase):
        self._bars[phase].value = 1.0
        self._bars[phase].bar_style = "danger"
        name = dict(self.PHASES)[phase]
        self._labels[phase].value = f"<span style='color:#F44336'>&#x2718; {name}</span>"


tracker = BenchmarkTracker()

# Results accumulator
results = {
    "baseline": {},
    "split_lora": {},
    "training_validation": {},
    "verdict": None,
}


### 0.4 Check GPU & Set MoE Backend

In [None]:
import torch

tracker.start("gpu_detect")

if not torch.cuda.is_available():
    raise RuntimeError("CUDA not available. This notebook requires a GPU.")

gpu_name = torch.cuda.get_device_name(0)
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
bf16_supported = torch.cuda.is_bf16_supported()

print(f"GPU: {gpu_name}")
print(f"VRAM: {gpu_mem_gb:.1f} GB")
print(f"BF16 supported: {bf16_supported}")

# Auto-override tier based on detected GPU
if "H100" in gpu_name or "H200" in gpu_name:
    detected_tier = "h100_80gb"
elif "A100" in gpu_name:
    detected_tier = "a100_80gb" if gpu_mem_gb > 50 else "a100_40gb"
else:
    detected_tier = gpu_tier  # Keep user selection
    print(f"  Unknown GPU, keeping user-selected tier: {gpu_tier}")

if detected_tier != gpu_tier:
    print(f"  Auto-overriding tier: {gpu_tier} -> {detected_tier}")
    gpu_tier = detected_tier
    BENCH_CONFIG.update(GPU_PROFILES[detected_tier])
    BENCH_CONFIG["gpu_tier"] = detected_tier

# CRITICAL: Set MoE backend BEFORE any unsloth import
moe_backend = BENCH_CONFIG["moe_backend"]
os.environ["UNSLOTH_MOE_BACKEND"] = moe_backend
print(f"MoE backend: {moe_backend}")

if not bf16_supported:
    raise RuntimeError("BF16 not supported on this GPU. Split LoRA requires BF16.")

if not BENCH_CONFIG["bf16_viable"]:
    print(f"\nWARNING: {detected_tier} may not have enough VRAM for bf16 GPT-OSS 20B.")
    print("Baseline (4-bit) will run but Split LoRA (bf16) may OOM.")

tracker.complete("gpu_detect")


## Step 1: Prepare Synthetic Benchmark Data

We generate 200 synthetic multi-turn coding agent conversations in Harmony format. These simulate real agent sessions with tool calls, thinking traces, and patch operations — the same distribution used in TUI training.

The pipeline's `encode_harmony_messages` formatter is used if available in the path. If the import fails (e.g. running in a fresh Colab without the repo), a manual ChatML-style fallback is used instead. Both produce valid training text, but the Harmony formatter adds structured tool call tokens and `<think>` blocks matching the model's training distribution.

In [None]:
import copy
import json
import random

from datasets import Dataset

tracker.start("synthetic_data")

NUM_SAMPLES = 200

# Synthetic conversation templates
TEMPLATES = [
    {
        "messages": [
            {"role": "developer", "content": "You are a coding assistant. Use tools to read and modify files."},
            {"role": "user", "content": "There's a bug in {file}. The {feature} isn't working correctly."},
            {
                "role": "assistant",
                "thinking": "Let me read the file to understand the issue.",
                "tool_calls": [
                    {"id": "1", "type": "function", "function": {"name": "read_file", "arguments": '{{"path": "{file}"}}'}}  # noqa: E501
                ],
            },
            {"role": "tool", "tool_call_id": "1", "content": "fn main() {{\n    let data = vec![1, 2, 3];\n    println!(\"{{:?}}\", data);\n}}"},
            {
                "role": "assistant",
                "thinking": "I see the issue. The {feature} needs to be fixed by {fix}.",
                "content": "I found the bug. The {feature} in `{file}` needs {fix}. Let me apply the patch.",
                "tool_calls": [
                    {"id": "2", "type": "function", "function": {"name": "write_file", "arguments": '{{"path": "{file}", "content": "fn main() {{\\n    let data = vec![1, 2, 3, 4];\\n    println!(\\"{{:?}}\\", data);\\n}}"}}'}}  # noqa: E501
                ],
            },
            {"role": "tool", "tool_call_id": "2", "content": "File written successfully."},
            {"role": "assistant", "content": "I've fixed the {feature} in `{file}`. The issue was {fix}."},
        ],
    },
    {
        "messages": [
            {"role": "developer", "content": "You are a debugging assistant with access to shell and file tools."},
            {"role": "user", "content": "Run the tests for {project} and fix any failures."},
            {
                "role": "assistant",
                "thinking": "I'll run the test suite first to identify failures.",
                "tool_calls": [
                    {"id": "1", "type": "function", "function": {"name": "run_command", "arguments": '{"command": "cargo test"}'}}
                ],
            },
            {"role": "tool", "tool_call_id": "1", "content": "test {test_name} ... FAILED\n\nfailures:\n    {test_name}: assertion failed: expected {expected}, got {actual}"},
            {
                "role": "assistant",
                "content": "Test `{test_name}` failed. The assertion expected `{expected}` but got `{actual}`. Let me fix this.",
            },
        ],
    },
]

random.seed(42)

files = ["main.rs", "lib.rs", "utils.rs", "handler.rs", "config.rs", "parser.rs", "server.rs", "client.rs"]
features = ["error handling", "input validation", "data parsing", "connection pooling", "cache invalidation", "rate limiting"]
fixes = ["adding a bounds check", "handling the None case", "fixing the off-by-one error", "adding timeout logic", "correcting the regex pattern"]
projects = ["web-server", "cli-tool", "data-pipeline", "api-gateway", "task-runner"]
test_names = ["test_parse_input", "test_handle_request", "test_validate_config", "test_connection", "test_cache_hit"]

samples = []
for i in range(NUM_SAMPLES):
    template = random.choice(TEMPLATES)
    messages = copy.deepcopy(template["messages"])

    # String substitution in all message fields
    subs = {
        "file": random.choice(files),
        "feature": random.choice(features),
        "fix": random.choice(fixes),
        "project": random.choice(projects),
        "test_name": random.choice(test_names),
        "expected": str(random.randint(1, 100)),
        "actual": str(random.randint(1, 100)),
    }

    for msg in messages:
        if "content" in msg and msg["content"]:
            msg["content"] = msg["content"].format(**subs)
        if "thinking" in msg and msg["thinking"]:
            msg["thinking"] = msg["thinking"].format(**subs)
        if "tool_calls" in msg:
            for tc in msg["tool_calls"]:
                tc["function"]["arguments"] = tc["function"]["arguments"].format(**subs)

    # Try Harmony formatter, fall back to manual
    try:
        from dataset_formatters.harmony import encode_harmony_messages
        text = encode_harmony_messages(messages, developer_instructions=None, reasoning_effort="medium")
    except ImportError:
        # Manual fallback: simple chatml-style
        parts = []
        for msg in messages:
            role = msg["role"]
            content = msg.get("content", "")
            if msg.get("thinking"):
                content = f"<think>{msg['thinking']}</think>\n{content}"
            if msg.get("tool_calls"):
                tc_str = json.dumps(msg["tool_calls"])
                content = f"{content}\n<tool_calls>{tc_str}</tool_calls>" if content else f"<tool_calls>{tc_str}</tool_calls>"
            parts.append(f"<|im_start|>{role}\n{content}<|im_end|>")
        text = "\n".join(parts) + "<|endoftext|>"

    samples.append({"text": text})

benchmark_dataset = Dataset.from_list(samples)
print(f"Created {len(benchmark_dataset)} synthetic samples")
print(f"Sample length (chars): {len(samples[0]['text'])}")

tracker.complete("synthetic_data")


## Step 2: Baseline Measurement (4-bit QLoRA)

The baseline uses the standard 4-bit NF4 quantized approach with QLoRA, matching the current production pipeline configuration in `configs/gpt_oss_20b.py`. This is the reference point for all speedup and VRAM savings calculations.

### 2.0 Benchmark Utilities

In [None]:
import gc
import time


def get_vram_gb():
    """Current VRAM usage in GB."""
    return torch.cuda.memory_allocated() / (1024**3)


def get_vram_peak_gb():
    """Peak VRAM usage in GB."""
    return torch.cuda.max_memory_allocated() / (1024**3)


def reset_vram_stats():
    """Reset peak VRAM tracking."""
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.empty_cache()


def cleanup_model(*objects):
    """Delete model objects and free VRAM."""
    for obj in objects:
        del obj
    gc.collect()
    torch.cuda.empty_cache()
    print(f"VRAM after cleanup: {get_vram_gb():.2f} GB")


def print_measurement(label, data):
    """Print a benchmark measurement."""
    print(f"\n{'=' * 50}")
    print(f"  {label}")
    print(f"{'=' * 50}")
    for k, v in data.items():
        if isinstance(v, float):
            print(f"  {k}: {v:.4f}")
        else:
            print(f"  {k}: {v}")


### 2.1 Load Baseline Model (4-bit QLoRA)

In [None]:
from pipeline_lib.unsloth_utils import (
    apply_lora_config,
    detect_moe_experts,
    load_unsloth_model,
    print_trainable_params,
    verify_expert_lora,
)

tracker.start("baseline_load")
reset_vram_stats()

print("Loading model with 4-bit quantization (QLoRA baseline)...")
t0 = time.perf_counter()
baseline_model, tokenizer = load_unsloth_model(
    max_seq_length=BENCH_CONFIG["max_seq_length"],
    load_in_4bit=True,
    tiled_mlp=True,
    offload_embedding=True,
)
load_time = time.perf_counter() - t0

results["baseline"]["load_time_s"] = load_time
results["baseline"]["vram_after_load_gb"] = get_vram_gb()
print(f"Load time: {load_time:.1f}s")
print(f"VRAM after load: {get_vram_gb():.2f} GB")

# Detect MoE structure
moe_info = detect_moe_experts(baseline_model)
print(f"MoE detected: {moe_info['is_moe']}")
if moe_info["is_moe"]:
    print(f"  Experts: {moe_info['num_experts']}")
    print(f"  Expert params: {moe_info['expert_param_count']:,}")

tracker.complete("baseline_load")

# Apply LoRA
tracker.start("baseline_lora")
print("\nApplying LoRA config...")
baseline_model = apply_lora_config(baseline_model, LORA_CONFIG, auto_detect_moe=True)

lora_result = verify_expert_lora(baseline_model)
results["baseline"]["has_expert_lora"] = lora_result["has_expert_lora"]
results["baseline"]["expert_lora_params"] = lora_result["expert_lora_params"]
results["baseline"]["attention_lora_params"] = lora_result["attention_lora_params"]
results["baseline"]["vram_after_lora_gb"] = get_vram_gb()

print(f"Expert LoRA active: {lora_result['has_expert_lora']}")
print(f"Expert LoRA params: {lora_result['expert_lora_params']:,}")
print(f"Attention LoRA params: {lora_result['attention_lora_params']:,}")
print_trainable_params(baseline_model)

tracker.complete("baseline_lora")


### 2.2 Benchmark Baseline

In [None]:
from trl import SFTConfig, SFTTrainer

tracker.start("baseline_bench")
reset_vram_stats()

print(f"Running {BENCH_CONFIG['benchmark_steps']} benchmark steps (4-bit QLoRA)...")
print("First few steps may be slower (compilation warmup).\n")

baseline_training_args = SFTConfig(
    output_dir="/tmp/split_lora_bench/baseline",
    max_steps=BENCH_CONFIG["benchmark_steps"],
    per_device_train_batch_size=BENCH_CONFIG["baseline_batch_size"],
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_steps=5,
    logging_steps=5,
    bf16=True,
    optim="adamw_8bit",
    seed=42,
    max_seq_length=BENCH_CONFIG["max_seq_length"],
    packing=True,
    dataset_text_field="text",
    report_to="none",
    save_strategy="no",
)

baseline_trainer = SFTTrainer(
    model=baseline_model,
    tokenizer=tokenizer,
    train_dataset=benchmark_dataset,
    args=baseline_training_args,
)

# Warmup: 3 steps not timed
print("Warmup (3 steps)...")
baseline_training_args.max_steps = 3
baseline_trainer.train()

# Reset for actual benchmark
reset_vram_stats()
baseline_training_args.max_steps = BENCH_CONFIG["benchmark_steps"]
baseline_trainer.args.max_steps = BENCH_CONFIG["benchmark_steps"]

print(f"Benchmarking ({BENCH_CONFIG['benchmark_steps']} steps)...")
t0 = time.perf_counter()
baseline_trainer.train()
wall_time = time.perf_counter() - t0

# Extract metrics
log_history = baseline_trainer.state.log_history
losses = [e["loss"] for e in log_history if "loss" in e]

results["baseline"]["wall_time_s"] = wall_time
results["baseline"]["avg_step_time_s"] = wall_time / BENCH_CONFIG["benchmark_steps"]
results["baseline"]["peak_vram_gb"] = get_vram_peak_gb()
results["baseline"]["loss_start"] = losses[0] if losses else None
results["baseline"]["loss_end"] = losses[-1] if losses else None

print_measurement("Baseline (4-bit QLoRA)", results["baseline"])

tracker.complete("baseline_bench")


### 2.3 Cleanup Baseline Model

In [None]:
tracker.start("cleanup")
cleanup_model(baseline_trainer, baseline_model)
del baseline_trainer, baseline_model
gc.collect()
torch.cuda.empty_cache()
print(f"VRAM after full cleanup: {get_vram_gb():.2f} GB")
tracker.complete("cleanup")


## Step 3: Split LoRA Measurement (bfloat16)

The key difference from the baseline is `load_in_4bit=False` — the model loads in bfloat16 with Split LoRA. Unsloth's Faster MoE reorders LoRA matrix operations to only compute on the routed (token, expert) pairs selected during MoE dispatch, rather than broadcasting across all experts.

**First run triggers Triton kernel autotuning (~2 minutes)**, which is excluded from benchmark timing by running 3 warmup steps before the timed window. Subsequent runs on the same Colab session will skip autotuning.

### 3.1 Load Model with bfloat16

In [None]:
tracker.start("split_load")
reset_vram_stats()

print("Loading model in bfloat16 (Split LoRA mode)...")
print("This uses more VRAM than 4-bit but enables Faster MoE optimization.\n")

try:
    t0 = time.perf_counter()
    split_model, tokenizer = load_unsloth_model(
        max_seq_length=BENCH_CONFIG["max_seq_length"],
        load_in_4bit=False,
        dtype=torch.bfloat16,
        tiled_mlp=True,
        offload_embedding=True,
    )
    load_time = time.perf_counter() - t0

    results["split_lora"]["load_time_s"] = load_time
    results["split_lora"]["vram_after_load_gb"] = get_vram_gb()
    print(f"Load time: {load_time:.1f}s")
    print(f"VRAM after load: {get_vram_gb():.2f} GB")

    tracker.complete("split_load")
    split_load_ok = True

except torch.cuda.OutOfMemoryError:
    print("\nOOM: Not enough VRAM for bfloat16 model.")
    print(f"GPU has {gpu_mem_gb:.0f} GB but bf16 GPT-OSS 20B needs ~40 GB for weights alone.")
    print("Split LoRA benchmark cannot proceed on this GPU.")
    results["split_lora"]["error"] = "OOM on model load"
    tracker.fail("split_load")
    tracker.skip("split_lora")
    tracker.skip("split_bench")
    split_load_ok = False

except Exception as e:
    print(f"\nError loading model: {e}")
    results["split_lora"]["error"] = str(e)
    tracker.fail("split_load")
    tracker.skip("split_lora")
    tracker.skip("split_bench")
    split_load_ok = False


### 3.2 Benchmark Split LoRA

In [None]:
if split_load_ok:
    # Apply LoRA (same config as baseline for fair comparison)
    tracker.start("split_lora")
    print("Applying LoRA config (same as baseline)...")
    split_model = apply_lora_config(split_model, LORA_CONFIG, auto_detect_moe=True)

    lora_result = verify_expert_lora(split_model)
    results["split_lora"]["has_expert_lora"] = lora_result["has_expert_lora"]
    results["split_lora"]["expert_lora_params"] = lora_result["expert_lora_params"]
    results["split_lora"]["attention_lora_params"] = lora_result["attention_lora_params"]
    results["split_lora"]["vram_after_lora_gb"] = get_vram_gb()

    print(f"Expert LoRA active: {lora_result['has_expert_lora']}")
    print(f"Expert LoRA params: {lora_result['expert_lora_params']:,}")
    print_trainable_params(split_model)
    tracker.complete("split_lora")

    # Benchmark
    tracker.start("split_bench")
    reset_vram_stats()

    print(f"\nRunning {BENCH_CONFIG['benchmark_steps']} benchmark steps (bfloat16 Split LoRA)...")
    print("First run triggers Triton autotuning (~2 min), excluded from timing.\n")

    split_training_args = SFTConfig(
        output_dir="/tmp/split_lora_bench/split_lora",
        max_steps=BENCH_CONFIG["benchmark_steps"],
        per_device_train_batch_size=BENCH_CONFIG["split_batch_size"],
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        lr_scheduler_type="cosine",
        warmup_steps=5,
        logging_steps=5,
        bf16=True,
        optim="adamw_8bit",
        seed=42,
        max_seq_length=BENCH_CONFIG["max_seq_length"],
        packing=True,
        dataset_text_field="text",
        report_to="none",
        save_strategy="no",
    )

    split_trainer = SFTTrainer(
        model=split_model,
        tokenizer=tokenizer,
        train_dataset=benchmark_dataset,
        args=split_training_args,
    )

    # Warmup: includes Triton autotune, not timed
    print("Warmup + Triton autotune (3 steps, not timed)...")
    split_training_args.max_steps = 3
    split_trainer.train()

    # Reset for actual benchmark
    reset_vram_stats()
    split_training_args.max_steps = BENCH_CONFIG["benchmark_steps"]
    split_trainer.args.max_steps = BENCH_CONFIG["benchmark_steps"]

    print(f"Benchmarking ({BENCH_CONFIG['benchmark_steps']} steps)...")
    t0 = time.perf_counter()
    split_trainer.train()
    wall_time = time.perf_counter() - t0

    log_history = split_trainer.state.log_history
    losses = [e["loss"] for e in log_history if "loss" in e]

    results["split_lora"]["wall_time_s"] = wall_time
    results["split_lora"]["avg_step_time_s"] = wall_time / BENCH_CONFIG["benchmark_steps"]
    results["split_lora"]["peak_vram_gb"] = get_vram_peak_gb()
    results["split_lora"]["loss_start"] = losses[0] if losses else None
    results["split_lora"]["loss_end"] = losses[-1] if losses else None

    print_measurement("Split LoRA (bfloat16)", results["split_lora"])
    tracker.complete("split_bench")
else:
    print("Skipping Split LoRA benchmark (model load failed).")


## Step 4: Short SFT Training Run (Validation)

Runs 100 steps with `logging_steps=1` to confirm the Split LoRA model converges properly. Three quality gate checks are applied:

1. **No NaN/Inf losses**: Training instability check — any NaN or Inf terminates the gate.
2. **Loss decreasing**: Compares average of first 10 vs last 10 logged losses. Expected: last_10 < first_10.
3. **Expert gradient flow**: Verifies that expert layer parameters receive gradients, confirming Split LoRA is training the MoE experts and not just the attention layers.

In [None]:
import math

first_10_avg = None
last_10_avg = None

if split_load_ok and BENCH_CONFIG["include_training_run"]:
    tracker.start("training_run")

    VALIDATION_STEPS = 100
    print(f"Running {VALIDATION_STEPS}-step training validation...")

    val_args = SFTConfig(
        output_dir="/tmp/split_lora_bench/validation",
        max_steps=VALIDATION_STEPS,
        per_device_train_batch_size=BENCH_CONFIG["split_batch_size"],
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        lr_scheduler_type="cosine",
        warmup_steps=10,
        logging_steps=1,
        bf16=True,
        optim="adamw_8bit",
        seed=42,
        max_seq_length=BENCH_CONFIG["max_seq_length"],
        packing=True,
        dataset_text_field="text",
        report_to="none",
        save_strategy="no",
    )

    # Reuse split_model if still loaded, otherwise note skip
    val_trainer = SFTTrainer(
        model=split_model,
        tokenizer=tokenizer,
        train_dataset=benchmark_dataset,
        args=val_args,
    )

    val_trainer.train()

    # Extract losses
    val_losses = [e["loss"] for e in val_trainer.state.log_history if "loss" in e]

    print("\n" + "=" * 60)
    print("  QUALITY GATE: Training Validation")
    print("=" * 60)

    checks = {}

    # Check 1: No NaN/Inf
    nan_count = sum(1 for l in val_losses if math.isnan(l) or math.isinf(l))
    checks["no_nan_inf"] = nan_count == 0
    print(f"  {'PASS' if checks['no_nan_inf'] else 'FAIL'} - No NaN/Inf losses ({nan_count} found)")

    # Check 2: Loss decreasing (first 10 avg vs last 10 avg)
    if len(val_losses) >= 20:
        first_10_avg = sum(val_losses[:10]) / 10
        last_10_avg = sum(val_losses[-10:]) / 10
        loss_decreased = last_10_avg < first_10_avg
        checks["loss_decreasing"] = loss_decreased
        improvement_pct = (first_10_avg - last_10_avg) / first_10_avg * 100
        print(f"  {'PASS' if loss_decreased else 'FAIL'} - Loss decreasing: {first_10_avg:.4f} -> {last_10_avg:.4f} ({improvement_pct:+.1f}%)")
    else:
        checks["loss_decreasing"] = None
        print(f"  SKIP - Not enough loss entries ({len(val_losses)}) for trend analysis")

    # Check 3: Expert gradient flow
    expert_has_grad = False
    for name, param in split_model.named_parameters():
        if "expert" in name.lower() and param.requires_grad and param.grad is not None:
            if param.grad.abs().sum() > 0:
                expert_has_grad = True
                break
    checks["expert_gradient_flow"] = expert_has_grad
    print(f"  {'PASS' if expert_has_grad else 'WARN'} - Expert gradient flow {'detected' if expert_has_grad else 'not detected (may be cleared after step)'}")

    # Overall verdict
    critical_passed = checks["no_nan_inf"] and (checks["loss_decreasing"] is not False)
    results["training_validation"] = {
        "steps": VALIDATION_STEPS,
        "losses": val_losses,
        "checks": checks,
        "passed": critical_passed,
        "first_10_avg": first_10_avg,
        "last_10_avg": last_10_avg,
    }

    print(f"\n  VERDICT: {'PASS' if critical_passed else 'FAIL'}")

    if critical_passed:
        tracker.complete("training_run")
    else:
        tracker.fail("training_run")

    # Cleanup validation trainer
    del val_trainer

elif not BENCH_CONFIG["include_training_run"]:
    tracker.skip("training_run")
    print("Training validation skipped (include_training_run=False)")
else:
    tracker.skip("training_run")
    print("Training validation skipped (Split LoRA model load failed)")


## Step 5: Results Comparison

In [None]:
tracker.start("comparison")

print("\n" + "=" * 70)
print("  BENCHMARK RESULTS: 4-bit QLoRA vs bfloat16 Split LoRA")
print("=" * 70)

# Side-by-side table
metrics = [
    ("Peak VRAM (GB)", "peak_vram_gb", ".2f"),
    ("Avg Step Time (s)", "avg_step_time_s", ".3f"),
    ("VRAM After Load (GB)", "vram_after_load_gb", ".2f"),
    ("VRAM After LoRA (GB)", "vram_after_lora_gb", ".2f"),
    ("Load Time (s)", "load_time_s", ".1f"),
    ("Expert LoRA Params", "expert_lora_params", ","),
    ("Attention LoRA Params", "attention_lora_params", ","),
    ("Expert LoRA Active", "has_expert_lora", ""),
    ("Loss (start)", "loss_start", ".4f"),
    ("Loss (end)", "loss_end", ".4f"),
]

header = f"{'Metric':<30} {'4-bit QLoRA':>18} {'bf16 Split LoRA':>18} {'Delta':>12}"
print(header)
print("-" * len(header))

for label, key, fmt in metrics:
    b_val = results["baseline"].get(key)
    s_val = results["split_lora"].get(key)

    if b_val is None:
        b_str = "N/A"
    elif fmt == ",":
        b_str = f"{b_val:,}"
    elif fmt == "":
        b_str = str(b_val)
    else:
        b_str = f"{b_val:{fmt}}"

    if s_val is None:
        s_str = "N/A"
    elif fmt == ",":
        s_str = f"{s_val:,}"
    elif fmt == "":
        s_str = str(s_val)
    else:
        s_str = f"{s_val:{fmt}}"

    # Calculate delta for numeric values
    delta_str = ""
    if isinstance(b_val, (int, float)) and isinstance(s_val, (int, float)) and b_val != 0:
        if "time" in key.lower() or "vram" in key.lower():
            ratio = s_val / b_val
            delta_str = f"{ratio:.2f}x"

    print(f"{label:<30} {b_str:>18} {s_str:>18} {delta_str:>12}")

# Speedup headline
b_step = results["baseline"].get("avg_step_time_s")
s_step = results["split_lora"].get("avg_step_time_s")
if b_step and s_step and s_step > 0:
    speedup = b_step / s_step
    print(f"\nSpeedup: {speedup:.1f}x {'faster' if speedup > 1 else 'slower'} with Split LoRA")
    results["speedup"] = speedup

b_vram = results["baseline"].get("peak_vram_gb")
s_vram = results["split_lora"].get("peak_vram_gb")
if b_vram and s_vram:
    vram_savings = (b_vram - s_vram) / b_vram * 100
    print(f"VRAM savings: {vram_savings:+.1f}% with Split LoRA")
    results["vram_savings_pct"] = vram_savings

# Optional matplotlib chart
try:
    import matplotlib
    matplotlib.use("Agg")
    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Step time comparison
    if b_step and s_step:
        bars = axes[0].bar(["4-bit QLoRA", "bf16 Split LoRA"], [b_step, s_step], color=["#2196F3", "#4CAF50"])
        axes[0].set_ylabel("Avg Step Time (s)")
        axes[0].set_title("Training Step Time")
        for bar, val in zip(bars, [b_step, s_step]):
            axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01, f"{val:.3f}s", ha="center", va="bottom")

    # VRAM comparison
    if b_vram and s_vram:
        bars = axes[1].bar(["4-bit QLoRA", "bf16 Split LoRA"], [b_vram, s_vram], color=["#2196F3", "#4CAF50"])
        axes[1].set_ylabel("Peak VRAM (GB)")
        axes[1].set_title("Peak VRAM Usage")
        for bar, val in zip(bars, [b_vram, s_vram]):
            axes[1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.2, f"{val:.1f} GB", ha="center", va="bottom")

    plt.tight_layout()
    plt.savefig("/tmp/split_lora_bench/comparison.png", dpi=150, bbox_inches="tight")
    plt.show()
    print("Chart saved to /tmp/split_lora_bench/comparison.png")
except ImportError:
    print("(matplotlib not available, skipping chart)")

tracker.complete("comparison")


## Step 6: Recommendation

In [None]:
tracker.start("recommendation")

print("\n" + "=" * 70)
print("  RECOMMENDATION")
print("=" * 70)

# Determine viability
speedup = results.get("speedup", 0)
expert_lora_ok = results.get("split_lora", {}).get("has_expert_lora", False)
training_ok = results.get("training_validation", {}).get("passed", None)
has_error = "error" in results.get("split_lora", {})

viable = (
    speedup >= 1.5
    and expert_lora_ok
    and training_ok is not False  # None (skipped) is acceptable
    and not has_error
)

results["verdict"] = "VIABLE" if viable else "NOT VIABLE"

if viable:
    print(f"\n  VERDICT: VIABLE")
    print(f"  Split LoRA achieves {speedup:.1f}x speedup with working expert LoRA.\n")

    print("  Suggested pipeline changes:")
    print("  " + "-" * 40)
    print()
    print("  1. configs/gpt_oss_20b.py:")
    print("     load_in_4bit: False  # Was: True")
    print('     dtype: "bfloat16"')
    print()
    print("  2. scripts/pipeline_lib/unsloth_utils.py:")
    print("     Add to load_unsloth_model():")
    print("       # Auto-detect MoE backend for Split LoRA")
    print('       if not load_in_4bit and os.environ.get("UNSLOTH_MOE_BACKEND") is None:')
    print('           gpu_name = torch.cuda.get_device_name(0)')
    print('           backend = "grouped_mm" if "H100" in gpu_name or "H200" in gpu_name else "unsloth_triton"')
    print('           os.environ["UNSLOTH_MOE_BACKEND"] = backend')
    print()
    print("  3. GPU_BASE config updates:")
    print(f'     moe_backend: "{BENCH_CONFIG["moe_backend"]}"')
    print(f"     split_lora_speedup: {speedup:.1f}x")
    print(f'     peak_vram_gb: {results["split_lora"].get("peak_vram_gb", "N/A")}')

    if results.get("vram_savings_pct"):
        vram_note = (
            f"saves {abs(results['vram_savings_pct']):.0f}% VRAM"
            if results["vram_savings_pct"] > 0
            else f"uses {abs(results['vram_savings_pct']):.0f}% more VRAM"
        )
        print(f"\n  Note: bfloat16 Split LoRA {vram_note} vs 4-bit QLoRA")

else:
    print(f"\n  VERDICT: NOT VIABLE")

    reasons = []
    if has_error:
        reasons.append(f"Model load error: {results['split_lora'].get('error')}")
    if speedup < 1.5:
        reasons.append(f"Insufficient speedup: {speedup:.1f}x (need >= 1.5x)")
    if not expert_lora_ok:
        reasons.append("Expert LoRA not working")
    if training_ok is False:
        reasons.append("Training validation failed")

    print("  Reasons:")
    for r in reasons:
        print(f"    - {r}")

    print("\n  Recommendation: Continue using 4-bit QLoRA for production pipeline.")

# Save results
with open("/tmp/split_lora_bench/results.json", "w") as f:
    # Convert non-serializable values
    serializable = {}
    for k, v in results.items():
        if isinstance(v, dict):
            serializable[k] = {
                sk: sv for sk, sv in v.items()
                if not isinstance(sv, list) or len(sv) < 200  # Skip very long loss lists
            }
        else:
            serializable[k] = v
    json.dump(serializable, f, indent=2, default=str)

print(f"\nResults saved to /tmp/split_lora_bench/results.json")
tracker.complete("recommendation")


## Benchmark Complete!

### Next Steps

**If VIABLE:**
- Update `configs/gpt_oss_20b.py`: set `load_in_4bit=False` and `dtype="bfloat16"`
- Update `scripts/pipeline_lib/unsloth_utils.py`: add MoE backend auto-detection in `load_unsloth_model()`
- Re-run `notebooks/train_gpt_oss_coding_tui.ipynb` with Split LoRA enabled
- Validate on real TUI training data (proxy logs + function-calling datasets)
- Monitor per-expert gradient norms in wandb to confirm expert utilization
- Compare final eval metrics (tool call accuracy, agent trajectory quality) against QLoRA baseline

**If NOT VIABLE:**
- Continue with 4-bit QLoRA as the production configuration
- Revisit when Unsloth adds Split LoRA support for 4-bit quantized models
- Check https://unsloth.ai/docs/new/faster-moe for updates on quantization support
- Consider profiling with `torch.profiler` to identify the actual bottleneck

### Output Artifacts
- `/tmp/split_lora_bench/results.json` — Full benchmark results (portable)
- `/tmp/split_lora_bench/comparison.png` — Side-by-side bar charts (if matplotlib available)
- `/tmp/split_lora_bench/config.json` — Benchmark configuration snapshot

In [None]:
# Final cleanup
try:
    if "split_model" in dir():
        del split_model
    if "split_trainer" in dir():
        del split_trainer
    if "tokenizer" in dir():
        del tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    print(f"Final VRAM: {get_vram_gb():.2f} GB")
except Exception:
    pass

# Colab: release GPU
try:
    if IN_COLAB:
        from google.colab import runtime
        runtime.unassign()
        print("Colab GPU released.")
except Exception:
    pass

print("\nBenchmark complete. Results at /tmp/split_lora_bench/")
