# H2C Bridge - Colab Development Notebook

In [None]:
# Clone repository
REPO_URL = "https://github.com/parkerpettit/h2c-bridge.git"

import os
if not os.path.exists("/content/h2c-bridge"):
    !git clone $REPO_URL /content/h2c-bridge
    print("Repository cloned")
else:
    print("Repository already exists, pulling latest changes...")
    !cd /content/h2c-bridge && git pull

%cd /content/h2c-bridge
print(f"Working directory: {os.getcwd()}")

In [None]:
# Install package in editable mode
!pip install -q -e .
!pip install -q -U bitsandbytes
print("Package installed")

In [None]:
import os
from google.colab import userdata
from huggingface_hub import login
import wandb


HF_SECRET_NAME = "HF_TOKEN"   # change if needed

hf_token = userdata.get(HF_SECRET_NAME)
if hf_token is None:
    raise ValueError(f"No Hugging Face token found in Colab secrets under '{HF_SECRET_NAME}'.")

os.environ["HF_TOKEN"] = hf_token
login(token=hf_token)
print("Logged in to HuggingFace")

wandb_key = userdata.get("wandb")
if wandb_key is None:
    raise ValueError("No Weights & Biases API key found in Colab secrets under 'wandb'.")

os.environ["WANDB_API_KEY"] = wandb_key
wandb.login(key=wandb_key)
print(f"Logged in as: {wandb.api.viewer()['entity']}")

## 1. Initialize Components

Set up models, data, and configuration. Edit `config.py` or `factory.py` and re-run this cell to test changes.

In [None]:
from h2c_bridge.config import get_default_config
from h2c_bridge.factory import H2CModelFactory
from h2c_bridge.data.datamodule import H2CDataModule
from h2c_bridge.utils import set_seed, clear_gpu

# Set seed for reproducibility
set_seed(42)

# Get config and customize for development
config = get_default_config()
config.update({
        "SHARER_ID": "meta-llama/Llama-3.1-8B-Instruct",
        "RECEIVER_ID": "Qwen/Qwen2.5-0.5B-Instruct",

        # Dataset size
        "MAX_SAMPLES": 250_000,  # max samples of OpenHermes to pretrain bridge on
        "BATCH_SIZE": 12,
        "lr": 1e-4,

        # Evaluation frequency (in steps)
        "eval_every": 500,
        "log_bridge_every": 50,  # log bridge gate stats to wandb

        # Training
        "epochs": 1,
        "gate_warmup_steps": 0,

        # MMLU evaluation
        # set sample size to None for all samples
        "mmlu_sample_size": 100,  # samples per category (57 categories in validation = 57 total samples)
        "mmlu_eval_split": "validation",  # ~1.5k samples, or set to test for 14k
        # Logging
        "wandb_log_examples": 10,  # number of examples to log to WandB per eval mode
        "BASELINES": {
            "receiver_only": {"acc": 0.3566, "latency_s": 0.3349},
            "sharer_only": {"acc": 0.6242, "latency_s": 0.6213},
            "text_to_text": {"acc": 0.3354, "latency_s": 1.2960},
        }
    })




print("Initializing Factory...")
factory = H2CModelFactory(config["SHARER_ID"], config["RECEIVER_ID"])
tok_sharer, tok_receiver = factory.load_tokenizers()

print("Initializing Data Module...")
dm = H2CDataModule(tok_sharer, tok_receiver, config)
print("Setup complete")

[receiver_only] Acc: 35.66% | Err: 0.82% | Latency: 0.3349s  
[sharer_only] Acc: 62.42% | Err: 0.55% | Latency: 0.3454s  
[text_to_text] Acc: 33.54% | Err: 0.48% | Latency: 1.2960s  


## 2. Quick Baseline Check (Optional)

Verify the evaluation pipeline works before training.

In [None]:
from h2c_bridge.training.engine import H2CEngine

# Create engine
engine = H2CEngine(factory, dm, config, config["lr"], config["eval_every"])

# Run quick baseline check
# print("Running baseline evaluation...")

In [None]:
# baseline_results = engine.mmlu_evaluator.evaluate_baselines(engine.mmlu_loader)
# config["BASELINES"] = baseline_results

## 3. Training

Train the bridge. Checkpoints automatically upload to WandB as artifacts with aliases:
- `latest`: most recent checkpoint
- `best`: highest accuracy
- `final`: end of training

In [None]:
# Clear GPU and re-initialize for clean training run
clear_gpu()
engine = H2CEngine(factory, dm, config, config["lr"], config["eval_every"])
# engine._perform_eval()  # Run initial eval
# Start training
# engine.run(epochs=1)

## 4. Resume from Checkpoint (Optional)

Load a checkpoint from WandB artifacts to continue training or run inference.

In [None]:
# Example: Load the best checkpoint
# Format: "ppettit/nlp_project/artifact_name:alias"
# You can find this in your WandB UI under Artifacts

ARTIFACT_PATH = "ppettit_nlp/nlp_project/bridge_Llama-3-1-8B-Instruct_TO_Qwen2-5-0-5B-Instruct_checkpoint:best"
engine.load_checkpoint(ARTIFACT_PATH)
print("Checkpoint loaded! You can now continue training or run inference.")
# engine.run(epochs=1)

## 5. Debugging & Analysis

Test specific components or run ad-hoc experiments.

In [None]:
# Test a specific prompt
prompt = "Explain how a CPU works."
engine.evaluator.generate_demo(prompt, max_new_tokens=100)

In [None]:
# Check bridge gate statistics
stats = engine.bridge.get_gate_stats()
print(f"Key Gate Avg: {stats['key_avg']:.4f}")
print(f"Value Gate Avg: {stats['value_avg']:.4f}")

## 6. Visualizations (Optional)

Generate publication-ready visualizations and upload to WandB.

In [None]:
# Uncomment to run full visualization suite
from h2c_bridge.visualization import run_all_visualizations
config["mmlu_sample_size"] = 100
run_all_visualizations(engine, config)