# H2C Bridge - Colab Development Notebook

In [None]:
# Clone repository
REPO_URL = "https://github.com/parkerpettit/h2c-bridge.git"

import os
if not os.path.exists("/content/h2c-bridge"):
    !git clone $REPO_URL /content/h2c-bridge
    print("✅ Repository cloned")
else:
    print("Repository already exists, pulling latest changes...")
    !cd /content/h2c-bridge && git pull

%cd /content/h2c-bridge
print(f"Working directory: {os.getcwd()}")

Cloning into '/content/h2c-bridge'...
remote: Enumerating objects: 289, done.[K
remote: Counting objects: 100% (289/289), done.[K
remote: Compressing objects: 100% (202/202), done.[K
remote: Total 289 (delta 179), reused 192 (delta 83), pack-reused 0 (from 0)[K
Receiving objects: 100% (289/289), 151.76 KiB | 21.68 MiB/s, done.
Resolving deltas: 100% (179/179), done.
✅ Repository cloned
/content/h2c-bridge
Working directory: /content/h2c-bridge


In [None]:
# Install package in editable mode
!pip install -q -e .
!pip install -q -U bitsandbytes
print("Package installed")

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.4/59.4 MB[0m [31m45.6 MB/s[0m eta [36m0:00:00[0m
[?25hPackage installed


In [None]:
import os
from google.colab import userdata
from huggingface_hub import login
import wandb


HF_SECRET_NAME = "HF_TOKEN"   # change if needed

hf_token = userdata.get(HF_SECRET_NAME)
if hf_token is None:
    raise ValueError(f"No Hugging Face token found in Colab secrets under '{HF_SECRET_NAME}'.")

os.environ["HF_TOKEN"] = hf_token
login(token=hf_token)
print("Logged in to HuggingFace")

wandb_key = userdata.get("wandb")
if wandb_key is None:
    raise ValueError("No Weights & Biases API key found in Colab secrets under 'wandb'.")

os.environ["WANDB_API_KEY"] = wandb_key
wandb.login(key=wandb_key)
print(f"Logged in as: {wandb.api.viewer()['entity']}")

Note: Environment variable`HF_TOKEN` is set and is the current active token independently from the token you've just configured.


Logged in to HuggingFace


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mppettit[0m ([33mppettit_nlp[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Logged in as: ppettit_nlp


## 1. Initialize Components

Set up models, data, and configuration. Edit `config.py` or `factory.py` and re-run this cell to test changes.

In [None]:
from h2c_bridge.config import get_default_config
from h2c_bridge.factory import H2CModelFactory
from h2c_bridge.data.datamodule import H2CDataModule
from h2c_bridge.utils import set_seed, clear_gpu

# Set seed for reproducibility
set_seed(42)

# Get config and customize for development
config = get_default_config()
config.update({
        "SHARER_ID": "meta-llama/Llama-3.1-8B-Instruct",
        "RECEIVER_ID": "Qwen/Qwen2.5-0.5B-Instruct",

        # Dataset size
        "MAX_SAMPLES": 250_000,  # max samples of OpenHermes to pretrain bridge on
        "BATCH_SIZE": 12,
        "lr": 1e-4,

        # Evaluation frequency (in steps)
        "eval_every": 500,
        "log_bridge_every": 50,  # log bridge gate stats to wandb

        # Training
        "epochs": 1,
        "gate_warmup_steps": 0,

        # MMLU evaluation
        # set sample size to None for all samples
        "mmlu_sample_size": 100,  # samples per category (57 categories in validation = 57 total samples)
        "mmlu_eval_split": "validation",  # ~1.5k samples, or set to test for 14k
        # Logging
        "wandb_log_examples": 10,  # number of examples to log to WandB per eval mode
        "BASELINES": {
            "receiver_only": {"acc": 0.3566, "latency_s": 0.3349},
            "sharer_only": {"acc": 0.6242, "latency_s": 0.6213},
            "text_to_text": {"acc": 0.3354, "latency_s": 1.2960},
        }
    })




print("Initializing Factory...")
factory = H2CModelFactory(config["SHARER_ID"], config["RECEIVER_ID"])
tok_sharer, tok_receiver = factory.load_tokenizers()

print("Initializing Data Module...")
dm = H2CDataModule(tok_sharer, tok_receiver, config)
print("✅ Setup complete")

Random seed set to 42
Initializing Factory...
--- [ModelFactory] Loading Tokenizers...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Initializing Data Module...
✅ Setup complete


[receiver_only] Acc: 35.66% | Err: 0.82% | Latency: 0.3349s  
[sharer_only] Acc: 62.42% | Err: 0.55% | Latency: 0.3454s  
[text_to_text] Acc: 33.54% | Err: 0.48% | Latency: 1.2960s  


## 2. Quick Baseline Check (Optional)

Verify the evaluation pipeline works before training.

In [None]:
from h2c_bridge.training.engine import H2CEngine

# Create engine
engine = H2CEngine(factory, dm, config, config["lr"], config["eval_every"])

# Run quick baseline check
# print("Running baseline evaluation...")



--- [ModelFactory] Loading LLMs (Frozen + Quantized)...


config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/659 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

--- [ModelFactory] Initializing Bridge...
--- [Bridge] Aligning Top 24 layers (Sharer: 32, Receiver: 24)
Total parameters: 58,669,104
Trainable parameters: 58,669,104
Size (MB): 223.8 MB (float32)
--- [DataModule] Loading Datasets (Max 250000)...
Loading OpenHermes-2.5 (train)...


README.md: 0.00B [00:00, ?B/s]

openhermes2_5.json:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

Generating train split:   0%|          | 0/1001551 [00:00<?, ? examples/s]

Processed 249647 valid conversation pairs (Skipped 353 > 2048 tokens).
--- [MMLU] Loading auxiliary_train split...


README.md: 0.00B [00:00, ?B/s]

dataset_infos.json: 0.00B [00:00, ?B/s]

all/test-00000-of-00001.parquet:   0%|          | 0.00/3.50M [00:00<?, ?B/s]

all/validation-00000-of-00001.parquet:   0%|          | 0.00/408k [00:00<?, ?B/s]

all/dev-00000-of-00001.parquet:   0%|          | 0.00/76.5k [00:00<?, ?B/s]

all/auxiliary_train-00000-of-00001.parqu(…):   0%|          | 0.00/47.5M [00:00<?, ?B/s]

Generating test split:   0%|          | 0/14042 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1531 [00:00<?, ? examples/s]

Generating dev split:   0%|          | 0/285 [00:00<?, ? examples/s]

Generating auxiliary_train split:   0%|          | 0/99842 [00:00<?, ? examples/s]

--- [MMLU] Found 1 subjects
--- [MMLU] Processed 62500 examples (auxiliary_train).
--- [DataModule] Combined Train Source Sizes: 249647 OpenHermes + 62500 MMLU Aux = 312147 total
--- [DataModule] Split: 309025 Train | 3122 Val
--- [DataModule] Setting up MMLU Eval (validation split, 100 per subject)...
--- [MMLU] Loading validation split...
--- [MMLU] Found 57 subjects
--- [MMLU] Processed 1461 examples (validation).
--- [Scheduler] Cosine schedule: 25753 total steps, 2575 warmup steps


In [None]:
# baseline_results = engine.mmlu_evaluator.evaluate_baselines(engine.mmlu_loader)
# config["BASELINES"] = baseline_results

## 3. Training

Train the bridge. Checkpoints automatically upload to WandB as artifacts with aliases:
- `latest`: most recent checkpoint
- `best`: highest accuracy
- `final`: end of training

In [None]:
# Clear GPU and re-initialize for clean training run
clear_gpu()
engine = H2CEngine(factory, dm, config, config["lr"], config["eval_every"])
# engine._perform_eval()  # Run initial eval
# Start training
# engine.run(epochs=1)

Cleared GPU cache.


--- [ModelFactory] Loading LLMs (Frozen + Quantized)...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

--- [ModelFactory] Initializing Bridge...
--- [Bridge] Aligning Top 24 layers (Sharer: 32, Receiver: 24)
Total parameters: 58,669,104
Trainable parameters: 58,669,104
Size (MB): 223.8 MB (float32)
--- [Scheduler] Cosine schedule: 25753 total steps, 2575 warmup steps


## 4. Resume from Checkpoint (Optional)

Load a checkpoint from WandB artifacts to continue training or run inference.

In [None]:
# Example: Load the best checkpoint
# Format: "ppettit/nlp_project/artifact_name:alias"
# You can find this in your WandB UI under Artifacts

ARTIFACT_PATH = "ppettit_nlp/nlp_project/bridge_Llama-3-1-8B-Instruct_TO_Qwen2-5-0-5B-Instruct_checkpoint:best"
engine.load_checkpoint(ARTIFACT_PATH)
print("Checkpoint loaded! You can now continue training or run inference.")
# engine.run(epochs=1)

Loading checkpoint from WandB artifact: ppettit_nlp/nlp_project/bridge_Llama-3-1-8B-Instruct_TO_Qwen2-5-0-5B-Instruct_checkpoint:best


[34m[1mwandb[0m: Downloading large artifact 'bridge_Llama-3-1-8B-Instruct_TO_Qwen2-5-0-5B-Instruct_checkpoint:best', 672.36MB. 1 files...
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 00:00:42.4 (15.9MB/s)


Restored optimizer state
Restored scheduler state (LR: 1.44e-05)
Loaded from step 20000, accuracy: 44.21%
Checkpoint loaded! You can now continue training or run inference.


## 5. Debugging & Analysis

Test specific components or run ad-hoc experiments.

In [None]:
# Test a specific prompt
prompt = "Explain how a CPU works."
engine.evaluator.generate_demo(prompt, max_new_tokens=100)

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



PROMPT: Explain how a CPU works.
--------------------
[Vanilla]:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
Certainly! A Central Processing Unit (CPU) is the brain of a computer and it's responsible for executing instructions to perform various tasks such as reading data from memory, performing arithmetic operations, handling logical operations, etc.

Here’s a step-by-step explanation of how a CPU works:

1. **Initialization**: The CPU starts with an initial state that includes all registers, flags, and other necessary variables.

2. **Memory Access**:
   - The CPU reads data from memory into the cache.

[Bridged]:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Explain how a CPU works.<|im_end|>
<|im_start|>assistant
A Central Processing Unit (CPU) is the brain of a computer system and it performs all the necessary

In [None]:
# Check bridge gate statistics
stats = engine.bridge.get_gate_stats()
print(f"Key Gate Avg: {stats['key_avg']:.4f}")
print(f"Value Gate Avg: {stats['value_avg']:.4f}")

Key Gate Avg: 0.7097
Value Gate Avg: 0.6808


## 6. Visualizations (Optional)

Generate publication-ready visualizations and upload to WandB.

In [None]:
# Uncomment to run full visualization suite
from h2c_bridge.visualization import run_all_visualizations
config["mmlu_sample_size"] = 100
run_all_visualizations(engine, config)


VISUALIZATION SUITE - Publication Ready

[Viz] Generating charts for theme(s): dark, light

[Viz] Running H2C Bridge evaluation...


Eval [bridge]:   0%|          | 0/244 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
You're using a Qwen2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


[bridge] Acc: 45.17% | Err: 0.00% | Latency: 0.1568s | Avg tokens: 0.0 | Per-token: 0.0000s
[WandB] Logged 10 examples to eval_examples/bridge
[Viz] H2C Bridge: Accuracy=45.17%, Latency=0.1568s

[Viz] Running baseline evaluations for comparison...

>>> RUNNING BASELINES (Detailed)

--- Running Baseline (Detailed): receiver_only ---


Eval [receiver_only]:   0%|          | 0/244 [00:00<?, ?it/s]

[receiver_only] Acc: 34.77% | Err: 0.00% | Latency: 0.0155s | Avg tokens: 0.0 | Per-token: 0.0000s
[WandB] Logged 10 examples to eval_examples/receiver_only
--- Running Baseline (Detailed): sharer_only ---


Eval [sharer_only]:   0%|          | 0/244 [00:00<?, ?it/s]

[sharer_only] Acc: 60.16% | Err: 0.00% | Latency: 0.0312s | Avg tokens: 0.0 | Per-token: 0.0000s
[WandB] Logged 10 examples to eval_examples/sharer_only
--- Running Baseline (Detailed): text_to_text ---


Eval [text_to_text]:   0%|          | 0/244 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[text_to_text] Acc: 33.74% | Err: 0.00% | Latency: 1.0830s | Avg tokens: 0.0 | Per-token: 0.0000s
[WandB] Logged 10 examples to eval_examples/text_to_text
[Viz] receiver_only: Accuracy=34.77%
[Viz] sharer_only: Accuracy=60.16%
[Viz] text_to_text: Accuracy=33.74%

Generating DARK theme charts (1/2)
[Viz] Theme switched to: dark
[Viz] Generating performance charts...
[Viz] Performance charts logged.
[Viz] Generating gate dynamics...
[Viz] Gate dynamics logged.
[Viz] Generating injection heatmap for: 'Explain why the sky is blue....'

PROMPT: Explain why the sky is blue.
--------------------
[Vanilla]:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
The

[Bridged]:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Explain why the sky is blue.<|im_end|>
<|im_start|>assistant
The
--------------------

[Viz] Injection heatmap log


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



[Viz] Training summary logged.

Generating LIGHT theme charts (2/2)
[Viz] Theme switched to: light
[Viz] Generating performance charts...
[Viz] Performance charts logged.
[Viz] Generating gate dynamics...
[Viz] Gate dynamics logged.
[Viz] Generating injection heatmap for: 'Explain why the sky is blue....'

PROMPT: Explain why the sky is blue.
--------------------
[Vanilla]:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
<|im_end|>
<|im_start|>assistant
The

[Bridged]:
<|im_start|>system
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>
<|im_start|>user
Explain why the sky is blue.<|im_end|>
<|im_start|>assistant
The
--------------------

[Viz] Injection heatmap logged.
[Viz] Generating probability shift chart...
[Viz] Probability shift logged.
[Viz] Generating comparative category breakdown...
[Viz] Comparative category breakdown logged.
[Viz] Generating comparative confusion matrices...
[V


This figure includes Axes that are not compatible with tight_layout, so results might be incorrect.



[Viz] Training summary logged.

All visualizations logged to WandB successfully!
  - Themes generated: dark, light
  - 10 chart types per theme
  - Includes baseline comparisons in confusion matrix & category breakdown
  - Use '_dark' or '_light' suffix in WandB to filter

