# H2C Bridge - Colab Development Notebook

In [None]:
# Clone repository
REPO_URL = "https://github.com/YOUR_USERNAME/h2c-bridge.git"  # UPDATE THIS

import os
if not os.path.exists("/content/h2c-bridge"):
    !git clone $REPO_URL /content/h2c-bridge
    print("✅ Repository cloned")
else:
    print("Repository already exists, pulling latest changes...")
    !cd /content/h2c-bridge && git pull

%cd /content/h2c-bridge
print(f"Working directory: {os.getcwd()}")

In [None]:
# Install package in editable mode
!pip install -q -e .
!pip install -q -U bitsandbytes
print("✅ Package installed")

In [None]:
# Enable autoreload for live code updates
%load_ext autoreload
%autoreload 2
print("✅ Autoreload enabled")

In [None]:
# Verify WandB authentication
import wandb

# Set up once by running 'wandb login'
try:
    wandb.login()
    print("✅ WandB authenticated")
    print(f"   Logged in as: {wandb.api.viewer()['entity']}")
except Exception as e:
    print("⚠️ WandB login failed. Run 'wandb login' in your local terminal.")
    print(f"   Error: {e}")

## 1. Initialize Components

Set up models, data, and configuration. Edit `config.py` or `factory.py` and re-run this cell to test changes.

In [None]:
from h2c_bridge.config import get_default_config
from h2c_bridge.factory import H2CModelFactory
from h2c_bridge.data.datamodule import H2CDataModule
from h2c_bridge.utils import set_seed, clear_gpu

# Set seed for reproducibility
set_seed(42)

# Get config and customize for development
config = get_default_config()
config.update({
    "SHARER_ID": "meta-llama/Llama-3.1-8B-Instruct",
    "RECEIVER_ID": "Qwen/Qwen2.5-0.5B-Instruct",
    "MAX_SAMPLES": 10_000,   # Smaller for faster iteration
    "BATCH_SIZE": 4,
    "epochs": 1,
    "eval_every": 100,
    "verbose": True,
})

print("Initializing Factory...")
factory = H2CModelFactory(config["SHARER_ID"], config["RECEIVER_ID"])
tok_sharer, tok_receiver = factory.load_tokenizers()

print("Initializing Data Module...")
dm = H2CDataModule(tok_sharer, tok_receiver, config)
print("✅ Setup complete")

## 2. Quick Baseline Check (Optional)

Verify the evaluation pipeline works before training.

In [None]:
from h2c_bridge.training.engine import H2CEngine

# Create engine
engine = H2CEngine(factory, dm, config, lr=1e-4, eval_every=100)

# Run quick baseline check
print("Running baseline evaluation...")
baseline_results = engine.mmlu_evaluator.evaluate_baselines(engine.mmlu_loader)
config["BASELINES"] = baseline_results

## 3. Training

Train the bridge. Checkpoints automatically upload to WandB as artifacts with aliases:
- `latest`: most recent checkpoint
- `best`: highest accuracy
- `final`: end of training

In [None]:
# Clear GPU and re-initialize for clean training run
clear_gpu()
engine = H2CEngine(factory, dm, config, lr=1e-4, eval_every=100)

# Start training
engine.run(epochs=1)

## 4. Resume from Checkpoint (Optional)

Load a checkpoint from WandB artifacts to continue training or run inference.

In [None]:
# Example: Load the best checkpoint
# Format: "entity/project/artifact_name:alias"
# You can find this in your WandB UI under Artifacts

# ARTIFACT_PATH = "your-entity/nlp_project/bridge_Llama-3-1-8B-Instruct_TO_Qwen2-5-0-5B-Instruct_checkpoint:best"
# engine.load_checkpoint(ARTIFACT_PATH)
# print("Checkpoint loaded! You can now continue training or run inference.")

## 5. Debugging & Analysis

Test specific components or run ad-hoc experiments.

In [None]:
# Test a specific prompt
prompt = "Explain how a CPU works."
engine.evaluator.generate_demo(prompt, max_new_tokens=100)

In [None]:
# Check bridge gate statistics
stats = engine.bridge.get_gate_stats()
print(f"Key Gate Avg: {stats['key_avg']:.4f}")
print(f"Value Gate Avg: {stats['value_avg']:.4f}")

## 6. Visualizations (Optional)

Generate publication-ready visualizations and upload to WandB.

In [None]:
# Uncomment to run full visualization suite
# from h2c_bridge.visualization import run_all_visualizations
# run_all_visualizations(engine, config, themes=("dark", "light"))