# Train GPT-OSS 20B Rust Agent (v4)

**Combined pipeline** — Strandset-first data + mutation enrichment across ALL phases + GRPO RL.

**v4 features:**
- Widget-based configuration UI (no manual variable editing)
- Mutation data enriches **all** training phases (lang_rust, core_agent, IPO)
- Pipeline progress dashboard with per-phase tracking
- Full pipeline: Strandset → lang_rust → merge → core_agent → IPO → GRPO → eval → export

**Base model:** [openai/gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) (20.9B MoE, 3.6B active)

**Data sources:**
- [Strandset-Rust-v1](https://huggingface.co/datasets/Fortytwo-Network/Strandset-Rust-v1) (191K examples)
- cargo-mutants generated mutations (optional, enriches all phases)

## Step 0: Environment Setup

### 0.1 Mount Google Drive & Clone Repository

In [None]:
import os

IN_COLAB = "COLAB_GPU" in os.environ or os.path.exists("/content")

DRIVE_BASE = ""
DRIVE_MODE = "local"

if IN_COLAB:
    from google.colab import drive
    drive.mount("/content/drive")
    DRIVE_BASE = "/content/drive/MyDrive/gpt-oss-20b-rust-agent-v4"
    DRIVE_MODE = "mounted"
    os.makedirs(DRIVE_BASE, exist_ok=True)

    # Clone repo if not present
    if not os.path.exists("llm-training-pipeline"):
        !git clone https://github.com/rmarnold/llm-training-pipeline.git
    os.chdir("llm-training-pipeline")
    !git pull --ff-only
    print(f"Working directory: {os.getcwd()}")
else:
    print("Running locally (not in Colab).")
    print(f"Working directory: {os.getcwd()}")

### 0.2 Install Dependencies

In [None]:
import subprocess, sys

IN_COLAB = "COLAB_GPU" in os.environ or os.path.exists("/content")

if IN_COLAB:
    # Core + GPT-OSS deps
    !pip install -q -e ".[gpt_oss,rust_eval]"

    # Unsloth (Colab optimised)
    !pip install -q unsloth

    # Rust toolchain (needed for cargo-mutants + eval)
    !curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y
    os.environ["PATH"] = os.path.expanduser("~/.cargo/bin") + ":" + os.environ["PATH"]
    !cargo install cargo-mutants

    # vLLM for fast inference in GRPO
    !pip install -q vllm

    # ipywidgets for config UI
    !pip install -q ipywidgets
    print("\nDependencies installed.")
else:
    print("Assuming local dependencies are already installed.")
    print("Run: pip install -e '.[gpt_oss,rust_eval]'")

### 0.3 Configure Pipeline

Toggle the form view (click the "..." menu on this cell) to see the interactive configuration panel.
Adjust settings, then **run this cell** to apply them.

In [None]:
#@title ### Pipeline Configuration { display-mode: "form" }

#@markdown ---
#@markdown #### Core Settings

training_scope = "quick_test"  #@param ["full", "quick_test", "lang_adapter_only", "skip_to_rl"] {type: "string"}
gpu_tier = "h100_80gb"  #@param ["a100_40gb", "a100_80gb", "h100_80gb"] {type: "string"}
max_steps_override = 0  #@param {type: "integer"}

#@markdown > *Max Steps Override: 0 = use GPU tier defaults. Set > 0 to cap all stages.*

#@markdown ---
#@markdown #### Mutation Enrichment
#@markdown *Requires Rust toolchain (`cargo`, `cargo-mutants`). Mutations run in background during GPU training.*

include_mutations = True  #@param {type: "boolean"}
enrich_lang_rust = True  #@param {type: "boolean"}
enrich_ipo = True  #@param {type: "boolean"}

#@markdown ---
#@markdown #### Pipeline Phases

include_ipo = True  #@param {type: "boolean"}
include_grpo = True  #@param {type: "boolean"}
skip_data_generation = False  #@param {type: "boolean"}

#@markdown ---
#@markdown #### Export

enable_qat_export = False  #@param {type: "boolean"}

#@markdown ---
#@markdown #### Advanced

max_mutations_per_repo = 50  #@param {type: "slider", min: 10, max: 200, step: 10}
use_service_account = False  #@param {type: "boolean"}
drive_folder_id = "18UpFpUhiNrs2Etha0uFjSGWmj1Ee1SnX"  #@param {type: "string"}

# ======================================================================
# GPU tier presets (auto-selected based on gpu_tier above)
# ======================================================================
import os, sys, json

GPU_CONFIGS = {
    "a100_40gb": {
        "lang_rust_batch": 2, "lang_rust_grad_accum": 16, "lang_rust_max_steps": 3000, "lang_rust_seq_len": 4096,
        "core_agent_batch": 1, "core_agent_grad_accum": 8, "core_agent_max_steps": 2000, "core_agent_seq_len": 8192,
        "ipo_batch": 1, "ipo_grad_accum": 16, "ipo_max_steps": 1000, "ipo_seq_len": 4096,
        "grpo_batch": 1, "grpo_grad_accum": 8, "grpo_max_steps": 3000, "grpo_seq_len": 16384, "grpo_num_gen": 4,
        "eval_num_samples": 100,
        "load_mode": "4bit", "moe_backend": "triton", "fast_inference": False,
    },
    "a100_80gb": {
        "lang_rust_batch": 4, "lang_rust_grad_accum": 8, "lang_rust_max_steps": 3000, "lang_rust_seq_len": 8192,
        "core_agent_batch": 2, "core_agent_grad_accum": 4, "core_agent_max_steps": 2000, "core_agent_seq_len": 16384,
        "ipo_batch": 1, "ipo_grad_accum": 16, "ipo_max_steps": 1000, "ipo_seq_len": 8192,
        "grpo_batch": 1, "grpo_grad_accum": 8, "grpo_max_steps": 5000, "grpo_seq_len": 32768, "grpo_num_gen": 4,
        "eval_num_samples": 200,
        "load_mode": "4bit", "moe_backend": "triton", "fast_inference": False,
    },
    "h100_80gb": {
        "lang_rust_batch": 6, "lang_rust_grad_accum": 8, "lang_rust_max_steps": 3000, "lang_rust_seq_len": 8192,
        "core_agent_batch": 6, "core_agent_grad_accum": 4, "core_agent_max_steps": 2000, "core_agent_seq_len": 16384,
        "ipo_batch": 2, "ipo_grad_accum": 16, "ipo_max_steps": 1000, "ipo_seq_len": 8192,
        "grpo_batch": 2, "grpo_grad_accum": 8, "grpo_max_steps": 5000, "grpo_seq_len": 65536, "grpo_num_gen": 4,
        "eval_num_samples": 200,
        "load_mode": "fp8", "moe_backend": "triton", "fast_inference": True,
    },
}

tier = GPU_CONFIGS[gpu_tier]

# ======================================================================
# Build CONFIG dict from form values
# ======================================================================
CONFIG = {
    "training_scope": training_scope,
    "gpu_tier": gpu_tier,
    **tier,
    # Mutation enrichment
    "include_mutations": include_mutations,
    "enrich_lang_rust": enrich_lang_rust and include_mutations,
    "enrich_ipo": enrich_ipo and include_mutations,
    # Pipeline phases
    "include_ipo": include_ipo,
    "include_grpo": include_grpo,
    "enable_qat_export": enable_qat_export,
    "skip_data_generation": skip_data_generation,
    # Advanced
    "max_mutations_per_repo": max_mutations_per_repo,
    "use_service_account": use_service_account,
    "drive_folder_id": drive_folder_id,
}

# Apply max_steps override
if max_steps_override > 0:
    for key in list(CONFIG.keys()):
        if key.endswith("_max_steps"):
            CONFIG[key] = max_steps_override

# Quick test caps
if CONFIG["training_scope"] == "quick_test":
    for key in list(CONFIG.keys()):
        if key.endswith("_max_steps"):
            CONFIG[key] = min(CONFIG[key], 50)
    CONFIG["eval_num_samples"] = 10

# Auto-detect mutation parallelism
import multiprocessing
cpu_count = multiprocessing.cpu_count()
cpu_jobs = max(1, cpu_count - 2)
try:
    mem_bytes = os.sysconf("SC_PAGE_SIZE") * os.sysconf("SC_PHYS_PAGES")
    ram_jobs = max(1, int(mem_bytes / (1024**3) / 4))
except (ValueError, OSError):
    ram_jobs = cpu_jobs
CONFIG["mutation_jobs"] = min(cpu_jobs, ram_jobs)
CONFIG["mutation_repo_workers"] = max(1, CONFIG["mutation_jobs"] // 4)

# Scope-based overrides (can't do reactive UI in Colab Forms, so enforce here)
if CONFIG["training_scope"] == "lang_adapter_only":
    CONFIG["include_ipo"] = False
    CONFIG["include_grpo"] = False
elif CONFIG["training_scope"] == "skip_to_rl":
    CONFIG["include_mutations"] = False
    CONFIG["enrich_lang_rust"] = False
    CONFIG["enrich_ipo"] = False

# ======================================================================
# Set up DriveHelper
# ======================================================================
sys.path.insert(0, "scripts")
from pipeline_lib.drive_utils import DriveHelper

if "DRIVE_BASE" not in dir():
    DRIVE_BASE = ""
if "DRIVE_MODE" not in dir():
    DRIVE_MODE = "local"

if CONFIG["use_service_account"] and CONFIG["drive_folder_id"]:
    sa_path = "service_account.json"
    try:
        from google.colab import userdata
        sa_json = userdata.get("SERVICE_ACCOUNT_JSON")
        with open(sa_path, "w") as f:
            f.write(sa_json)
    except Exception:
        pass

    if os.path.exists(sa_path) and os.path.getsize(sa_path) > 10:
        try:
            drive_helper = DriveHelper(
                mode="service_account",
                credentials_path=sa_path,
                folder_id=CONFIG["drive_folder_id"],
            )
            DRIVE_MODE = "service_account"
        except Exception as e:
            print(f"Service account failed: {e}")
            drive_helper = DriveHelper(mode="local")
            DRIVE_MODE = "local"
    else:
        drive_helper = DriveHelper(mode="local")
        DRIVE_MODE = "local"
elif DRIVE_BASE:
    drive_helper = DriveHelper(mode="mounted", drive_base=DRIVE_BASE)
    DRIVE_MODE = "mounted"
else:
    drive_helper = DriveHelper(mode="local")
    DRIVE_MODE = "local"

# Save for persistence across restarts
os.makedirs("data", exist_ok=True)
with open("data/config_v4.json", "w") as f:
    json.dump(CONFIG, f, indent=2)

# ======================================================================
# Print summary
# ======================================================================
print("=" * 55)
print("  PIPELINE CONFIGURATION (v4)")
print("=" * 55)
print(f"  Scope:          {CONFIG['training_scope'].upper()}")
print(f"  GPU tier:       {CONFIG['gpu_tier']}")
print(f"  MoE backend:    {CONFIG['moe_backend']}")
print(f"  Load mode:      {CONFIG['load_mode']}")
print(f"  Drive mode:     {DRIVE_MODE}")
print()
print(f"  Mutations:      {CONFIG['include_mutations']}")
if CONFIG["include_mutations"]:
    print(f"    Enrich LR:    {CONFIG['enrich_lang_rust']}")
    print(f"    Enrich IPO:   {CONFIG['enrich_ipo']}")
    print(f"    Mut/repo:     {CONFIG['max_mutations_per_repo']}")
    print(f"    Jobs:         {CONFIG['mutation_jobs']} ({CONFIG['mutation_repo_workers']} repo workers)")
print()
print(f"  IPO:            {CONFIG['include_ipo']}")
print(f"  GRPO:           {CONFIG['include_grpo']}")
print(f"  QAT export:     {CONFIG['enable_qat_export']}")
if max_steps_override > 0:
    print(f"  Max steps:      {max_steps_override} (override)")
print()
print(f"  Lang Adapter:   batch={CONFIG['lang_rust_batch']} x grad_accum={CONFIG['lang_rust_grad_accum']}, seq={CONFIG['lang_rust_seq_len']}, steps={CONFIG['lang_rust_max_steps']}")
print(f"  Core Agent:     batch={CONFIG['core_agent_batch']} x grad_accum={CONFIG['core_agent_grad_accum']}, seq={CONFIG['core_agent_seq_len']}, steps={CONFIG['core_agent_max_steps']}")
if CONFIG["include_ipo"]:
    print(f"  IPO:            batch={CONFIG['ipo_batch']} x grad_accum={CONFIG['ipo_grad_accum']}, seq={CONFIG['ipo_seq_len']}, steps={CONFIG['ipo_max_steps']}")
if CONFIG["include_grpo"]:
    print(f"  GRPO:           batch={CONFIG['grpo_batch']} x grad_accum={CONFIG['grpo_grad_accum']}, seq={CONFIG['grpo_seq_len']}, steps={CONFIG['grpo_max_steps']}")
print("=" * 55)

### 0.4 Pipeline Dashboard

In [None]:
import ipywidgets as widgets
from IPython.display import display

class PipelineTracker:
    """Track pipeline progress with visual indicators."""

    PHASES = [
        ("strandset", "Strandset Data"),
        ("mutations_bg", "Mutations (Background)"),
        ("lang_rust", "Lang Adapter Training"),
        ("merge", "Merge Adapter"),
        ("enrichment", "Mutation Enrichment"),
        ("core_agent", "Core Agent SFT"),
        ("ipo", "IPO Preference"),
        ("grpo", "GRPO RL"),
        ("eval", "Evaluation"),
        ("export", "Export"),
    ]

    def __init__(self):
        self._bars = {}
        self._labels = {}
        rows = []
        for key, name in self.PHASES:
            label = widgets.HTML(
                value=f"<span style='color:#888'>&#x25CB; {name}</span>",
                layout=widgets.Layout(width="220px"),
            )
            bar = widgets.FloatProgress(
                value=0, min=0, max=1.0,
                bar_style="info",
                layout=widgets.Layout(width="300px", height="18px"),
            )
            self._bars[key] = bar
            self._labels[key] = label
            rows.append(widgets.HBox([label, bar]))
        self._container = widgets.VBox(rows)
        display(widgets.HTML("<b>Pipeline Progress</b>"))
        display(self._container)

    def start(self, phase):
        self._labels[phase].value = (
            f"<span style='color:#2196F3'>&#x25B6; {dict(self.PHASES)[phase]}</span>"
        )
        self._bars[phase].value = 0.1
        self._bars[phase].bar_style = "info"

    def complete(self, phase):
        self._labels[phase].value = (
            f"<span style='color:#4CAF50'>&#x2714; {dict(self.PHASES)[phase]}</span>"
        )
        self._bars[phase].value = 1.0
        self._bars[phase].bar_style = "success"

    def skip(self, phase):
        self._labels[phase].value = (
            f"<span style='color:#9E9E9E'>&#x2014; {dict(self.PHASES)[phase]} (skipped)</span>"
        )
        self._bars[phase].value = 1.0
        self._bars[phase].bar_style = ""

    def fail(self, phase):
        self._labels[phase].value = (
            f"<span style='color:#F44336'>&#x2718; {dict(self.PHASES)[phase]}</span>"
        )
        self._bars[phase].bar_style = "danger"

tracker = PipelineTracker()

### 0.5 Set Up Persistent Storage

In [None]:
import os

DRIVE_SUBDIRS = [
    "data/rust/strandset",
    "data/rust/mutations",
    "data/rust/lang_rust",
    "data/rust/core_agent",
    "data/rust/ipo",
    "data/rust/grpo",
    "data/rust/eval",
    "checkpoints/lang_rust",
    "checkpoints/gpt-oss-20b-rust-merged",
    "checkpoints/core_agent",
    "checkpoints/core_agent_ipo",
    "checkpoints/core_agent_grpo",
    "evals",
]

if DRIVE_MODE == "mounted":
    for subdir in DRIVE_SUBDIRS:
        drive_path = os.path.join(DRIVE_BASE, subdir)
        os.makedirs(drive_path, exist_ok=True)
        local_path = subdir
        if not os.path.exists(local_path):
            os.makedirs(os.path.dirname(local_path) or ".", exist_ok=True)
            os.symlink(drive_path, local_path)
            print(f"  Linked: {local_path} -> {drive_path}")
        else:
            print(f"  Exists: {local_path}")
    print(f"\nDrive base: {DRIVE_BASE}")
elif DRIVE_MODE == "service_account":
    for subdir in DRIVE_SUBDIRS:
        os.makedirs(subdir, exist_ok=True)
        drive_helper.ensure_dir(subdir)
    print("Drive directories created (service account mode).")
else:
    for subdir in DRIVE_SUBDIRS:
        os.makedirs(subdir, exist_ok=True)
    print("Local directories created (no Drive backup).")

### 0.6 Check GPU & Configure MoE Backend

In [None]:
import torch, os

if torch.cuda.is_available():
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / (1024**3)

    print(f"GPU: {gpu_name}")
    print(f"VRAM: {gpu_mem:.1f} GB")

    # Auto-detect GPU tier override
    detected_tier = None
    if "H100" in gpu_name or "H200" in gpu_name:
        detected_tier = "h100_80gb"
    elif "A100" in gpu_name:
        detected_tier = "a100_80gb" if gpu_mem > 45 else "a100_40gb"

    if detected_tier and detected_tier != CONFIG["gpu_tier"]:
        print(f"\n  Auto-override: {CONFIG['gpu_tier']} -> {detected_tier}")
        old_tier = CONFIG["gpu_tier"]
        CONFIG["gpu_tier"] = detected_tier
        tier = GPU_CONFIGS[detected_tier]
        for k, v in tier.items():
            CONFIG[k] = v
        print(f"  Updated CONFIG with {detected_tier} presets.")

    # Set MoE backend
    os.environ["UNSLOTH_MOE_BACKEND"] = CONFIG.get("moe_backend", "triton")
    print(f"\n  MoE backend: {os.environ['UNSLOTH_MOE_BACKEND']}")
    print(f"  Load mode: {CONFIG['load_mode']}")
    print(f"  Fast inference: {CONFIG.get('fast_inference', False)}")

    # FP8 detection
    if CONFIG["load_mode"] == "fp8":
        try:
            import transformer_engine
            print("  FP8: transformer-engine available")
        except ImportError:
            print("  FP8: transformer-engine not found, falling back to 4bit")
            CONFIG["load_mode"] = "4bit"
else:
    print("No GPU detected! Training will fail.")
    print("Enable GPU: Runtime -> Change runtime type -> GPU")

print(f"\nFinal config: scope={CONFIG['training_scope']}, tier={CONFIG['gpu_tier']}")

## Step 1: Data Preparation

### 1.1 Download & Format Strandset

In [None]:
if CONFIG["skip_data_generation"]:
    print("Skipping data generation (skip_data_generation=True)")
    tracker.skip("strandset")
else:
    tracker.start("strandset")

    cmd = "python scripts/20_prepare_strandset.py"

    if CONFIG["training_scope"] == "quick_test":
        cmd += " --max_samples 500"

    print("Downloading and formatting Strandset-Rust-v1...")
    print("=" * 60)

    !{cmd}

    # Copy Strandset outputs to training dirs
    import shutil
    from pathlib import Path

    mappings = [
        ("data/rust/strandset/lang_rust/train", "data/rust/lang_rust/train"),
        ("data/rust/strandset/core_agent/train", "data/rust/core_agent/train"),
        ("data/rust/strandset/ipo/train", "data/rust/ipo/train"),
    ]
    for src, dst in mappings:
        if os.path.exists(src):
            if os.path.islink(dst):
                # Symlinked to Drive — copy into the link target
                shutil.copytree(src, dst, dirs_exist_ok=True)
            elif not os.path.exists(dst) or len(os.listdir(dst)) == 0:
                shutil.copytree(src, dst, dirs_exist_ok=True)
            print(f"  {src} -> {dst}")

    drive_helper.backup("data/rust/strandset", "data/rust/strandset")
    if DRIVE_MODE != "local":
        print("\nStrandset backed up to Drive.")

    tracker.complete("strandset")

### 1.2 Start Background Mutations

In [None]:
import subprocess

mutation_proc = None

if not CONFIG["include_mutations"]:
    print("Mutations disabled — skipping.")
    tracker.skip("mutations_bg")
elif CONFIG["skip_data_generation"]:
    print("Skipping data generation (skip_data_generation=True)")
    tracker.skip("mutations_bg")
else:
    tracker.start("mutations_bg")

    max_muts = CONFIG["max_mutations_per_repo"]
    backup_dir = "data/rust/mutations/backup"

    cmd = [
        "python", "scripts/16_generate_mutations.py",
        "--max_mutations_per_repo", str(max_muts),
        "--backup-dir", backup_dir,
    ]

    if CONFIG["training_scope"] == "quick_test":
        cmd.extend(["--repos", "ripgrep"])

    print(f"Starting mutations in background (max {max_muts}/repo)...")
    print(f"  Command: {' '.join(cmd)}")
    mutation_proc = subprocess.Popen(
        cmd,
        stdout=open("data/rust/mutations/stdout.log", "w"),
        stderr=subprocess.STDOUT,
    )
    print(f"  PID: {mutation_proc.pid}")
    print("  Mutations will run during lang_rust training (~60 min).")

### 1.3 Verify Data

In [None]:
data_checks = [
    ("Strandset stats", "data/rust/strandset/stats.json"),
    ("Lang Rust train", "data/rust/lang_rust/train"),
    ("Core Agent train", "data/rust/core_agent/train"),
    ("IPO train", "data/rust/ipo/train"),
    ("Mutations dir", "data/rust/mutations"),
]

print("Data Verification:")
print("=" * 60)
for name, path in data_checks:
    exists = os.path.exists(path)
    if exists and os.path.isdir(path):
        items = os.listdir(path)
        print(f"  \u2713 {name}: {path} ({len(items)} items)")
    elif exists:
        size_mb = os.path.getsize(path) / (1024 * 1024)
        print(f"  \u2713 {name}: {path} ({size_mb:.1f} MB)")
    else:
        print(f"  \u2717 {name}: not found at {path}")
print("=" * 60)

## Step 2: Lang Adapter Training

### 2.0 Pre-enrich lang_rust (Cached Mutations)

If a previous run generated mutation-based debug pairs for lang_rust, merge them
into the training data now. Uses a `.enriched_v4` marker file to avoid duplicates.

In [None]:
if CONFIG["training_scope"] == "skip_to_rl":
    print("Skipping — scope is skip_to_rl")
elif not CONFIG["enrich_lang_rust"]:
    print("Skipping — enrich_lang_rust is disabled")
else:
    cache_dir = "data/rust/mutations/lang_rust_enrichment"
    train_dir = "data/rust/lang_rust/train"
    marker = os.path.join(train_dir, ".enriched_v4")

    if os.path.exists(cache_dir) and not os.path.exists(marker):
        from datasets import load_from_disk, concatenate_datasets

        try:
            cached_ds = load_from_disk(cache_dir)
            base_ds = load_from_disk(train_dir)
            print(f"Pre-enriching lang_rust:")
            print(f"  Base dataset: {len(base_ds):,} examples")
            print(f"  Cached mutations: {len(cached_ds):,} debug pairs")
            merged = concatenate_datasets([base_ds, cached_ds])
            merged.save_to_disk(train_dir)

            # Write marker to prevent re-merging
            with open(marker, "w") as f:
                f.write(f"enriched with {len(cached_ds)} cached examples\n")

            print(f"  Merged: {len(merged):,} total -> {train_dir}")
        except Exception as e:
            print(f"  Warning: could not load cached enrichment: {e}")
    elif os.path.exists(marker):
        print("Already enriched (marker found). Skipping.")
    else:
        print("No cached mutation enrichment found. Will generate for next run.")

### 2.1 Train lang_rust Adapter

In [None]:
if CONFIG["training_scope"] == "skip_to_rl":
    print("Skipping — scope is skip_to_rl")
    tracker.skip("lang_rust")
else:
    tracker.start("lang_rust")

    batch = CONFIG["lang_rust_batch"]
    grad_accum = CONFIG["lang_rust_grad_accum"]
    max_steps = CONFIG["lang_rust_max_steps"]
    seq_len = CONFIG["lang_rust_seq_len"]

    cmd = f"python scripts/13_train_lang_adapter.py"
    cmd += f" --train_data_path data/rust/lang_rust/train"
    cmd += f" --per_device_train_batch_size {batch}"
    cmd += f" --gradient_accumulation_steps {grad_accum}"
    cmd += f" --max_steps {max_steps}"

    print(f"Training lang_rust adapter...")
    print(f"  Data: data/rust/lang_rust/train")
    print(f"  Batch: {batch} x {grad_accum} = {batch * grad_accum}")
    print(f"  Max steps: {max_steps}")
    print(f"  Seq length: {seq_len} (from config)")
    print(f"  LoRA rank: 64")
    print(f"  Split LoRA backend: {CONFIG['moe_backend']}")
    print("=" * 60)

    !{cmd}

    drive_helper.backup("checkpoints/lang_rust", "checkpoints/lang_rust")
    if DRIVE_MODE != "local":
        print("\nCheckpoint backed up to Drive.")

    tracker.complete("lang_rust")

### 2.2 Merge lang_rust into Base

In [None]:
if CONFIG["training_scope"] == "skip_to_rl":
    print("Skipping — scope is skip_to_rl")
    tracker.skip("merge")
else:
    tracker.start("merge")

    print("Merging lang_rust adapter into base model...")
    print("=" * 60)

    !python scripts/19_merge_adapter.py

    drive_helper.backup("checkpoints/gpt-oss-20b-rust-merged", "checkpoints/gpt-oss-20b-rust-merged")
    if DRIVE_MODE != "local":
        print("\nMerged model backed up to Drive.")

    tracker.complete("merge")

### 2.3 Verify Merge

In [None]:
if CONFIG["training_scope"] == "skip_to_rl":
    print("Skipping — scope is skip_to_rl")
else:
    merged_path = "checkpoints/gpt-oss-20b-rust-merged"

    print("Merge Verification:")
    print("=" * 60)

    if os.path.exists(merged_path):
        files = os.listdir(merged_path)
        total_size = sum(
            os.path.getsize(os.path.join(merged_path, f))
            for f in files if os.path.isfile(os.path.join(merged_path, f))
        )
        print(f"  \u2713 Merged model: {merged_path}")
        print(f"    Files: {len(files)}")
        print(f"    Total size: {total_size / (1024**3):.1f} GB")
    else:
        print(f"  \u2717 Merged model not found at {merged_path}")

    print("=" * 60)

## Step 2.5: Mutation Enrichment

Wait for background mutations to complete, then enrich **all** training phases:
- **(a)** Save debug pairs for lang_rust enrichment cache (for future runs)
- **(b)** Generate trajectories and merge into core_agent data
- **(c)** Format preference pairs and merge into IPO data

### 2.5 Wait for Mutations & Enrich All Phases

In [None]:
import sys, json
sys.path.insert(0, "scripts")

if not CONFIG["include_mutations"] or mutation_proc is None:
    print("Mutations not running — skipping enrichment.")
    tracker.skip("enrichment")
elif CONFIG["training_scope"] == "skip_to_rl":
    print("Skipping — scope is skip_to_rl")
    tracker.skip("enrichment")
else:
    tracker.start("enrichment")

    # --- Wait for mutation process ---
    print("Waiting for background mutations to complete...")
    returncode = mutation_proc.wait()
    if returncode != 0:
        print(f"  WARNING: mutations exited with code {returncode}")
        print(f"  Check log: data/rust/mutations/stdout.log")
    else:
        tracker.complete("mutations_bg")

    mutations_path = "data/rust/mutations/mutations.jsonl"
    if not os.path.exists(mutations_path):
        print(f"  No mutations file found at {mutations_path}")
        tracker.fail("enrichment")
    else:
        # Count mutations
        with open(mutations_path) as f:
            mutations = [json.loads(line) for line in f if line.strip()]
        print(f"  Loaded {len(mutations):,} mutations from {mutations_path}")

        # ============================================================
        # (a) Save debug pairs for lang_rust enrichment cache
        # ============================================================
        print(f"\n--- (a) Caching lang_rust enrichment ---")
        from dataset_formatters.harmony import format_harmony_debug

        debug_examples = []
        for m in mutations:
            result = format_harmony_debug({
                "buggy_code": m.get("mutant_code", ""),
                "error_message": m.get("compiler_error", m.get("test_error", "")),
                "fixed_code": m.get("original_code", ""),
            })
            if result.get("text"):
                debug_examples.append(result)

        if debug_examples:
            from datasets import Dataset
            cache_dir = "data/rust/mutations/lang_rust_enrichment"
            debug_ds = Dataset.from_list(debug_examples)
            debug_ds.save_to_disk(cache_dir)
            print(f"  Saved {len(debug_ds):,} debug pairs -> {cache_dir}")
            print(f"  (Will be used on next run for lang_rust pre-enrichment)")
        else:
            print("  No valid debug pairs generated.")

        # ============================================================
        # (b) Generate trajectories -> merge into core_agent
        # ============================================================
        print(f"\n--- (b) Generating trajectories for core_agent ---")

        traj_cmd = f"python scripts/15_generate_trajectories.py"
        traj_cmd += f" --mutations_path {mutations_path}"
        traj_cmd += f" --output_dir data/rust/core_agent/train_traj"
        traj_cmd += f" --no-strandset"

        if CONFIG["training_scope"] == "quick_test":
            traj_cmd += " --max_samples 100"

        !{traj_cmd}

        # Merge trajectory data into core_agent
        traj_path = "data/rust/core_agent/train_traj"
        core_path = "data/rust/core_agent/train"
        if os.path.exists(traj_path) and os.path.exists(core_path):
            from datasets import load_from_disk, concatenate_datasets

            traj_ds = load_from_disk(traj_path)
            core_ds = load_from_disk(core_path)
            print(f"  Enriching core_agent: {len(core_ds):,} Strandset + {len(traj_ds):,} trajectories")
            merged = concatenate_datasets([core_ds, traj_ds])
            merged.save_to_disk(core_path)
            print(f"  Saved enriched dataset: {len(merged):,} total -> {core_path}")

        # ============================================================
        # (c) Format preference pairs -> merge into IPO
        # ============================================================
        if CONFIG["enrich_ipo"]:
            print(f"\n--- (c) Generating preference pairs for IPO ---")
            from dataset_formatters.harmony import format_harmony_preference

            pref_examples = []
            for m in mutations:
                buggy = m.get("mutant_code", "")
                fixed = m.get("original_code", "")
                error = m.get("compiler_error", m.get("test_error", ""))
                if not buggy or not fixed:
                    continue

                prompt = f"Fix the bug in this Rust code"
                if error:
                    prompt += f":\n\nError: {error}"
                prompt += f"\n\n```rust\n{buggy}\n```"

                result = format_harmony_preference({
                    "prompt": prompt,
                    "chosen": f"```rust\n{fixed}\n```",
                    "rejected": f"```rust\n{buggy}\n```",
                })
                if result.get("text"):
                    pref_examples.append(result)

            if pref_examples:
                from datasets import Dataset, load_from_disk, concatenate_datasets

                ipo_path = "data/rust/ipo/train"
                pref_ds = Dataset.from_list(pref_examples)

                if os.path.exists(ipo_path):
                    base_ipo = load_from_disk(ipo_path)
                    print(f"  Enriching IPO: {len(base_ipo):,} Strandset + {len(pref_ds):,} mutation pairs")
                    merged_ipo = concatenate_datasets([base_ipo, pref_ds])
                    merged_ipo.save_to_disk(ipo_path)
                    print(f"  Saved enriched IPO: {len(merged_ipo):,} total -> {ipo_path}")
                else:
                    pref_ds.save_to_disk(ipo_path)
                    print(f"  Saved IPO: {len(pref_ds):,} pairs -> {ipo_path}")
            else:
                print("  No valid preference pairs generated.")
        else:
            print("\n--- (c) IPO enrichment disabled ---")

        # Backup enriched data
        drive_helper.backup("data/rust/core_agent", "data/rust/core_agent")
        drive_helper.backup("data/rust/ipo", "data/rust/ipo")
        drive_helper.backup("data/rust/mutations", "data/rust/mutations")
        if DRIVE_MODE != "local":
            print("\nEnriched data backed up to Drive.")

        tracker.complete("enrichment")

In [None]:
from datasets import load_from_disk

enriched_checks = [
    ("Lang Rust (training)", "data/rust/lang_rust/train"),
    ("Core Agent (enriched)", "data/rust/core_agent/train"),
    ("IPO (enriched)", "data/rust/ipo/train"),
]

print("Enriched Data Summary:")
print("=" * 60)
for name, path in enriched_checks:
    if os.path.exists(path):
        try:
            ds = load_from_disk(path)
            print(f"  \u2713 {name}: {len(ds):,} examples")
        except Exception as e:
            print(f"  \u2717 {name}: failed to load ({e})")
    else:
        print(f"  \u2014 {name}: not found")
print("=" * 60)

## Step 3: Core Agent SFT

Train a higher-rank LoRA adapter (rank 128) on agent trajectories with tool use.
Uses the merged lang_rust model as the base.

**v4:** Data enriched with mutation-based trajectories (Step 2.5b).

### 3.1 Train core_agent Adapter

In [None]:
if CONFIG["training_scope"] in ("lang_adapter_only", "skip_to_rl"):
    print(f"Skipping \u2014 scope is {CONFIG['training_scope']}")
    tracker.skip("core_agent")
else:
    tracker.start("core_agent")

    batch = CONFIG["core_agent_batch"]
    grad_accum = CONFIG["core_agent_grad_accum"]
    max_steps = CONFIG["core_agent_max_steps"]
    seq_len = CONFIG["core_agent_seq_len"]

    cmd = f"python scripts/14_train_core_agent.py"
    cmd += f" --train_data_path data/rust/core_agent/train"
    cmd += f" --per_device_train_batch_size {batch}"
    cmd += f" --gradient_accumulation_steps {grad_accum}"
    cmd += f" --max_steps {max_steps}"

    print(f"Training core_agent adapter...")
    print(f"  Data: data/rust/core_agent/train")
    print(f"  Batch: {batch} x {grad_accum} = {batch * grad_accum}")
    print(f"  Max steps: {max_steps}")
    print(f"  Seq length: {seq_len} (from config)")
    print(f"  LoRA rank: 128")
    print(f"  Split LoRA backend: {CONFIG['moe_backend']}")
    print(f"  Auto packing: enabled (uncontaminated)")
    print("=" * 60)

    !{cmd}

    drive_helper.backup("checkpoints/core_agent", "checkpoints/core_agent")
    if DRIVE_MODE != "local":
        print("\nCheckpoint backed up to Drive.")

    tracker.complete("core_agent")

In [None]:
if CONFIG["training_scope"] in ("lang_adapter_only", "skip_to_rl"):
    print(f"Skipping \u2014 scope is {CONFIG['training_scope']}")
else:
    ckpt_path = "checkpoints/core_agent/final"

    print("Core Agent Verification:")
    print("=" * 60)

    if os.path.exists(ckpt_path):
        files = os.listdir(ckpt_path)
        print(f"  \u2713 Checkpoint: {ckpt_path} ({len(files)} files)")

        adapter_config = os.path.join(ckpt_path, "adapter_config.json")
        if os.path.exists(adapter_config):
            import json
            with open(adapter_config) as f:
                cfg = json.load(f)
            print(f"    LoRA rank: {cfg.get('r', '?')}")
            print(f"    Alpha: {cfg.get('lora_alpha', '?')}")
            print(f"    Target modules: {cfg.get('target_modules', '?')}")
    else:
        print(f"  \u2717 Checkpoint not found at {ckpt_path}")

    print("=" * 60)

## Step 4: IPO Preference Training

Train with Identity Preference Optimisation on preference pairs.
Very low learning rate (5e-7), 1 epoch only to avoid collapse.

**v4:** IPO data enriched with mutation-based preference pairs (Step 2.5c).

### 4.1 Train with IPO

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
    tracker.skip("ipo")
elif not CONFIG["include_ipo"]:
    print("Skipping \u2014 IPO disabled (include_ipo=False)")
    tracker.skip("ipo")
else:
    tracker.start("ipo")

    batch = CONFIG["ipo_batch"]
    grad_accum = CONFIG["ipo_grad_accum"]
    max_steps = CONFIG["ipo_max_steps"]

    if CONFIG["training_scope"] == "skip_to_rl":
        ipo_checkpoint = "checkpoints/core_agent/final"
        print("Using existing core_agent checkpoint (skip_to_rl mode)")
    else:
        ipo_checkpoint = "checkpoints/core_agent/final"

    cmd = f"python scripts/17_ipo_preference.py"
    cmd += f" --checkpoint {ipo_checkpoint}"
    cmd += f" --train_data_path data/rust/ipo/train"
    cmd += f" --per_device_train_batch_size {batch}"
    cmd += f" --gradient_accumulation_steps {grad_accum}"
    cmd += f" --max_steps {max_steps}"

    print(f"Training with IPO (enriched preferences)...")
    print(f"  Checkpoint: {ipo_checkpoint}")
    print(f"  Data: data/rust/ipo/train")
    print(f"  Batch: {batch} x {grad_accum} = {batch * grad_accum}")
    print(f"  Max steps: {max_steps}")
    print(f"  Loss: IPO (beta=0.1)")
    print(f"  Load mode: {CONFIG['load_mode']}")
    print(f"  Split LoRA backend: {CONFIG['moe_backend']}")
    print("=" * 60)

    !{cmd}

    drive_helper.backup("checkpoints/core_agent_ipo", "checkpoints/core_agent_ipo")
    if DRIVE_MODE != "local":
        print("\nCheckpoint backed up to Drive.")

    tracker.complete("ipo")

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
elif not CONFIG["include_ipo"]:
    print("Skipping \u2014 IPO disabled")
else:
    ckpt_path = "checkpoints/core_agent_ipo/final"

    print("IPO Verification:")
    print("=" * 60)

    if os.path.exists(ckpt_path):
        files = os.listdir(ckpt_path)
        print(f"  \u2713 IPO checkpoint: {ckpt_path} ({len(files)} files)")
    else:
        print(f"  \u2717 IPO checkpoint not found at {ckpt_path}")

    tb_dir = "checkpoints/core_agent_ipo"
    tb_files = []
    if os.path.exists(tb_dir):
        for root, dirs, fnames in os.walk(tb_dir):
            for fn in fnames:
                if fn.startswith("events.out.tfevents"):
                    tb_files.append(os.path.join(root, fn))
    if tb_files:
        print(f"  \u2713 TensorBoard logs found ({len(tb_files)} event files)")
        print(f"    Monitor KL divergence: warn >0.3, abort >0.5")
    else:
        print(f"  \u2014 No TensorBoard logs found")

    print("=" * 60)

## Step 5: GRPO RL

Group Relative Policy Optimisation with execution-based rewards.
Generates N completions per prompt, runs `cargo check/test/clippy`, computes group-relative advantages.

**Optimisations:**
- FP8 RL with vLLM inference on H100 (1.6x throughput)
- Chunked batching for longer context
- Extended curriculum: 65K context on H100
- Harmony format compliance reward

### 5.1 Train with GRPO

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
    tracker.skip("grpo")
elif not CONFIG["include_grpo"]:
    print("Skipping \u2014 GRPO disabled (include_grpo=False)")
    tracker.skip("grpo")
else:
    tracker.start("grpo")

    batch = CONFIG["grpo_batch"]
    grad_accum = CONFIG["grpo_grad_accum"]
    max_steps = CONFIG["grpo_max_steps"]
    max_seq = CONFIG["grpo_seq_len"]

    grpo_checkpoint = "checkpoints/core_agent_ipo/final"
    if not os.path.exists(grpo_checkpoint):
        grpo_checkpoint = "checkpoints/core_agent/final"
        print(f"  IPO checkpoint not found, using: {grpo_checkpoint}")

    cmd = f"python scripts/18_grpo_rl.py"
    cmd += f" --checkpoint {grpo_checkpoint}"
    cmd += f" --per_device_train_batch_size {batch}"
    cmd += f" --gradient_accumulation_steps {grad_accum}"
    cmd += f" --max_steps {max_steps}"

    v4_features = []
    v4_features.append(f"Split LoRA ({CONFIG['moe_backend']})")
    if CONFIG["load_mode"] == "fp8":
        v4_features.append("FP8 weights")
    if CONFIG.get("fast_inference"):
        v4_features.append("vLLM inference")
    v4_features.append("Chunked batching (auto)")
    v4_features.append("Auto packing")

    if CONFIG["gpu_tier"] == "a100_40gb":
        print("NOTE: 40GB GPU \u2014 GRPO sequence length capped at 16384")

    print(f"Training with GRPO...")
    print(f"  Checkpoint: {grpo_checkpoint}")
    print(f"  Batch: {batch} x {grad_accum} = {batch * grad_accum}")
    print(f"  Max steps: {max_steps}")
    print(f"  Max seq length: {max_seq}")
    print(f"  Generations per prompt: {CONFIG['grpo_num_gen']}")
    print(f"\n  Features active:")
    for feat in v4_features:
        print(f"    \u2713 {feat}")
    print("=" * 60)

    !{cmd}

    drive_helper.backup("checkpoints/core_agent_grpo", "checkpoints/core_agent_grpo")
    if DRIVE_MODE != "local":
        print("\nCheckpoint backed up to Drive.")

    tracker.complete("grpo")

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
elif not CONFIG["include_grpo"]:
    print("Skipping \u2014 GRPO disabled")
else:
    ckpt_path = "checkpoints/core_agent_grpo/final"

    print("GRPO Verification:")
    print("=" * 60)

    if os.path.exists(ckpt_path):
        files = os.listdir(ckpt_path)
        print(f"  \u2713 GRPO checkpoint: {ckpt_path} ({len(files)} files)")
    else:
        print(f"  \u2717 GRPO checkpoint not found at {ckpt_path}")

    print("=" * 60)

## Step 6: Evaluation

Evaluate the best checkpoint on held-out Rust tasks using execution-based metrics
(cargo check, cargo test, clippy).

### 6.1 Run Rust Evaluation

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
    tracker.skip("eval")
else:
    tracker.start("eval")

    # Determine best checkpoint
    if CONFIG["include_grpo"] and os.path.exists("checkpoints/core_agent_grpo/final"):
        eval_checkpoint = "checkpoints/core_agent_grpo/final"
    elif CONFIG["include_ipo"] and os.path.exists("checkpoints/core_agent_ipo/final"):
        eval_checkpoint = "checkpoints/core_agent_ipo/final"
    elif os.path.exists("checkpoints/core_agent/final"):
        eval_checkpoint = "checkpoints/core_agent/final"
    else:
        eval_checkpoint = "checkpoints/lang_rust/final"

    num_samples = CONFIG["eval_num_samples"]

    print(f"Evaluating checkpoint: {eval_checkpoint}")
    print(f"Samples: {num_samples}")
    print("=" * 60)

    !python scripts/eval_rust_agent.py \
        --checkpoint {eval_checkpoint} \
        --num_samples {num_samples}

    drive_helper.backup("evals/rust_agent", "evals/rust_agent")
    if DRIVE_MODE != "local":
        print("\nResults backed up to Drive.")

    tracker.complete("eval")

### 6.2 Check Promotion Gates

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
else:
    print("Checking promotion gates...")
    print("=" * 60)

    !python scripts/12_check_gates.py rust_agent

In [None]:
if CONFIG["training_scope"] == "lang_adapter_only":
    print("Skipping \u2014 scope is lang_adapter_only")
else:
    import json

    metrics_path = "evals/rust_agent/metrics.json"

    if os.path.exists(metrics_path):
        with open(metrics_path) as f:
            metrics = json.load(f)

        targets = {
            "cargo_check_pass_rate": (0.85, "higher"),
            "cargo_test_pass_rate": (0.70, "higher"),
            "clippy_clean_rate": (0.80, "higher"),
            "iterations_to_green_median": (3, "lower"),
            "diff_size_median": (50, "lower"),
            "tool_call_format_accuracy": (0.99, "higher"),
            "hallucinated_api_rate": (0.05, "lower"),
        }

        print("=" * 60)
        print("EVALUATION RESULTS")
        print("=" * 60)
        print(f"{'Metric':<32} {'Value':>8} {'Target':>8} {'Status':>8}")
        print("-" * 60)

        for key, (target, direction) in targets.items():
            value = metrics.get(key)
            if value is None:
                print(f"{key:<32} {'N/A':>8} {target:>8} {'\u2014':>8}")
                continue

            if direction == "higher":
                passed = value >= target
            else:
                passed = value <= target

            status = "\u2713 PASS" if passed else "\u2717 FAIL"
            fmt_val = f"{value:.2%}" if isinstance(value, float) and value <= 1 else f"{value}"
            fmt_tgt = f"{target:.0%}" if isinstance(target, float) and target <= 1 else f"{target}"
            print(f"{key:<32} {fmt_val:>8} {fmt_tgt:>8} {status:>8}")

        print("=" * 60)
    else:
        print(f"\u2717 Metrics file not found at {metrics_path}")
        print("Run evaluation (6.1) first.")

## Step 7: Test Model

Load the trained model and generate Rust code interactively.

### 7.1 Load Model

In [None]:
from unsloth import FastLanguageModel
from peft import PeftModel
import torch

CHECKPOINT_PRIORITY = [
    "checkpoints/core_agent_grpo/final",
    "checkpoints/core_agent_ipo/final",
    "checkpoints/core_agent/final",
    "checkpoints/lang_rust/final",
]

# Merged model can be loaded directly (no adapter needed)
MERGED_PATH = "checkpoints/gpt-oss-20b-rust-merged"

MODEL_PATH = None
is_adapter = False
for path in CHECKPOINT_PRIORITY:
    if os.path.exists(path) and os.path.exists(os.path.join(path, "adapter_config.json")):
        MODEL_PATH = path
        is_adapter = True
        break

if MODEL_PATH is None and os.path.exists(MERGED_PATH):
    MODEL_PATH = MERGED_PATH
    is_adapter = False

if MODEL_PATH is None:
    print("\u2717 No checkpoint found. Train the model first.")
else:
    print(f"Loading model from: {MODEL_PATH}")
    print(f"  Type: {'LoRA adapter' if is_adapter else 'merged model'}")

    # GPT-OSS fused MoE experts (GptOssExperts) don't support BNB 4-bit
    # quantization at load time (no `down_projs` sub-module for traversal).
    # See: https://github.com/unslothai/unsloth/issues/3775
    #
    # Strategy: use Unsloth's pre-quantized BNB model (already 4-bit, no
    # quantizer runs at load). Falls back to bfloat16 if unavailable.
    model = None
    base_name = "openai/gpt-oss-20b"

    # Try 1: Pre-quantized BNB 4-bit
    try:
        print("  Loading pre-quantized BNB 4-bit model...")
        model, tokenizer = FastLanguageModel.from_pretrained(
            "unsloth/gpt-oss-20b-unsloth-bnb-4bit",
            max_seq_length=4096,
            dtype=None,
            load_in_4bit=False,  # Already quantized
        )
        print("  Mode: BNB 4-bit (pre-quantized)")
    except Exception as e:
        print(f"  Pre-quantized BNB failed: {e}")

    # Try 2: bfloat16 without quantization (~40GB, fits H100 80GB)
    if model is None:
        print("  Loading in bfloat16 (no quantization)...")
        model, tokenizer = FastLanguageModel.from_pretrained(
            base_name,
            max_seq_length=4096,
            dtype=torch.bfloat16,
            load_in_4bit=False,
        )
        print("  Mode: bfloat16 (no quantization)")

    print("=" * 60)

    if is_adapter:
        print(f"  Applying LoRA adapter from {MODEL_PATH}...")
        model = PeftModel.from_pretrained(model, MODEL_PATH)

    FastLanguageModel.for_inference(model)
    print("\u2713 Model loaded!")

### 7.2 Generate Rust Code

In [None]:
import sys
sys.path.insert(0, "scripts")
from dataset_formatters.harmony import encode_harmony_messages

TEST_PROMPTS = [
    "Write a Rust function `fn merge_sorted(a: &[i32], b: &[i32]) -> Vec<i32>` that merges two sorted slices into a single sorted vector.",
    "This Rust code fails the borrow checker. Fix it:\n```rust\nfn main() {\n    let mut v = vec![1, 2, 3];\n    let first = &v[0];\n    v.push(4);\n    println!(\"{}\", first);\n}\n```",
    "Write an async Rust function using tokio that fetches a URL with reqwest, retries up to 3 times on failure, and returns the response body as a String.",
]

def generate_rust(prompt, max_tokens=1024):
    messages = [{"role": "user", "content": prompt}]
    formatted = encode_harmony_messages(
        messages,
        developer_instructions="You are a Rust programming expert. Write correct, idiomatic code.",
        add_generation_prompt=True,
    )
    inputs = tokenizer(formatted, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=0.3,
            do_sample=True,
            top_p=0.9,
        )
    return tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)

for i, prompt in enumerate(TEST_PROMPTS, 1):
    print(f"\n{'=' * 60}")
    print(f"Test {i}: {prompt[:80]}...")
    print("=" * 60)
    response = generate_rust(prompt)
    print(response)
    print()

In [None]:
CUSTOM_PROMPT = "Write a Rust function that reads a CSV file and returns the sum of a specified column."

print(f"Prompt: {CUSTOM_PROMPT}")
print("=" * 60)
print(generate_rust(CUSTOM_PROMPT))

## Step 8: Export

Merge the final adapter and export to HuggingFace + GGUF formats.

**v4:** Optional QAT export for 97-100% MXFP4 quality retention.

### 8.1 Export to GGUF

In [None]:
tracker.start("export")

ADAPTER_PRIORITY = [
    "checkpoints/core_agent_grpo/final",
    "checkpoints/core_agent_ipo/final",
    "checkpoints/core_agent/final",
    "checkpoints/lang_rust/final",
]

adapter_path = None
for path in ADAPTER_PRIORITY:
    if os.path.exists(path):
        adapter_path = path
        break

if adapter_path is None:
    print("\u2717 No adapter checkpoint found.")
    tracker.fail("export")
else:
    export_dir = "checkpoints/gpt-oss-20b-rust-export-v4"
    print(f"Exporting adapter: {adapter_path}")
    print(f"Output: {export_dir}")
    print("=" * 60)

    !python scripts/19_merge_adapter.py \
        --adapter_path {adapter_path} \
        --output_dir {export_dir} \
        --export_formats hf gguf_q4

    drive_helper.backup(export_dir, "checkpoints/gpt-oss-20b-rust-export-v4")
    if DRIVE_MODE != "local":
        print("\nExport backed up to Drive.")

    tracker.complete("export")

### 8.2 QAT Export (Optional)

Quantisation-Aware Training for MXFP4 deployment.
Recovers 97-100% quality vs 59-89% with post-training quantisation.

In [None]:
if not CONFIG.get("enable_qat_export"):
    print("QAT export disabled. Enable via widget toggle in Step 0.3.")
    print("\nQAT recovers 97-100% quality when deploying to MXFP4,")
    print("vs 59-89% with standard post-training quantisation (PTQ).")
else:
    export_dir = "checkpoints/gpt-oss-20b-rust-export-v4"
    qat_dir = "checkpoints/gpt-oss-20b-rust-qat"

    if not os.path.exists(export_dir):
        print("\u2717 Run standard export (8.1) first.")
    else:
        print("Running QAT pass on merged model...")
        print("  This fine-tunes with MXFP4-aware quantisation at reduced LR (1e-5).")
        print("=" * 60)

        try:
            import modelopt.torch.quantization as mtq
            print("\u2713 nvidia-modelopt available")

            print("\nQAT pipeline (manual steps):")
            print(f"  1. Load merged BF16 model from {export_dir}")
            print(f"  2. mtq.quantize(model, config=mtq.MXFP4_DEFAULT_CFG)")
            print(f"  3. Fine-tune for ~100 steps at LR 1e-5")
            print(f"  4. Export to {qat_dir}")
        except ImportError:
            print("\u2717 nvidia-modelopt not installed.")
            print("  Install: pip install nvidia-modelopt")

### 8.3 Download GGUF

In [None]:
if IN_COLAB:
    from google.colab import files
    import glob

    export_dir = "checkpoints/gpt-oss-20b-rust-export-v4"
    gguf_files = glob.glob(os.path.join(export_dir, "**/*.gguf"), recursive=True)

    if gguf_files:
        gguf_path = gguf_files[0]
        size_gb = os.path.getsize(gguf_path) / (1024**3)
        print(f"Downloading: {os.path.basename(gguf_path)} ({size_gb:.1f} GB)")
        files.download(gguf_path)
    else:
        print("\u2717 No GGUF file found. Run export (8.1) first.")
else:
    print("Download not available outside Colab.")
    print("GGUF file is at: checkpoints/gpt-oss-20b-rust-export-v4/")

### 8.4 Upload to HuggingFace Hub

In [None]:
# --- Configuration ---
HF_REPO_ID = ""  # e.g. "your-username/gpt-oss-20b-rust-agent-v4"
HF_PRIVATE = True

assert HF_REPO_ID, "Set HF_REPO_ID above before running this cell."

import glob
from huggingface_hub import HfApi

# Authenticate: try Colab Secrets first, then interactive login
try:
    from google.colab import userdata
    hf_token = userdata.get("HF_TOKEN")
    print("Using HF_TOKEN from Colab Secrets.")
except Exception:
    from huggingface_hub import login
    login()
    hf_token = None  # login() stores token globally

api = HfApi(token=hf_token)

# Create repo (no-op if it already exists)
api.create_repo(repo_id=HF_REPO_ID, private=HF_PRIVATE, exist_ok=True)
print(f"Repo ready: https://huggingface.co/{HF_REPO_ID}")

# --- Model card ---
export_dir = "checkpoints/gpt-oss-20b-rust-export-v4"
hf_dir = os.path.join(export_dir, "hf")

model_card = """\
---
base_model: openai/gpt-oss-20b
tags:
  - rust
  - code-agent
  - gpt-oss
  - qlora
  - unsloth
  - grpo
license: apache-2.0
pipeline_tag: text-generation
---

# GPT-OSS 20B Rust Agent (v4)

Fine-tuned from [openai/gpt-oss-20b](https://huggingface.co/openai/gpt-oss-20b) using
Strandset-Rust-v1 + mutation enrichment with execution-grounded GRPO reinforcement learning.

## Training Pipeline

1. **Strandset Data** \u2014 191K verified Rust examples (code gen, debug, review, preferences)
2. **Mutation Generation** \u2014 cargo-mutants based code mutations
3. **Lang Adapter** \u2014 Rust domain specialisation (enriched with cached mutation debug pairs)
4. **Core Agent SFT** \u2014 Debug/review training (enriched with mutation trajectories)
5. **IPO** \u2014 Preference optimisation (enriched with mutation preference pairs)
6. **GRPO RL** \u2014 Execution-grounded reinforcement learning

Trained with [Unsloth](https://github.com/unslothai/unsloth) QLoRA on NVIDIA H100 80GB.

## v4 Features

- Widget-based configuration UI
- Mutation enrichment across ALL training phases
- Pipeline progress dashboard
- Split LoRA + FP8 RL on H100

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("{repo_id}")
tokenizer = AutoTokenizer.from_pretrained("{repo_id}")
```

## GGUF

A quantised GGUF file is included for use with llama.cpp or Ollama.
""".format(repo_id=HF_REPO_ID)

readme_path = os.path.join(hf_dir, "README.md")
os.makedirs(hf_dir, exist_ok=True)
with open(readme_path, "w") as f:
    f.write(model_card)
print(f"Wrote model card to {readme_path}")

# --- Upload HF safetensors model ---
assert os.path.isdir(hf_dir), f"HF export dir not found: {hf_dir}. Run export (8.1) first."
print(f"Uploading HF model from {hf_dir} ...")
api.upload_folder(
    folder_path=hf_dir,
    repo_id=HF_REPO_ID,
    commit_message="Upload merged HF model (v4 \u2014 Strandset + mutation enrichment pipeline)",
    token=hf_token,
)
print("HF model uploaded.")

# --- Upload GGUF file ---
gguf_files = glob.glob(os.path.join(export_dir, "**/*.gguf"), recursive=True)
if gguf_files:
    gguf_path = gguf_files[0]
    gguf_name = os.path.basename(gguf_path)
    size_gb = os.path.getsize(gguf_path) / (1024**3)
    print(f"Uploading GGUF: {gguf_name} ({size_gb:.1f} GB) ...")
    api.upload_file(
        path_or_fileobj=gguf_path,
        path_in_repo=gguf_name,
        repo_id=HF_REPO_ID,
        commit_message=f"Upload GGUF quantisation ({gguf_name})",
        token=hf_token,
    )
    print("GGUF uploaded.")
else:
    print("No GGUF file found \u2014 skipping. Run export (8.1) to generate one.")

print(f"\nDone! View your model at: https://huggingface.co/{HF_REPO_ID}")

---
## Training Complete!

Your GPT-OSS 20B Rust coding agent (v4) is trained and ready to use.

**v4 Pipeline:**
1. Strandset-Rust-v1: 191K verified examples across all phases
2. Mutation Enrichment: cargo-mutants data enriches lang_rust (cached), core_agent (trajectories), and IPO (preference pairs)
3. Lang Adapter: Rust domain specialisation (LoRA rank 64)
4. Core Agent SFT: Debug and review training (LoRA rank 128)
5. IPO: Preference optimisation with enriched pairs
6. GRPO RL: Execution-grounded reinforcement learning

**Outputs:**
- Checkpoints: `checkpoints/core_agent_{ipo,grpo}/final`
- Evaluation: `evals/rust_agent/metrics.json`
- Exported model: `checkpoints/gpt-oss-20b-rust-export-v4/`
- All backed up to Google Drive: `gpt-oss-20b-rust-agent-v4/`

**Next steps:**
- Review evaluation metrics in Step 6
- Test interactively in Step 7
- Deploy the GGUF file with llama.cpp or Ollama
- For MXFP4 deployment, enable QAT export in Step 8.2

In [None]:
# Disconnect and release GPU runtime to stop billing
from google.colab import runtime
runtime.unassign()