# Phase 0: Baseline Reproduction (DnD on Qwen2.5-1.5B)

This notebook verifies the existing DnD framework works end-to-end for Qwen2.5-1.5B LoRA generation before adding delta supervision.

## Goals
- Verify all DnD imports work correctly
- Test the LoRA tokenizer roundtrip
- Run a minimal training loop
- Generate and save a LoRA checkpoint

## Step 1: Environment Setup & Imports

In [4]:
import sys
import os

# Detect environment
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("Running in Google Colab")
    
    # Option 1: Clone repo if not present (uncomment and set your repo URL)
    # if not os.path.exists("dnd_repo"):
    #     !git clone https://github.com/YOUR_USERNAME/dnd_repo.git dnd_repo
    
    # Option 2: Mount Google Drive if repo is there
    # from google.colab import drive
    # drive.mount('/content/drive')
    # !cp -r /content/drive/MyDrive/dnd_repo /content/dnd_repo
    
    # Option 3: Upload dnd_repo.zip and extract
    # !unzip dnd_repo.zip
    
    # Install dependencies
    !pip install -q safetensors accelerate transformers sentence-transformers
    
    # Check if dnd_repo exists
    if not os.path.exists("dnd_repo"):
        print("\n" + "="*60)
        print("ERROR: dnd_repo not found!")
        print("Please do one of the following:")
        print("  1. Upload dnd_repo folder to Colab")
        print("  2. Uncomment the git clone line above and set your repo URL")
        print("  3. Mount Google Drive and copy the repo")
        print("="*60)
else:
    print("Running locally")

# Add DnD to path - this allows importing as 'workspace.dnd...'
DND_PATH = os.path.abspath("dnd_repo")
if DND_PATH not in sys.path:
    sys.path.insert(0, DND_PATH)

print(f"\nWorking directory: {os.getcwd()}")
print(f"DnD path: {DND_PATH}")

if os.path.exists(DND_PATH):
    contents = os.listdir(DND_PATH)
    print(f"DnD contents: {contents[:5]}{'...' if len(contents) > 5 else ''}")
    if 'workspace' in contents:
        print("[OK] workspace folder found")
else:
    print("[ERROR] DnD path does not exist!")

Running in Google Colab

ERROR: dnd_repo not found!
Please do one of the following:
  1. Upload dnd_repo folder to Colab
  2. Uncomment the git clone line above and set your repo URL
  3. Mount Google Drive and copy the repo

Working directory: /content
DnD path: /content/dnd_repo
[ERROR] DnD path does not exist!


In [2]:
# Core imports
import json
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from collections import OrderedDict
from pathlib import Path
from tqdm.auto import tqdm

# Check CUDA availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

PyTorch version: 2.9.0+cpu
CUDA available: False


## Step 2: Verify DnD Imports

In [None]:
def verify_imports():
    """Verify all DnD components can be imported."""
    results = {}
    
    # Model components
    try:
        from workspace.dnd.model.decoderonly import (
            HyperConvDecoderModel,
            HyperConvDecoderModel_FullCond,
            HyperConvDecoderModel_SuperLarge
        )
        results["Model imports"] = "OK"
    except Exception as e:
        results["Model imports"] = f"FAIL: {e}"
    
    # Tokenizer
    try:
        from workspace.dnd.tokenizer.register import Qwen2515LoRA_Tokenizer2D
        results["Tokenizer imports"] = "OK"
    except Exception as e:
        results["Tokenizer imports"] = f"FAIL: {e}"
    
    # Dataset
    try:
        from workspace.dnd.dataset.register import (
            Text2Qwen25LoRA_FullCondDataset,
            Text2Qwen25LoRA_CondQ_ADataset
        )
        results["Dataset imports"] = "OK"
    except Exception as e:
        results["Dataset imports"] = f"FAIL: {e}"
    
    # Modules
    try:
        from workspace.dnd.module.hyperconv import HyperConvDecoder
        results["HyperConv imports"] = "OK"
    except Exception as e:
        results["HyperConv imports"] = f"FAIL: {e}"
    
    # Tools
    try:
        from workspace.dnd.tools import load_safetensors, save_safetensors
        results["Tools imports"] = "OK"
    except Exception as e:
        results["Tools imports"] = f"FAIL: {e}"
    
    # Print results
    all_ok = True
    for name, status in results.items():
        icon = "[OK]" if status == "OK" else "[FAIL]"
        print(f"{icon} {name}: {status}")
        if status != "OK":
            all_ok = False
    
    print("\n" + ("All imports successful!" if all_ok else "Some imports failed!"))
    return all_ok

verify_imports()

## Step 3: Prepare Sample Data

Create dummy LoRA checkpoints for testing the pipeline.

In [None]:
from safetensors.torch import save_file, load_file

def create_dummy_lora_checkpoint(output_path: str, rank: int = 16):
    """
    Create a dummy LoRA checkpoint matching Qwen2.5-1.5B structure.
    
    Args:
        output_path: Full path to the .safetensors file (not directory)
        rank: LoRA rank
    
    LoRA targets for Qwen2.5-1.5B:
    - q_proj, k_proj, v_proj, o_proj (attention)
    - gate_proj, up_proj, down_proj (MLP)
    """
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # Qwen2.5-1.5B config
    hidden_size = 1536
    intermediate_size = 8960
    num_layers = 28
    num_kv_heads = 2  # GQA: fewer KV heads
    num_heads = 12
    head_dim = hidden_size // num_heads
    kv_dim = num_kv_heads * head_dim  # 256
    
    lora_weights = {}
    
    for layer_idx in range(num_layers):
        prefix = f"base_model.model.model.layers.{layer_idx}"
        
        # Q and O projections: (hidden_size, hidden_size)
        for proj in ["q_proj", "o_proj"]:
            lora_weights[f"{prefix}.self_attn.{proj}.lora_A.weight"] = torch.randn(rank, hidden_size) * 0.01
            lora_weights[f"{prefix}.self_attn.{proj}.lora_B.weight"] = torch.zeros(hidden_size, rank)
        
        # K and V projections: (kv_dim, hidden_size) for GQA
        for proj in ["k_proj", "v_proj"]:
            lora_weights[f"{prefix}.self_attn.{proj}.lora_A.weight"] = torch.randn(rank, hidden_size) * 0.01
            lora_weights[f"{prefix}.self_attn.{proj}.lora_B.weight"] = torch.zeros(kv_dim, rank)
        
        # MLP: gate_proj and up_proj (intermediate_size, hidden_size)
        for proj in ["gate_proj", "up_proj"]:
            lora_weights[f"{prefix}.mlp.{proj}.lora_A.weight"] = torch.randn(rank, hidden_size) * 0.01
            lora_weights[f"{prefix}.mlp.{proj}.lora_B.weight"] = torch.zeros(intermediate_size, rank)
        
        # MLP: down_proj (hidden_size, intermediate_size)
        lora_weights[f"{prefix}.mlp.down_proj.lora_A.weight"] = torch.randn(rank, intermediate_size) * 0.01
        lora_weights[f"{prefix}.mlp.down_proj.lora_B.weight"] = torch.zeros(hidden_size, rank)
    
    # Convert to bfloat16
    lora_weights = {k: v.to(torch.bfloat16) for k, v in lora_weights.items()}
    
    # Save checkpoint
    save_file(lora_weights, output_path)
    
    total_params = sum(p.numel() for p in lora_weights.values())
    print(f"Created: {output_path}")
    print(f"  Layers: {num_layers}, Rank: {rank}, Params: {total_params:,}")
    
    return lora_weights

In [None]:
# Create sample checkpoints directory structure
# DnD dataset expects: a folder containing ONLY .safetensors files
DATA_DIR = Path("data/sample_checkpoints/math_loras")
DATA_DIR.mkdir(parents=True, exist_ok=True)

# Create multiple dummy LoRA checkpoints as individual .safetensors files
NUM_CHECKPOINTS = 5
for i in range(NUM_CHECKPOINTS):
    ckpt_path = DATA_DIR / f"{i}.safetensors"
    create_dummy_lora_checkpoint(str(ckpt_path), rank=16)

print(f"\nCreated {NUM_CHECKPOINTS} checkpoints in {DATA_DIR}")
print(f"Files: {list(DATA_DIR.glob('*.safetensors'))}")

In [None]:
# Create sample prompts
PROMPTS_DIR = Path("data/sample_prompts")
PROMPTS_DIR.mkdir(parents=True, exist_ok=True)

sample_prompts = [
    {"prompt": "Solve the equation: 2x + 5 = 15", "response": "x = 5"},
    {"prompt": "What is the derivative of x^2?", "response": "2x"},
    {"prompt": "Calculate the area of a circle with radius 5", "response": "25*pi"},
    {"prompt": "Simplify: (3x + 2)(x - 4)", "response": "3x^2 - 10x - 8"},
    {"prompt": "Find the roots of x^2 - 5x + 6 = 0", "response": "x = 2 or x = 3"},
    {"prompt": "What is 15% of 80?", "response": "12"},
    {"prompt": "Convert 3/4 to a decimal", "response": "0.75"},
    {"prompt": "What is the sum of angles in a triangle?", "response": "180 degrees"},
    {"prompt": "Calculate: 125 / 5 + 3 * 4", "response": "37"},
    {"prompt": "Find the GCD of 24 and 36", "response": "12"},
] * 10  # Repeat for more samples

prompts_path = PROMPTS_DIR / "math_prompts.json"
with open(prompts_path, "w") as f:
    json.dump(sample_prompts, f, indent=2)

print(f"Created {len(sample_prompts)} sample prompts at {prompts_path}")

## Step 4: Tokenizer Sanity Check

Test that the tokenizer can convert LoRA weights to tokens and back.

In [None]:
from workspace.dnd.tokenizer.register import Qwen2515LoRA_Tokenizer2D
from workspace.dnd.tools import load_safetensors

def test_tokenizer_roundtrip(checkpoint_path: str):
    """Test that tokenize -> detokenize recovers original weights."""
    
    # Load checkpoint
    weights = load_safetensors(checkpoint_path, map_location="cpu", dtype=torch.bfloat16)
    
    # Remove base_model.model. prefix (matching the dataset's post_process)
    weights = {k.replace("base_model.model.", ""): v for k, v in weights.items()}
    weights = OrderedDict(sorted(weights.items()))
    
    print(f"Loaded {len(weights)} tensors from checkpoint")
    print(f"Sample keys: {list(weights.keys())[:3]}")
    
    # Initialize tokenizer with Qwen2.5-1.5B token size
    token_size = (18, 258)  # From the training script
    tokenizer = Qwen2515LoRA_Tokenizer2D(token_size=token_size)
    
    # Make a copy for comparison
    original_weights = {k: v.clone() for k, v in weights.items()}
    
    # Tokenize
    tokens, scales = tokenizer.tokenize(weights)
    print(f"\nTokenized:")
    print(f"  - Tokens shape: {tokens.shape}")
    print(f"  - Token dtype: {tokens.dtype}")
    print(f"  - Has NaN (padding): {torch.isnan(tokens).any().item()}")
    
    # Load fresh weights for fake_diction
    fake_weights = load_safetensors(checkpoint_path, map_location="cpu", dtype=torch.bfloat16)
    fake_weights = {k.replace("base_model.model.", ""): v for k, v in fake_weights.items()}
    fake_weights = OrderedDict(sorted(fake_weights.items()))
    
    # Detokenize
    reconstructed = tokenizer.detokenize(fake_weights, tokens)
    print(f"\nReconstructed {len(reconstructed)} tensors")
    
    # Compare
    total_error = 0.0
    max_error = 0.0
    error_count = 0
    
    for key in original_weights:
        if key in reconstructed:
            orig = original_weights[key].float()
            recon = reconstructed[key].float()
            error = torch.abs(orig - recon).mean().item()
            total_error += error
            max_error = max(max_error, error)
            if error > 1e-3:
                error_count += 1
                if error_count <= 5:  # Only print first 5 errors
                    print(f"  High error in {key}: {error:.6f}")
    
    avg_error = total_error / len(original_weights)
    print(f"\n=== Results ===")
    print(f"Average reconstruction error: {avg_error:.8f}")
    print(f"Max reconstruction error: {max_error:.8f}")
    print(f"Keys with error > 1e-3: {error_count}")
    
    if avg_error < 1e-3:
        print("\n[PASS] Tokenizer roundtrip test")
    else:
        print("\n[WARN] Reconstruction error above threshold")
    
    return tokens, avg_error

In [None]:
# Test on our dummy checkpoint
test_ckpt = "data/sample_checkpoints/math_loras/0.safetensors"
tokens, error = test_tokenizer_roundtrip(test_ckpt)

## Step 5: Minimal Training Loop

Set up a minimal training loop to verify the pipeline works end-to-end.

In [None]:
from transformers import AutoModel, AutoTokenizer
from workspace.dnd.model.decoderonly import HyperConvDecoderModel_SuperLarge
from workspace.dnd.tokenizer.register import Qwen2515LoRA_Tokenizer2D
from workspace.dnd.dataset.register import Text2Qwen25LoRA_FullCondDataset

# Configuration
CONFIG = {
    # Data settings
    "token_size": (18, 258),
    "max_text_length": 512,  # Reduced for testing
    "modified_length": 128,  # Reduced for testing
    "num_texts": 4,  # Reduced for testing
    "batch_size": 2,
    "real_length": 5,  # Number of checkpoints to use from the folder
    
    # Training settings
    "total_steps": 50,
    "learning_rate": 1e-4,
    "log_interval": 10,
    
    # Model settings  
    "extractor_type": "BERT",
    "extractor_model": "sentence-transformers/all-MiniLM-L6-v2",  # Small model for testing
    
    # Paths - single folder containing .safetensors files
    "checkpoint_folders": ["data/sample_checkpoints/math_loras"],
    "prompts_path": "data/sample_prompts/math_prompts.json",
    "output_dir": "outputs/phase0_baseline",
}

print("Configuration:")
for k, v in CONFIG.items():
    print(f"  {k}: {v}")

In [None]:
# Initialize components
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# LoRA tokenizer
lora_tokenizer = Qwen2515LoRA_Tokenizer2D(token_size=CONFIG["token_size"])
print("[OK] LoRA tokenizer initialized")

# Text tokenizer and condition model
text_tokenizer = AutoTokenizer.from_pretrained(CONFIG["extractor_model"])
condition_model = AutoModel.from_pretrained(CONFIG["extractor_model"], torch_dtype=torch.float32)
condition_model = condition_model.to(device)
condition_model.eval()
for param in condition_model.parameters():
    param.requires_grad = False
print(f"[OK] Text encoder loaded: {CONFIG['extractor_model']}")
print(f"     Hidden size: {condition_model.config.hidden_size}")

In [None]:
# Load prompts
with open(CONFIG["prompts_path"], "r") as f:
    prompts_data = json.load(f)

# Create datasets for each checkpoint folder
texts_per_folder = [prompts_data] * len(CONFIG["checkpoint_folders"])

# Initialize dataset
Text2Qwen25LoRA_FullCondDataset.dtype = torch.bfloat16

dataset = Text2Qwen25LoRA_FullCondDataset(
    checkpoint_folders=CONFIG["checkpoint_folders"],
    tokenizer=lora_tokenizer,
    num_texts=CONFIG["num_texts"],
    texts=texts_per_folder,
    max_text_length=CONFIG["max_text_length"],
    text_tokenizer=text_tokenizer,
    expected_iteration=CONFIG["total_steps"] * CONFIG["batch_size"] * 2,
    real_length=CONFIG["real_length"],
)

print(f"[OK] Dataset initialized")
print(f"     Total checkpoints: {dataset.real_length}")
print(f"     Dataset length: {len(dataset)}")

In [None]:
# Test loading a single sample
sample = dataset[0]
tokens, condition, path = sample
print(f"Sample loaded from: {path}")
print(f"Tokens shape: {tokens.shape}")
print(f"Condition type: {type(condition)}")
print(f"Condition input_ids shape: {condition.input_ids.shape}")

In [None]:
# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=CONFIG["batch_size"],
    shuffle=True,
    collate_fn=dataset.collate_fn_train,
    drop_last=True,
)

print(f"[OK] DataLoader created with batch_size={CONFIG['batch_size']}")

In [None]:
# Get hidden size from condition model
hidden_size = condition_model.config.hidden_size  # 384 for MiniLM-L6

# Calculate criterion weight (uniform for testing)
# In real training, this would be computed from actual LoRA statistics
num_tokens = tokens.shape[0]  # From sample
criterion_weight = torch.ones(num_tokens)

# Model configuration - scaled down for testing
model_config = {
    "features": [
        (CONFIG["num_texts"], CONFIG["modified_length"], hidden_size),  # Input
        (32, 64, 128),   # Intermediate
        (64, 32, 128),   # Intermediate
        (num_tokens, CONFIG["token_size"][0], CONFIG["token_size"][1]),  # Output
    ],
    "condition_dim": (CONFIG["num_texts"], CONFIG["modified_length"], hidden_size),
    "kernel_size": 5,
}

print("Model configuration:")
print(f"  Features: {model_config['features']}")
print(f"  Condition dim: {model_config['condition_dim']}")
print(f"  Criterion weight shape: {criterion_weight.shape}")

In [None]:
# Initialize model
model = HyperConvDecoderModel_SuperLarge(
    config=model_config,
    criterion_weight=criterion_weight.view(1, -1, 1, 1),
    max_length=CONFIG["max_text_length"],
    modified_length=CONFIG["modified_length"],
    extractor_type=CONFIG["extractor_type"],
    extra_condition_module=condition_model,
    freeze_extra_condition=True,
).to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"[OK] Model initialized")
print(f"     Total parameters: {total_params:,}")
print(f"     Trainable parameters: {trainable_params:,}")

In [None]:
# Initialize optimizer
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=CONFIG["learning_rate"],
    weight_decay=0.01,
)

print(f"[OK] Optimizer initialized with lr={CONFIG['learning_rate']}")

In [None]:
# Training loop
print("\n" + "="*50)
print("Starting training...")
print("="*50 + "\n")

model.train()
losses = []
global_step = 0

pbar = tqdm(total=CONFIG["total_steps"], desc="Training")

for epoch in range(100):  # Max epochs (will break when steps reached)
    for batch_idx, (tokens, cond_id, cond_mask) in enumerate(dataloader):
        # Move to device
        tokens = tokens.to(device)
        conditions = {
            "input_ids": cond_id.to(device),
            "attention_mask": cond_mask.to(device),
        }
        
        # Handle NaN padding
        mask = ~torch.isnan(tokens)
        tokens = torch.nan_to_num(tokens, nan=0.0)
        
        # Forward pass
        optimizer.zero_grad()
        
        with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu", 
                           dtype=torch.bfloat16):
            loss = model(
                source=None,
                mask=mask,
                condition=conditions,
                target=tokens,
            )
        
        # Backward pass
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        # Logging
        loss_val = loss.item()
        losses.append(loss_val)
        global_step += 1
        
        pbar.update(1)
        pbar.set_postfix({"loss": f"{loss_val:.4f}"})
        
        if global_step % CONFIG["log_interval"] == 0:
            avg_loss = sum(losses[-CONFIG["log_interval"]:]) / CONFIG["log_interval"]
            print(f"Step {global_step}: loss = {avg_loss:.6f}")
        
        if global_step >= CONFIG["total_steps"]:
            break
    
    if global_step >= CONFIG["total_steps"]:
        break

pbar.close()
print(f"\nTraining complete!")
print(f"Final loss: {losses[-1]:.6f}")
print(f"Average loss (last 10): {sum(losses[-10:])/10:.6f}")

In [None]:
# Plot training curve
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Check if loss decreased
first_10_avg = sum(losses[:10]) / 10
last_10_avg = sum(losses[-10:]) / 10
print(f"\nFirst 10 steps avg loss: {first_10_avg:.6f}")
print(f"Last 10 steps avg loss: {last_10_avg:.6f}")
print(f"Loss {'decreased' if last_10_avg < first_10_avg else 'did not decrease'} over training")

## Step 6: Generation Test

Test generating a LoRA checkpoint from the trained model.

In [None]:
def generate_lora(model, text_condition: str, lora_tokenizer, text_tokenizer, device, 
                  num_texts=4, max_length=512):
    """Generate a LoRA checkpoint from a text condition."""
    model.eval()
    
    # Tokenize condition (need num_texts copies)
    inputs = text_tokenizer(
        [text_condition] * num_texts,
        return_tensors="pt",
        padding="max_length",
        max_length=max_length,
        truncation=True,
    )
    
    # Add batch dimension
    conditions = {
        "input_ids": inputs.input_ids.unsqueeze(0).to(device),
        "attention_mask": inputs.attention_mask.unsqueeze(0).to(device),
    }
    
    # Generate tokens
    with torch.no_grad():
        with torch.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu",
                           dtype=torch.bfloat16):
            tokens_pred = model.generate(
                source=None,
                mask=None,
                condition=conditions,
                target=None,
            )
    
    print(f"Generated tokens shape: {tokens_pred.shape}")
    
    return tokens_pred

In [None]:
# Generate from a test prompt
test_prompt = "Solve mathematical equations step by step"
print(f"Generating LoRA for prompt: '{test_prompt}'")

generated_tokens = generate_lora(
    model=model,
    text_condition=test_prompt,
    lora_tokenizer=lora_tokenizer,
    text_tokenizer=text_tokenizer,
    device=device,
    num_texts=CONFIG["num_texts"],
    max_length=CONFIG["max_text_length"],
)

print(f"\nGenerated tokens stats:")
print(f"  Shape: {generated_tokens.shape}")
print(f"  Min: {generated_tokens.min().item():.4f}")
print(f"  Max: {generated_tokens.max().item():.4f}")
print(f"  Mean: {generated_tokens.mean().item():.4f}")
print(f"  Std: {generated_tokens.std().item():.4f}")

In [None]:
# Detokenize to actual LoRA weights
# Load a fake_diction template
template_path = "data/sample_checkpoints/math_loras/0.safetensors"
fake_weights = load_safetensors(template_path, map_location="cpu", dtype=torch.bfloat16)
fake_weights = {k.replace("base_model.model.", ""): v for k, v in fake_weights.items()}
fake_weights = OrderedDict(sorted(fake_weights.items()))

# Detokenize
generated_lora = lora_tokenizer.detokenize(fake_weights, generated_tokens[0].cpu())

print(f"\nDetokenized LoRA weights:")
print(f"  Number of tensors: {len(generated_lora)}")
print(f"  Total parameters: {sum(p.numel() for p in generated_lora.values()):,}")
print(f"\nSample weights:")
for i, (k, v) in enumerate(generated_lora.items()):
    if i < 3:
        print(f"  {k}: shape={v.shape}, mean={v.float().mean():.6f}, std={v.float().std():.6f}")

In [None]:
# Save the generated LoRA
output_dir = Path(CONFIG["output_dir"])
output_dir.mkdir(parents=True, exist_ok=True)

# Add back the base_model.model. prefix
save_weights = {f"base_model.model.{k}": v.contiguous() for k, v in generated_lora.items()}

output_path = output_dir / "generated_lora.safetensors"
save_file(save_weights, str(output_path))

print(f"[OK] Saved generated LoRA to {output_path}")
print(f"     File size: {output_path.stat().st_size / 1024:.1f} KB")

In [None]:
# Save the model checkpoint
model_path = output_dir / "model.pth"

# Get state dict without the frozen condition module
state_dict = model.state_dict()
keys_to_remove = [k for k in state_dict.keys() if k.startswith("condition_module")]
for k in keys_to_remove:
    del state_dict[k]

torch.save(state_dict, model_path)

print(f"[OK] Saved model checkpoint to {model_path}")
print(f"     File size: {model_path.stat().st_size / 1024 / 1024:.1f} MB")

## Summary: Acceptance Criteria Check

In [None]:
print("="*60)
print("Phase 0 Acceptance Criteria")
print("="*60)

criteria = {
    "All imports succeed": verify_imports(),
    "Tokenizer roundtrip error < 1e-3": error < 1e-3,
    "Training loop runs without OOM": True,  # If we got here, it passed
    "Loss decreased over training": last_10_avg < first_10_avg,
    "Can generate and save LoRA checkpoint": output_path.exists(),
}

print()
all_passed = True
for criterion, passed in criteria.items():
    status = "[PASS]" if passed else "[FAIL]"
    print(f"{status} {criterion}")
    if not passed:
        all_passed = False

print()
if all_passed:
    print("All acceptance criteria PASSED!")
    print("Ready to proceed to Phase 1.")
else:
    print("Some criteria FAILED. Please review and fix issues before proceeding.")

## Next Steps

Once Phase 0 is verified, proceed to **Phase 1** to compute delta embeddings for teacher LoRAs.