# H2C Bridge - Colab Development Notebook

In [1]:
# Clone repository
REPO_URL = "https://github.com/parkerpettit/h2c-bridge.git" 

import os
if not os.path.exists("/content/h2c-bridge"):
    !git clone $REPO_URL /content/h2c-bridge
    print("✅ Repository cloned")
else:
    print("Repository already exists, pulling latest changes...")
    !cd /content/h2c-bridge && git pull

%cd /content/h2c-bridge
print(f"Working directory: {os.getcwd()}")

Repository already exists, pulling latest changes...
Already up to date.
/content/h2c-bridge
Working directory: /content/h2c-bridge


In [2]:
# Install package in editable mode
!pip install -q -e .
!pip install -q -U bitsandbytes
print("Package installed")

  Preparing metadata (setup.py) ... [?25l[?25hdone
Package installed


In [3]:
import wandb
import getpass
import os

api_key = getpass.getpass('WandB API Key: ')
os.environ['WANDB_API_KEY'] = api_key
wandb.login(key=api_key)
print(f"✓ Logged in as: {wandb.api.viewer()['entity']}")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mppettit[0m ([33mppettit_nlp[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✓ Logged in as: ppettit_nlp


In [4]:
from huggingface_hub import login

token = getpass.getpass('HF Token: ')
os.environ['HF_TOKEN'] = token
login(token=token)
print("✓ Logged in to HuggingFace")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


✓ Logged in to HuggingFace


## 1. Initialize Components

Set up models, data, and configuration. Edit `config.py` or `factory.py` and re-run this cell to test changes.

In [5]:
from h2c_bridge.config import get_default_config
from h2c_bridge.factory import H2CModelFactory
from h2c_bridge.data.datamodule import H2CDataModule
from h2c_bridge.utils import set_seed, clear_gpu

# Set seed for reproducibility
set_seed(42)

# Get config and customize for development
config = get_default_config()
config.update({
        "SHARER_ID": "meta-llama/Llama-3.1-8B-Instruct",
        "RECEIVER_ID": "Qwen/Qwen2.5-0.5B-Instruct",
        
        # Dataset size
        "MAX_SAMPLES": 100_000,  # max samples of OpenHermes to pretrain bridge on
        "BATCH_SIZE": 1,
        "lr": 1e-4,
        
        # Evaluation frequency (in steps)
        "eval_every": 1000,
        "log_bridge_every": 50,  # log bridge gate stats to wandb
        
        # Training
        "epochs": 1,
        "gate_warmup_steps": 0,
        
        # MMLU evaluation
        "mmlu_sample_size": 1,  # samples per category (57 categories in validation = 57 total samples)
    
        # Logging
        "wandb_log_examples": 10,  # number of examples to log to WandB per eval mode
   
    })

     


print("Initializing Factory...")
factory = H2CModelFactory(config["SHARER_ID"], config["RECEIVER_ID"])
tok_sharer, tok_receiver = factory.load_tokenizers()

print("Initializing Data Module...")
dm = H2CDataModule(tok_sharer, tok_receiver, config)
print("✅ Setup complete")

Random seed set to 42
Initializing Factory...
--- [ModelFactory] Loading Tokenizers...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Initializing Data Module...
✅ Setup complete


In [6]:
# from h2c_bridge.config import get_default_config
# from h2c_bridge.factory import H2CModelFactory
# from h2c_bridge.data.datamodule import H2CDataModule

# def inspect_tokens():
#     print(">>> Loading Configuration and Models...")
#     config = get_default_config()
    
#     # Initialize Factory
#     factory = H2CModelFactory(config["SHARER_ID"], config["RECEIVER_ID"])
#     tok_sharer, tok_receiver = factory.load_tokenizers()
    
#     print(f"Sharer Tokenizer: {tok_sharer.name_or_path}")
#     print(f"Receiver Tokenizer: {tok_receiver.name_or_path}")

#     # Initialize DataModule
#     print(">>> Loading DataModule...")
#     dm = H2CDataModule(tok_sharer, tok_receiver, config)
#     dm.setup()
    
#     train_loader, _ = dm.get_loaders()
    
#     print(">>> Fetching one batch...")
#     batch = next(iter(train_loader))
    
#     # Inspect Sharer Input
#     print("\n" + "="*50)
#     print("SHARER INPUT (First Example)")
#     print("="*50)
#     sharer_ids = batch['sharer_input_ids'][0]
#     # Decode with special tokens visible
#     sharer_text = tok_sharer.decode(sharer_ids, skip_special_tokens=False)
#     print(sharer_text)
    
#     # Inspect Receiver Input (Prompt + Target)
#     print("\n" + "="*50)
#     print("RECEIVER INPUT (First Example)")
#     print("="*50)
#     rec_prompt_ids = batch['receiver_prompt_ids'][0]
#     rec_kickstart_ids = batch['receiver_kickstart_ids'][0]
#     rec_target_ids = batch['receiver_target_ids'][0]
    
#     # Combine for full view
#     full_ids = list(rec_prompt_ids) + list(rec_kickstart_ids) + list(rec_target_ids)
    
#     # Decode with special tokens visible
#     rec_text = tok_receiver.decode(full_ids, skip_special_tokens=False)
#     print(rec_text)

# # inspect_tokens()

## 2. Quick Baseline Check (Optional)

Verify the evaluation pipeline works before training.

In [7]:
from h2c_bridge.training.engine import H2CEngine

# Create engine
engine = H2CEngine(factory, dm, config, lr=1e-4, eval_every=100)

# Run quick baseline check
print("Running baseline evaluation...")
baseline_results = engine.mmlu_evaluator.evaluate_baselines(engine.mmlu_loader, debug_mode=True)
config["BASELINES"] = baseline_results



--- [ModelFactory] Loading LLMs (Frozen + Quantized)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

--- [ModelFactory] Initializing Bridge...
--- [Bridge] Aligning Top 24 layers (Sharer: 32, Receiver: 24)
Total parameters: 52,334,640
Trainable parameters: 52,334,640
Size (MB): 199.6 MB (float32)
--- [DataModule] Loading Datasets (Max 100000)...
Loading OpenHermes-2.5 (train)...
Processed 99663 valid conversation pairs (Skipped 337 > 2048 tokens).
--- [MMLU] Loading auxiliary_train split...
--- [MMLU] Found 1 subjects
--- [MMLU] Processed 5 examples (auxiliary_train).
--- [DataModule] Combined Train Source Sizes: 99663 OpenHermes + 5 MMLU Aux = 99668 total
--- [DataModule] Split: 98671 Train | 997 Val
--- [DataModule] Setting up MMLU Eval (Validation Split)...
--- [MMLU] Loading validation split...
--- [MMLU] Found 57 subjects
--- [MMLU] Processed 57 examples (validation).
--- [Scheduler] Cosine schedule: 98671 total steps, 9867 warmup steps
Running baseline evaluation...

>>> RUNNING BASELINES (Deterministic)
[DEBUG MODE] Will stop after 5 samples

--- Running Baseline: receiver_only

Eval [receiver_only] [DEBUG]:   0%|          | 0/57 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



[DEBUG MODE] Stopping early after 5 samples
[receiver_only] Acc: 40.00% | Err: 0.00% | Latency: 0.7ms
[WandB] Logged 5 examples to eval_examples/receiver_only
--- Running Baseline: sharer_only ---


Eval [sharer_only] [DEBUG]:   0%|          | 0/57 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



[DEBUG MODE] Stopping early after 5 samples
[sharer_only] Acc: 40.00% | Err: 0.00% | Latency: 0.9ms
[WandB] Logged 5 examples to eval_examples/sharer_only
--- Running Baseline: text_to_text ---


Eval [text_to_text] [DEBUG]:   0%|          | 0/57 [00:00<?, ?it/s]


[DEBUG MODE] Stopping early after 5 samples
[text_to_text] Acc: 40.00% | Err: 0.00% | Latency: 3.4ms
[WandB] Logged 5 examples to eval_examples/text_to_text


## 3. Training

Train the bridge. Checkpoints automatically upload to WandB as artifacts with aliases:
- `latest`: most recent checkpoint
- `best`: highest accuracy
- `final`: end of training

In [None]:
# Clear GPU and re-initialize for clean training run
clear_gpu()
engine = H2CEngine(factory, dm, config, lr=1e-4, eval_every=10)

# Start training
engine.run(epochs=1)

Cleared GPU cache.


--- [ModelFactory] Loading LLMs (Frozen + Quantized)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

--- [ModelFactory] Initializing Bridge...
--- [Bridge] Aligning Top 24 layers (Sharer: 32, Receiver: 24)
Total parameters: 52,334,640
Trainable parameters: 52,334,640
Size (MB): 199.6 MB (float32)
--- [Scheduler] Cosine schedule: 98671 total steps, 9867 warmup steps
--- [Engine] Starting Training for 1 epochs...


Epoch 1:   0%|          | 0/98671 [00:00<?, ?it/s]


--- Evaluation at Step 100 ---


Validation Loop:   0%|          | 0/997 [00:00<?, ?it/s]

KeyboardInterrupt: 

## 4. Resume from Checkpoint (Optional)

Load a checkpoint from WandB artifacts to continue training or run inference.

In [None]:
# Example: Load the best checkpoint
# Format: "entity/project/artifact_name:alias"
# You can find this in your WandB UI under Artifacts

# ARTIFACT_PATH = "your-entity/nlp_project/bridge_Llama-3-1-8B-Instruct_TO_Qwen2-5-0-5B-Instruct_checkpoint:best"
# engine.load_checkpoint(ARTIFACT_PATH)
# print("Checkpoint loaded! You can now continue training or run inference.")

## 5. Debugging & Analysis

Test specific components or run ad-hoc experiments.

In [None]:
# Test a specific prompt
prompt = "Explain how a CPU works."
engine.evaluator.generate_demo(prompt, max_new_tokens=100)

In [None]:
# Check bridge gate statistics
stats = engine.bridge.get_gate_stats()
print(f"Key Gate Avg: {stats['key_avg']:.4f}")
print(f"Value Gate Avg: {stats['value_avg']:.4f}")

## 6. Visualizations (Optional)

Generate publication-ready visualizations and upload to WandB.

In [None]:
# Uncomment to run full visualization suite
# from h2c_bridge.visualization import run_all_visualizations
# run_all_visualizations(engine, config, themes=("dark", "light"))