# WX-AFD: Fine-Tuning Qwen3-4B for Area Forecast Discussions

This notebook fine-tunes **Qwen3-4B-Instruct-2507** with DoRA + rsLoRA to generate NWS Area Forecast Discussions from structured weather data.

**Google Colab version** — requires A100 GPU + High RAM runtime.

**Pipeline context:** This is step 4 in the WX-AFD pipeline:
1. `01_scrape_afds.py` — Scrape AFDs from IEM
2. `02_fetch_weather.py` — Fetch weather data from Open-Meteo
3. `03_build_dataset.py` — Build training JSONL (messages format)
4. **`04_train_colab.ipynb` — Fine-tune, evaluate, and export model** (this notebook)

---

**Table of Contents**
1. [Environment Setup](#1.-Environment-Setup)
2. [Data Inspection](#2.-Data-Inspection)
3. [Configuration](#3.-Configuration)
4. [Sanity Check](#4.-Sanity-Check)
5. [Training](#5.-Training)
6. [Training Curves](#6.-Training-Curves)
7. [LoRA Merge](#7.-LoRA-Merge)
8. [Inference](#8.-Inference)
9. [Evaluation](#9.-Evaluation)
10. [Results](#10.-Results)
11. [Post-Training Sanity Checks](#11.-Post-Training-Sanity-Checks)
12. [Push to HuggingFace Hub](#12.-Push-to-HuggingFace-Hub)
13. [Next Steps](#13.-Next-Steps)

## 1. Environment Setup

In [None]:
# Install dependencies (torch is pre-installed on Colab — don't reinstall)
!pip install -U packaging setuptools wheel ninja
!pip install --no-build-isolation axolotl[flash-attn,deepspeed]
!pip install rouge-score==0.1.2 bert-score sacrebleu

In [None]:
# Mount Google Drive and set up paths
from pathlib import Path

from google.colab import drive
drive.mount('/content/drive')

DRIVE_ROOT = Path("/content/drive/MyDrive/wx-afd")
DATA_DIR = DRIVE_ROOT / "data"
CONFIG_PATH = DRIVE_ROOT / "configs" / "wx-afd-dora.yml"
OUTPUT_DIR = DRIVE_ROOT / "output"
EVAL_DIR = DRIVE_ROOT / "eval"

# Ensure directories exist
for d in [DATA_DIR, CONFIG_PATH.parent, OUTPUT_DIR, EVAL_DIR]:
    d.mkdir(parents=True, exist_ok=True)

print(f"Drive root: {DRIVE_ROOT}")
print(f"Data dir:   {DATA_DIR}")
print(f"Config:     {CONFIG_PATH}")
print(f"Output:     {OUTPUT_DIR}")
print(f"Eval:       {EVAL_DIR}")

In [None]:
import shutil
import sys

# Clone repo (for wx_afd.py, configs, and pipeline scripts)
!git clone https://github.com/ringusTheImp/AFDModel.git /content/AFDModel
sys.path.insert(0, "/content/AFDModel")

# Check for existing data on Drive, run pipeline if missing
if not (DATA_DIR / "train.jsonl").exists():
    print("Data not found on Drive — running pipeline...")
    !cd /content/AFDModel && python 01_scrape_afds.py
    !cd /content/AFDModel && python 02_fetch_weather.py
    !cd /content/AFDModel && python 03_build_dataset.py
    # Copy results to Drive for persistence
    shutil.copytree("/content/AFDModel/data", str(DATA_DIR), dirs_exist_ok=True)
    print(f"Data copied to Drive: {DATA_DIR}")
else:
    print(f"Data found on Drive: {DATA_DIR}")

In [None]:
import yaml

# Write Colab-specific Axolotl config with Drive paths
# Load the template from the cloned repo
with open("/content/AFDModel/configs/wx-afd-dora.yml") as f:
    config_data = yaml.safe_load(f)

# Patch paths for Colab + Google Drive
config_data["datasets"][0]["path"] = str(DATA_DIR / "train.jsonl")
config_data["test_datasets"][0]["path"] = str(DATA_DIR / "val.jsonl")
config_data["output_dir"] = str(OUTPUT_DIR)

with open(CONFIG_PATH, "w") as f:
    yaml.dump(config_data, f, default_flow_style=False, sort_keys=False)

print(f"Config written to: {CONFIG_PATH}")
print(f"  datasets[0].path:      {config_data['datasets'][0]['path']}")
print(f"  test_datasets[0].path: {config_data['test_datasets'][0]['path']}")
print(f"  output_dir:            {config_data['output_dir']}")

In [None]:
import json
import os
import subprocess

import matplotlib.pyplot as plt
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"

# ---- GPU check ----
print(f"PyTorch:  {torch.__version__}")
print(f"CUDA:     {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU:      {torch.cuda.get_device_name(0)}")
    print(f"VRAM:     {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
print(f"Root:     {DRIVE_ROOT}")

## 2. Data Inspection

Verify that our training data (output of `03_build_dataset.py`) is correctly formatted.

In [None]:
def load_jsonl(path):
    """Load JSONL file into a list of dicts."""
    with open(path) as f:
        return [json.loads(line) for line in f]

train_data = load_jsonl(DATA_DIR / "train.jsonl")
val_data = load_jsonl(DATA_DIR / "val.jsonl")

print(f"Training examples:   {len(train_data)}")
print(f"Validation examples: {len(val_data)}")
print(f"Total:               {len(train_data) + len(val_data)}")

# Verify messages format
ex = train_data[0]
assert "messages" in ex, "Missing 'messages' key"
assert len(ex["messages"]) == 3, f"Expected 3 messages, got {len(ex['messages'])}"
assert ex["messages"][0]["role"] == "system"
assert ex["messages"][1]["role"] == "user"
assert ex["messages"][2]["role"] == "assistant"
print("\nMessages format: OK (system + user + assistant)")

In [None]:
# Token length distribution
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

def count_tokens(example):
    text = tokenizer.apply_chat_template(example["messages"], tokenize=False)
    return len(tokenizer.encode(text))

train_lengths = [count_tokens(ex) for ex in train_data]
val_lengths = [count_tokens(ex) for ex in val_data]
all_lengths = train_lengths + val_lengths

print(f"Token lengths (all {len(all_lengths)} examples):")
print(f"  Min:    {min(all_lengths)}")
print(f"  Mean:   {sum(all_lengths) // len(all_lengths)}")
print(f"  Median: {sorted(all_lengths)[len(all_lengths)//2]}")
print(f"  Max:    {max(all_lengths)}")
print(f"  >2048:  {sum(1 for l in all_lengths if l > 2048)} "
      f"({sum(1 for l in all_lengths if l > 2048)/len(all_lengths):.1%})")

fig, ax = plt.subplots(figsize=(10, 4))
ax.hist(all_lengths, bins=50, edgecolor="black", alpha=0.7)
ax.axvline(2048, color="red", linestyle="--", label="sequence_len=2048")
ax.set_xlabel("Total tokens per example")
ax.set_ylabel("Count")
ax.set_title("Token Length Distribution")
ax.legend()
plt.tight_layout()
plt.show()

In [None]:
# Display a sample training example
sample = train_data[0]
print("=" * 70)
print("SYSTEM PROMPT:")
print("=" * 70)
print(sample["messages"][0]["content"][:300], "...")
print()
print("=" * 70)
print("WEATHER INPUT (first 500 chars):")
print("=" * 70)
print(sample["messages"][1]["content"][:500], "...")
print()
print("=" * 70)
print("AFD OUTPUT (first 500 chars):")
print("=" * 70)
print(sample["messages"][2]["content"][:500], "...")

## 3. Configuration

Load and validate the Axolotl YAML config.

In [None]:
with open(CONFIG_PATH) as f:
    config = yaml.safe_load(f)

# Validate critical fields
checks = {
    "base_model": config.get("base_model") == "Qwen/Qwen3-4B-Instruct-2507",
    "chat_template": config.get("chat_template") == "qwen3",
    "eos_token": config.get("special_tokens", {}).get("eos_token") == "<|im_end|>",
    "pad_token": config.get("special_tokens", {}).get("pad_token") == "<|endoftext|>",
    "adapter": config.get("adapter") == "lora",
    "lora_r": config.get("lora_r") == 16,
    "lora_alpha": config.get("lora_alpha") == 32,
    "peft_use_dora": config.get("peft_use_dora") is True,
    "peft_use_rslora": config.get("peft_use_rslora") is True,
    "sample_packing": config.get("sample_packing") is True,
    "roles_to_train": config["datasets"][0].get("roles_to_train") == ["assistant"],
    "eot_tokens": config.get("eot_tokens") == ["<|im_end|>"],
    "test_datasets": "test_datasets" in config,
    "sequence_len": config.get("sequence_len") == 2048,
    "bf16": config.get("bf16") is True,
    "early_stopping": config.get("early_stopping_patience") == 5,
}

all_ok = True
for name, ok in checks.items():
    status = "OK" if ok else "FAIL"
    if not ok:
        all_ok = False
    print(f"  [{status}] {name}")

assert all_ok, "Config validation failed — fix issues above before training"
print("\nConfig validation: ALL PASSED")

## 4. Sanity Check

Use Axolotl's internals to validate config normalization, dataset loading, and loss masking
**before** launching a real training job. This catches issues early without GPU time.

Key checks:
- Config normalizes without errors
- Dataset loads and tokenizes correctly
- Labels are `-100` for system/user tokens (loss masking)
- EOS/PAD token IDs are correct

In [None]:
from axolotl.utils.config import normalize_config, validate_config
from axolotl.common.datasets import load_datasets
from axolotl.utils.dict import DictDefault

# Workaround: Axolotl 0.14.0 bug — SFTDataset pydantic model not subscriptable
from axolotl.utils.schemas.config import SFTDataset
if not hasattr(SFTDataset, '__getitem__'):
    SFTDataset.__getitem__ = lambda self, key: getattr(self, key)

# Normalize and validate
cfg = DictDefault(config)
normalize_config(cfg)
validate_config(cfg)
print("Config normalization: OK")
print("Config validation:    OK")

# Load tokenizer for checks
tok = AutoTokenizer.from_pretrained(cfg.base_model, trust_remote_code=True)
eos_id = tok.convert_tokens_to_ids("<|im_end|>")
pad_id = tok.convert_tokens_to_ids("<|endoftext|>")
print(f"EOS token ID: {eos_id} (expect 151645) — {'OK' if eos_id == 151645 else 'FAIL'}")
print(f"PAD token ID: {pad_id} (expect 151643) — {'OK' if pad_id == 151643 else 'FAIL'}")
assert eos_id != pad_id, "EOS and PAD must differ!"

# Load datasets (this tests the full data pipeline)
print("\nLoading datasets...")
dataset_meta = load_datasets(cfg=cfg)
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
print(f"  Train: {len(train_dataset)} packed sequences")
if eval_dataset:
    print(f"  Eval:  {len(eval_dataset)} packed sequences")

# Verify loss masking: labels should be -100 for system/user tokens
sample = train_dataset[0]
labels = sample["labels"]
n_masked = sum(1 for l in labels if l == -100)
n_total = len(labels)
print(f"\nLoss masking (sample 0):")
print(f"  Total tokens:  {n_total}")
print(f"  Masked (-100): {n_masked} ({n_masked/n_total:.1%})")
print(f"  Trained:       {n_total - n_masked} ({(n_total - n_masked)/n_total:.1%})")
assert n_masked > 0, "No masked tokens — loss masking may be broken"
print("\nSanity check: ALL PASSED")

## 5. Training

Run training directly in this notebook session (requires A100 GPU runtime).

**Expected timeline:** ~429 steps across 3 epochs, ~45-75 minutes on A100 40GB.
Early stopping (patience=5) may terminate after epoch 2.

In [None]:
from axolotl.train import train

train(cfg=cfg, dataset_meta=dataset_meta)

In [None]:
# Monitor training logs
log_dir = OUTPUT_DIR / "runs"
logs = sorted(log_dir.glob("**/events.*")) if log_dir.exists() else []
if logs:
    print(f"TensorBoard logs found: {len(logs)} event files")
    print(f"  Directory: {log_dir}")
else:
    print("No TensorBoard logs found yet.")

In [None]:
# Find trainer_state.json
state_files = sorted(OUTPUT_DIR.glob("**/trainer_state.json"))
if not state_files:
    print("No trainer_state.json found — training may not have completed yet.")
else:
    state_path = state_files[-1]
    print(f"Loading: {state_path}")
    state = json.loads(state_path.read_text())

    # Extract metrics
    train_loss, train_steps = [], []
    eval_loss, eval_steps = [], []
    lr_values, lr_steps = [], []

    for entry in state["log_history"]:
        step = entry["step"]
        if "loss" in entry:
            train_loss.append(entry["loss"])
            train_steps.append(step)
        if "eval_loss" in entry:
            eval_loss.append(entry["eval_loss"])
            eval_steps.append(step)
        if "learning_rate" in entry:
            lr_values.append(entry["learning_rate"])
            lr_steps.append(step)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

    # Loss curves
    ax1.plot(train_steps, train_loss, label="Train Loss", alpha=0.7)
    if eval_loss:
        ax1.plot(eval_steps, eval_loss, label="Val Loss", marker="o", markersize=4)
    ax1.set_xlabel("Step")
    ax1.set_ylabel("Loss")
    ax1.set_title("Training & Validation Loss")
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # LR schedule
    if lr_values:
        ax2.plot(lr_steps, lr_values, color="green")
        ax2.set_xlabel("Step")
        ax2.set_ylabel("Learning Rate")
        ax2.set_title("Learning Rate Schedule")
        ax2.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()

    # Summary
    print(f"\nTotal steps:     {state['global_step']}")
    print(f"Best model step: {state.get('best_model_checkpoint', 'N/A')}")
    if eval_loss:
        print(f"Best val loss:   {min(eval_loss):.4f} (step {eval_steps[eval_loss.index(min(eval_loss))]})")
    print(f"Final train loss: {train_loss[-1]:.4f}")

## 6. Training Curves

## 7. LoRA Merge

Merge the LoRA adapter back into the base model for clean inference.
The merged model is a standard HuggingFace directory — no adapter loading needed.

In [None]:
# Merge LoRA adapter into base model
merge_cmd = [
    "accelerate", "launch", "-m", "axolotl.cli.merge_lora",
    str(CONFIG_PATH),
    "--lora_model_dir", str(OUTPUT_DIR),
]
print(f"Running: {' '.join(merge_cmd)}")
result = subprocess.run(merge_cmd, capture_output=True, text=True)
print(result.stdout)
if result.returncode != 0:
    print(f"STDERR:\n{result.stderr}")
    raise RuntimeError("Merge failed")

merged_dir = OUTPUT_DIR / "merged"
assert merged_dir.exists(), f"Merged model not found at {merged_dir}"
print(f"\nMerged model saved to: {merged_dir}")
print(f"Contents: {[f.name for f in sorted(merged_dir.iterdir())]}")

In [None]:
# Merged model is already on Google Drive — no copy needed
print(f"Merged model saved to Google Drive: {OUTPUT_DIR / 'merged'}")

## 8. Inference

Load the merged model and generate AFDs. Compare against reference forecasts.

In [None]:
from wx_afd import generate_afd, load_model

# Load merged model
merged_dir = OUTPUT_DIR / "merged"
print(f"Loading merged model from: {merged_dir}")
model, tokenizer = load_model(str(merged_dir))
print("Model loaded.")

# Generate from a validation example
val_ex = val_data[0]
weather_input = val_ex["messages"][1]["content"]
reference_afd = val_ex["messages"][2]["content"]

generated_afd = generate_afd(model, tokenizer, weather_input)

print("=" * 70)
print("GENERATED AFD:")
print("=" * 70)
print(generated_afd[:1000], "..." if len(generated_afd) > 1000 else "")
print()
print("=" * 70)
print("REFERENCE AFD:")
print("=" * 70)
print(reference_afd[:1000], "..." if len(reference_afd) > 1000 else "")

In [None]:
# EOS behavior check: verify model stops generating
print("EOS Behavior Check — generating 5 examples...")
print()
for i in range(min(5, len(val_data))):
    inp = val_data[i]["messages"][1]["content"]
    out = generate_afd(model, tokenizer, inp)
    terminated = len(out) < 2048 * 4  # rough char-level check
    print(f"  Example {i}: {len(out)} chars — {'OK (terminated)' if terminated else 'WARNING: may not have stopped'}")

print("\nAll examples should terminate cleanly with <|im_end|>.")

## 9. Evaluation

Full evaluation on all validation examples using ROUGE-1/2/L, BERTScore F1, and format compliance.

In [None]:
from rouge_score import rouge_scorer
from tqdm.notebook import tqdm
from wx_afd import REQUIRED_SECTIONS, compute_rouge, compute_bertscore, format_compliance

# Generate all predictions
predictions = []
references = []

print(f"Generating AFDs for {len(val_data)} validation examples...")
for ex in tqdm(val_data):
    pred = generate_afd(model, tokenizer, ex["messages"][1]["content"])
    predictions.append(pred)
    references.append(ex["messages"][2]["content"])

# Save generated AFDs
ft_eval_dir = EVAL_DIR / "finetuned"
ft_gen_dir = ft_eval_dir / "generated"
ft_scores_dir = ft_eval_dir / "scores"
ft_gen_dir.mkdir(parents=True, exist_ok=True)
ft_scores_dir.mkdir(parents=True, exist_ok=True)

for i, pred in enumerate(predictions):
    (ft_gen_dir / f"example_{i:04d}.txt").write_text(pred)

# ROUGE
rouge_avg = compute_rouge(predictions, references)
print(f"\nROUGE-1: {rouge_avg['rouge1']:.4f}")
print(f"ROUGE-2: {rouge_avg['rouge2']:.4f}")
print(f"ROUGE-L: {rouge_avg['rougeL']:.4f}")

# Per-example ROUGE (for later analysis)
scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
rouge_results = {"rouge1": [], "rouge2": [], "rougeL": []}
for pred, ref in zip(predictions, references):
    result = scorer.score(ref, pred)
    for key in rouge_results:
        rouge_results[key].append(result[key].fmeasure)

# Free model memory before BERTScore
del model
torch.cuda.empty_cache()

# BERTScore
print("\nComputing BERTScore (this may take a minute)...")
bertscore_f1 = compute_bertscore(predictions, references)
print(f"BERTScore F1: {bertscore_f1:.4f}")

# Format compliance
compliance = [format_compliance(pred)["compliance_score"] for pred in predictions]
avg_compliance = sum(compliance) / len(compliance)
print(f"Format compliance: {avg_compliance:.2%}")

# Save metrics
ft_metrics = {
    "tag": "finetuned",
    "num_examples": len(predictions),
    "rouge1": rouge_avg["rouge1"],
    "rouge2": rouge_avg["rouge2"],
    "rougeL": rouge_avg["rougeL"],
    "bertscore_f1": bertscore_f1,
    "format_compliance": avg_compliance,
}
with open(ft_scores_dir / "metrics.json", "w") as f:
    json.dump(ft_metrics, f, indent=2)
print(f"\nMetrics saved to {ft_scores_dir / 'metrics.json'}")

In [None]:
# Zero-shot baseline evaluation
print(f"Loading zero-shot baseline: {MODEL_ID}")
model_zs, tokenizer_zs = load_model(MODEL_ID)

zs_predictions = []
print(f"Generating zero-shot AFDs for {len(val_data)} examples...")
for ex in tqdm(val_data):
    pred = generate_afd(model_zs, tokenizer_zs, ex["messages"][1]["content"])
    zs_predictions.append(pred)

# Save
zs_eval_dir = EVAL_DIR / "zero-shot"
zs_gen_dir = zs_eval_dir / "generated"
zs_scores_dir = zs_eval_dir / "scores"
zs_gen_dir.mkdir(parents=True, exist_ok=True)
zs_scores_dir.mkdir(parents=True, exist_ok=True)
for i, pred in enumerate(zs_predictions):
    (zs_gen_dir / f"example_{i:04d}.txt").write_text(pred)

# ROUGE
zs_rouge_avg = compute_rouge(zs_predictions, references)

# Free model before BERTScore
del model_zs
torch.cuda.empty_cache()

# BERTScore
print("Computing BERTScore for zero-shot...")
zs_bertscore = compute_bertscore(zs_predictions, references)

# Compliance
zs_compliance = [format_compliance(pred)["compliance_score"] for pred in zs_predictions]
zs_avg_compliance = sum(zs_compliance) / len(zs_compliance)

zs_metrics = {
    "tag": "zero-shot",
    "num_examples": len(zs_predictions),
    "rouge1": zs_rouge_avg["rouge1"],
    "rouge2": zs_rouge_avg["rouge2"],
    "rougeL": zs_rouge_avg["rougeL"],
    "bertscore_f1": zs_bertscore,
    "format_compliance": zs_avg_compliance,
}
with open(zs_scores_dir / "metrics.json", "w") as f:
    json.dump(zs_metrics, f, indent=2)

print(f"\nZero-shot ROUGE-1: {zs_rouge_avg['rouge1']:.4f}")
print(f"Zero-shot ROUGE-L: {zs_rouge_avg['rougeL']:.4f}")
print(f"Zero-shot BERTScore F1: {zs_bertscore:.4f}")
print(f"Zero-shot compliance: {zs_avg_compliance:.2%}")

In [None]:
# AFD format compliance detail
print("AFD Section Presence (fine-tuned model):")
print()
for sec in REQUIRED_SECTIONS:
    present = sum(1 for p in predictions if sec in p.lower())
    print(f"  {sec:<15} {present}/{len(predictions)} ({present/len(predictions):.0%})")

print("\nAFD Section Presence (zero-shot baseline):")
print()
for sec in REQUIRED_SECTIONS:
    present = sum(1 for p in zs_predictions if sec in p.lower())
    print(f"  {sec:<15} {present}/{len(zs_predictions)} ({present/len(zs_predictions):.0%})")

## 10. Results

Compare fine-tuned model against zero-shot baseline.

In [None]:
# Comparison table
print(f"{'Metric':<20} {'Fine-tuned':>12} {'Zero-shot':>12} {'Delta':>12}")
print("-" * 58)
for k in ["rouge1", "rouge2", "rougeL", "bertscore_f1", "format_compliance"]:
    f_val = ft_metrics[k]
    z_val = zs_metrics[k]
    delta = f_val - z_val
    sign = "+" if delta > 0 else ""
    print(f"{k:<20} {f_val:>12.4f} {z_val:>12.4f} {sign}{delta:>11.4f}")

# Bar chart
import numpy as np

metrics_keys = ["rouge1", "rouge2", "rougeL", "bertscore_f1", "format_compliance"]
labels = ["ROUGE-1", "ROUGE-2", "ROUGE-L", "BERTScore\nF1", "Format\nCompliance"]
ft_vals = [ft_metrics[k] for k in metrics_keys]
zs_vals = [zs_metrics[k] for k in metrics_keys]

x = np.arange(len(labels))
width = 0.35

fig, ax = plt.subplots(figsize=(10, 5))
bars1 = ax.bar(x - width/2, ft_vals, width, label="Fine-tuned", color="#2196F3")
bars2 = ax.bar(x + width/2, zs_vals, width, label="Zero-shot", color="#FF9800")

ax.set_ylabel("Score")
ax.set_title("Fine-tuned vs Zero-shot Evaluation")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
ax.set_ylim(0, 1)
ax.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Best and worst 3 examples by ROUGE-L
rougeL_scores = rouge_results["rougeL"]
indexed = sorted(enumerate(rougeL_scores), key=lambda x: x[1], reverse=True)

print("=" * 70)
print("TOP 3 (Best ROUGE-L):")
print("=" * 70)
for rank, (idx, score) in enumerate(indexed[:3], 1):
    print(f"\n--- #{rank}: Example {idx} (ROUGE-L = {score:.4f}) ---")
    print("GENERATED (first 300 chars):")
    print(predictions[idx][:300])
    print()

print("\n" + "=" * 70)
print("BOTTOM 3 (Worst ROUGE-L):")
print("=" * 70)
for rank, (idx, score) in enumerate(indexed[-3:], 1):
    print(f"\n--- #{rank}: Example {idx} (ROUGE-L = {score:.4f}) ---")
    print("GENERATED (first 300 chars):")
    print(predictions[idx][:300])
    print("REFERENCE (first 300 chars):")
    print(references[idx][:300])
    print()

## 11. Post-Training Sanity Checks

Six gates that must **all pass** before publishing the model. These verify that the
merged model is structurally complete, generates properly terminated output, and
meets minimum quality thresholds relative to both absolute baselines and the
zero-shot model.

| Gate | Check |
|------|-------|
| 1 | Merged model directory contains required files |
| 2 | Generated outputs terminate before `max_new_tokens` |
| 3 | ROUGE-L > 0.15 |
| 4 | BERTScore F1 > 0.50 |
| 5 | Format compliance > 50% |
| 6 | Fine-tuned metrics beat zero-shot baseline |

In [None]:
# Gate 1: Merged model directory check
print("Gate 1: Merged model directory ...")
merged_dir = OUTPUT_DIR / "merged"
assert merged_dir.exists(), f"Merged dir not found: {merged_dir}"

required_files = ["config.json", "tokenizer_config.json", "tokenizer.json"]
for fname in required_files:
    assert (merged_dir / fname).exists(), f"Missing {fname} in merged dir"

safetensors = list(merged_dir.glob("*.safetensors"))
assert len(safetensors) >= 1, "No .safetensors files in merged dir"

print(f"  Directory: {merged_dir}")
print(f"  Required files: {required_files} — all present")
print(f"  Safetensors shards: {len(safetensors)}")
print("Gate 1: PASSED")

In [None]:
# Gate 2: EOS termination check
# Model was del'd during eval for BERTScore memory — reload it
print("Gate 2: EOS termination ...")
model, tokenizer = load_model(str(merged_dir))

MAX_NEW_TOKENS = 2048
MARGIN = 10
test_examples = val_data[:3]

for i, ex in enumerate(test_examples):
    weather_input = ex["messages"][1]["content"]
    messages = [
        {"role": "system", "content": "You are an expert NWS meteorologist."},
        {"role": "user", "content": weather_input},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            eos_token_id=tokenizer.convert_tokens_to_ids("<|im_end|>"),
            pad_token_id=tokenizer.convert_tokens_to_ids("<|endoftext|>"),
        )
    generated_tokens = output[0][inputs["input_ids"].shape[1]:]
    n_gen = len(generated_tokens)
    assert n_gen < MAX_NEW_TOKENS - MARGIN, (
        f"Example {i}: generated {n_gen} tokens (limit {MAX_NEW_TOKENS}), "
        f"model may not be terminating with EOS"
    )
    print(f"  Example {i}: {n_gen} tokens — OK")

# Free model again
del model
torch.cuda.empty_cache()
print("Gate 2: PASSED")

In [None]:
# Gates 3-6: Metric thresholds
print("Gate 3: ROUGE-L > 0.15 ...")
assert ft_metrics["rougeL"] > 0.15, (
    f"ROUGE-L = {ft_metrics['rougeL']:.4f} (threshold: 0.15)"
)
print(f"  ROUGE-L = {ft_metrics['rougeL']:.4f} — PASSED")

print("Gate 4: BERTScore F1 > 0.50 ...")
assert ft_metrics["bertscore_f1"] > 0.50, (
    f"BERTScore F1 = {ft_metrics['bertscore_f1']:.4f} (threshold: 0.50)"
)
print(f"  BERTScore F1 = {ft_metrics['bertscore_f1']:.4f} — PASSED")

print("Gate 5: Format compliance > 0.50 ...")
assert ft_metrics["format_compliance"] > 0.50, (
    f"Format compliance = {ft_metrics['format_compliance']:.4f} (threshold: 0.50)"
)
print(f"  Format compliance = {ft_metrics['format_compliance']:.4f} — PASSED")

print("Gate 6: Fine-tuned > zero-shot ...")
assert ft_metrics["rougeL"] > zs_metrics["rougeL"], (
    f"ROUGE-L: fine-tuned ({ft_metrics['rougeL']:.4f}) <= "
    f"zero-shot ({zs_metrics['rougeL']:.4f})"
)
assert ft_metrics["bertscore_f1"] > zs_metrics["bertscore_f1"], (
    f"BERTScore F1: fine-tuned ({ft_metrics['bertscore_f1']:.4f}) <= "
    f"zero-shot ({zs_metrics['bertscore_f1']:.4f})"
)
print(f"  ROUGE-L:     {ft_metrics['rougeL']:.4f} > {zs_metrics['rougeL']:.4f} — PASSED")
print(f"  BERTScore:   {ft_metrics['bertscore_f1']:.4f} > {zs_metrics['bertscore_f1']:.4f} — PASSED")

print()
print("=" * 60)
print("ALL SANITY CHECKS PASSED — safe to push to HF Hub")
print("=" * 60)

## 12. Push to HuggingFace Hub

Upload the merged model and GGUF quantizations to HuggingFace:

| Repo | Contents |
|------|----------|
| `ringusTheImp/wx-afd-qwen3-4b` | Merged HF model + tokenizer |
| `ringusTheImp/wx-afd-qwen3-4b-GGUF` | F16, Q8_0, Q4_K_M quantizations |

In [None]:
# Install huggingface_hub and log in
!pip install -U huggingface_hub python-dotenv

from huggingface_hub import HfApi, login

hf_token = None

# 1. Try Colab Secrets
try:
    from google.colab import userdata
    hf_token = userdata.get("HF_TOKEN")
    print("Found HF_TOKEN in Colab Secrets.")
except Exception:
    pass

# 2. Try .env file (for local runs)
if hf_token is None:
    try:
        from dotenv import load_dotenv
        load_dotenv()
        hf_token = os.environ.get("HF_TOKEN")
        if hf_token:
            print("Found HF_TOKEN in .env file.")
    except ImportError:
        pass

# 3. Fall back to interactive login
if hf_token:
    login(token=hf_token)
else:
    print("HF_TOKEN not found — using interactive login.")
    login()

api = HfApi()
whoami = api.whoami()
print(f"Logged in as: {whoami['name']}")


In [None]:
# Push merged model to HuggingFace Hub
from huggingface_hub import ModelCard

HF_REPO = "ringusTheImp/wx-afd-qwen3-4b"
GGUF_REPO = "ringusTheImp/wx-afd-qwen3-4b-GGUF"

# Build model card
model_card_text = f"""---
license: apache-2.0
base_model: Qwen/Qwen3-4B-Instruct-2507
tags:
  - weather
  - meteorology
  - nws
  - area-forecast-discussion
  - dora
  - rslora
  - axolotl
datasets:
  - custom
language:
  - en
---

# WX-AFD: Qwen3-4B for NWS Area Forecast Discussions

Fine-tuned **Qwen3-4B-Instruct-2507** with DoRA + rsLoRA to generate
NWS Area Forecast Discussions (AFDs) from structured weather model data.
Trained on Louisville, KY (WFO LMK) forecasts.

## Evaluation Results

| Metric | Fine-tuned | Zero-shot | Delta |
|--------|-----------|-----------|-------|
| ROUGE-1 | {ft_metrics['rouge1']:.4f} | {zs_metrics['rouge1']:.4f} | {ft_metrics['rouge1'] - zs_metrics['rouge1']:+.4f} |
| ROUGE-2 | {ft_metrics['rouge2']:.4f} | {zs_metrics['rouge2']:.4f} | {ft_metrics['rouge2'] - zs_metrics['rouge2']:+.4f} |
| ROUGE-L | {ft_metrics['rougeL']:.4f} | {zs_metrics['rougeL']:.4f} | {ft_metrics['rougeL'] - zs_metrics['rougeL']:+.4f} |
| BERTScore F1 | {ft_metrics['bertscore_f1']:.4f} | {zs_metrics['bertscore_f1']:.4f} | {ft_metrics['bertscore_f1'] - zs_metrics['bertscore_f1']:+.4f} |
| Format Compliance | {ft_metrics['format_compliance']:.2%} | {zs_metrics['format_compliance']:.2%} | {ft_metrics['format_compliance'] - zs_metrics['format_compliance']:+.2%} |

## Usage

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "{HF_REPO}",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("{HF_REPO}")

messages = [
    {{"role": "system", "content": "You are an expert NWS meteorologist at the Louisville, Kentucky Weather Forecast Office (WFO LMK)."}},
    {{"role": "user", "content": "<your weather data here>"}},
]
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(text, return_tensors="pt").to(model.device)
output = model.generate(**inputs, max_new_tokens=2048, temperature=0.7)
print(tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True))
```

## GGUF Quantizations

GGUF files for llama.cpp / Ollama / LM Studio are available at
[{GGUF_REPO}](https://huggingface.co/{GGUF_REPO}):
- `wx-afd-qwen3-4b-F16.gguf` — Full FP16
- `wx-afd-qwen3-4b-Q8_0.gguf` — 8-bit quantization
- `wx-afd-qwen3-4b-Q4_K_M.gguf` — 4-bit quantization (recommended for CPU)

## MLX (Apple Silicon)

MLX conversion requires macOS with Apple Silicon — run locally:

```bash
pip install mlx-lm
mlx_lm.convert --hf-path {HF_REPO} -q --upload-repo ringusTheImp/wx-afd-qwen3-4b-MLX
```

## Training Details

- **Base model:** Qwen/Qwen3-4B-Instruct-2507
- **Method:** DoRA + rsLoRA (rank 16, alpha 32)
- **Framework:** Axolotl
- **Data:** {ft_metrics['num_examples']} validation examples from Louisville, KY WFO (LMK)
- **Sequence length:** 2048 tokens (sample packing)
- **Precision:** bfloat16

## Citation

```bibtex
@misc{{wx-afd-qwen3-4b,
  title={{WX-AFD: Fine-Tuning Qwen3-4B for Area Forecast Discussions}},
  author={{ringusTheImp}},
  year={{2025}},
  url={{https://huggingface.co/{HF_REPO}}}
}}
```
"""

# Create repo and upload model
api.create_repo(HF_REPO, exist_ok=True)
print(f"Uploading merged model to {HF_REPO} ...")
api.upload_folder(
    folder_path=str(merged_dir),
    repo_id=HF_REPO,
    commit_message="Upload merged WX-AFD model",
)
print("Model uploaded.")

# Push model card
card = ModelCard(model_card_text)
card.push_to_hub(HF_REPO, commit_message="Add model card")
print(f"Model card pushed to https://huggingface.co/{HF_REPO}")


In [None]:
# Clone and build llama.cpp for GGUF conversion
!git clone https://github.com/ggml-org/llama.cpp.git /content/llama.cpp
!pip install -r /content/llama.cpp/requirements.txt

# Build llama-quantize
!cd /content/llama.cpp && cmake -B build && cmake --build build -j$(nproc)

quantize_bin = Path("/content/llama.cpp/build/bin/llama-quantize")
assert quantize_bin.exists(), f"llama-quantize not found at {quantize_bin}"
print(f"llama-quantize binary: {quantize_bin}")


In [None]:
# Convert to GGUF and quantize
gguf_dir = OUTPUT_DIR / "gguf"
gguf_dir.mkdir(parents=True, exist_ok=True)

f16_path = gguf_dir / "wx-afd-qwen3-4b-F16.gguf"
q8_path = gguf_dir / "wx-afd-qwen3-4b-Q8_0.gguf"
q4_path = gguf_dir / "wx-afd-qwen3-4b-Q4_K_M.gguf"

# Step 1: HF → F16 GGUF
print("Converting HF model to F16 GGUF ...")
!python /content/llama.cpp/convert_hf_to_gguf.py \
    {str(merged_dir)} \
    --outtype f16 \
    --outfile {str(f16_path)}
assert f16_path.exists(), f"F16 conversion failed: {f16_path}"
print(f"  F16: {f16_path.stat().st_size / 1e9:.2f} GB")

# Step 2: F16 → Q8_0
print("Quantizing F16 → Q8_0 ...")
!{quantize_bin} {str(f16_path)} {str(q8_path)} Q8_0
assert q8_path.exists(), f"Q8_0 quantization failed: {q8_path}"
print(f"  Q8_0: {q8_path.stat().st_size / 1e9:.2f} GB")

# Step 3: F16 → Q4_K_M
print("Quantizing F16 → Q4_K_M ...")
!{quantize_bin} {str(f16_path)} {str(q4_path)} Q4_K_M
assert q4_path.exists(), f"Q4_K_M quantization failed: {q4_path}"
print(f"  Q4_K_M: {q4_path.stat().st_size / 1e9:.2f} GB")

print()
print("GGUF files (saved to Google Drive):")
for p in [f16_path, q8_path, q4_path]:
    print(f"  {p.name}: {p.stat().st_size / 1e9:.2f} GB")

In [None]:
# Push GGUF files to HuggingFace Hub
api.create_repo(GGUF_REPO, exist_ok=True)

gguf_files = [f16_path, q8_path, q4_path]
for gf in gguf_files:
    print(f"Uploading {gf.name} ({gf.stat().st_size / 1e9:.2f} GB) ...")
    api.upload_file(
        path_or_fileobj=str(gf),
        path_in_repo=gf.name,
        repo_id=GGUF_REPO,
        commit_message=f"Upload {gf.name}",
    )
    print(f"  Uploaded: {gf.name}")

# Push GGUF model card
gguf_card_text = f"""---
license: apache-2.0
base_model: Qwen/Qwen3-4B-Instruct-2507
tags:
  - weather
  - meteorology
  - nws
  - gguf
  - llama-cpp
---

# WX-AFD: Qwen3-4B GGUF Quantizations

GGUF quantizations of [{HF_REPO}](https://huggingface.co/{HF_REPO}) for
llama.cpp, Ollama, and LM Studio.

## Available Files

| File | Quant | Size |
|------|-------|------|
| `wx-afd-qwen3-4b-F16.gguf` | F16 | {f16_path.stat().st_size / 1e9:.2f} GB |
| `wx-afd-qwen3-4b-Q8_0.gguf` | Q8_0 | {q8_path.stat().st_size / 1e9:.2f} GB |
| `wx-afd-qwen3-4b-Q4_K_M.gguf` | Q4_K_M | {q4_path.stat().st_size / 1e9:.2f} GB |

## Usage with llama.cpp

```bash
# Download
huggingface-cli download {GGUF_REPO} wx-afd-qwen3-4b-Q4_K_M.gguf --local-dir .

# Run
llama-cli -m wx-afd-qwen3-4b-Q4_K_M.gguf -p "<weather data>" -n 2048
```

## Usage with Ollama

```bash
# Create Modelfile
echo 'FROM ./wx-afd-qwen3-4b-Q4_K_M.gguf' > Modelfile
ollama create wx-afd -f Modelfile
ollama run wx-afd
```

See the [full model card](https://huggingface.co/{HF_REPO}) for evaluation results and training details.
"""

gguf_card = ModelCard(gguf_card_text)
gguf_card.push_to_hub(GGUF_REPO, commit_message="Add GGUF model card")
print(f"\nGGUF repo: https://huggingface.co/{GGUF_REPO}")


## 13. Next Steps

**Completed in this notebook:**
- Environment setup (Colab dependencies + Google Drive)
- Data pipeline (clone repo + run scripts 01-03 if needed)
- Data inspection and validation
- Config verification with all corrections applied
- Sanity check (loss masking, token IDs, dataset loading)
- Model training (DoRA + rsLoRA, 3 epochs with early stopping)
- LoRA merge and export
- Full evaluation: ROUGE, BERTScore, format compliance
- Zero-shot baseline comparison
- Post-training sanity checks (6 gates)
- HuggingFace Hub push: [ringusTheImp/wx-afd-qwen3-4b](https://huggingface.co/ringusTheImp/wx-afd-qwen3-4b)
- GGUF quantizations (F16, Q8_0, Q4_K_M): [ringusTheImp/wx-afd-qwen3-4b-GGUF](https://huggingface.co/ringusTheImp/wx-afd-qwen3-4b-GGUF)

**MLX conversion (run locally on Apple Silicon):**
```bash
pip install mlx-lm
mlx_lm.convert --hf-path ringusTheImp/wx-afd-qwen3-4b -q --upload-repo ringusTheImp/wx-afd-qwen3-4b-MLX
```

**Future work:**
- Multi-WFO training (expand beyond Louisville)
- AlignScore factual consistency evaluation
- GRPO/DPO alignment with AlignScore as reward signal
- Attribute-specific LoRA adapters (synopsis vs aviation)
- vLLM/TGI deployment for real-time inference