# Markov Mixing Investigation in mHC

This notebook runs experiments to investigate whether doubly stochastic H_res matrices in mHC exhibit Markov chain mixing.

## Setup

In [None]:
# Mount Google Drive (optional - for saving checkpoints)
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# Clone the repository
!git clone https://github.com/tokenbender/mHC-manifold-constrained-hyper-connections.git
%cd mHC-manifold-constrained-hyper-connections

In [None]:
# Install dependencies
!pip install -e . --quiet
!pip install wandb --quiet

In [None]:
# Verify GPU is available
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## Run Unit Tests

In [None]:
# Run all Markov mixing tests
!pytest tests/test_markov_mixing.py -v

In [None]:
# Run existing HC tests to make sure nothing broke
!pytest tests/test_hyper_connections.py -v

## Quick Sanity Check (No Training)

In [None]:
# Test spectral analysis functions directly
import torch
from hyper_connections.hyper_connections import sinkhorn_log
from analysis.spectral import analyze_h_res, spectral_gap
from analysis.markov_metrics import cumulative_product_metrics

# Create some random H_res matrices
num_layers = 12
num_streams = 4

h_res_list = []
for i in range(num_layers):
    logits = torch.randn(num_streams, num_streams)
    H = sinkhorn_log(logits, num_iters=10, tau=0.05)
    h_res_list.append(H)
    
    props = analyze_h_res(H)
    print(f"Layer {i}: |λ₂|={props['lambda_2_abs']:.4f}, gap={props['spectral_gap']:.4f}")

print("\n--- Cumulative Product Analysis ---")
metrics = cumulative_product_metrics(h_res_list)
for i in range(len(metrics['dist_to_uniform'])):
    print(f"Depth {i}: dist_to_uniform={metrics['dist_to_uniform'][i]:.4f}, gap={metrics['spectral_gap'][i]:.4f}")

## Download FineWeb10B Data

In [None]:
%cd examples/nanogpt
!python data/fineweb10B/download.py

## Phase 1: Baseline Characterization

Train a 6-layer mHC model to establish baseline spectral properties.

In [None]:
# Optional: Login to W&B
import wandb
wandb.login()

In [None]:
# Train 6-layer mHC baseline (quick test - 500 iters)
!python train.py config/train_fineweb10B_mhc.py \
    max_iters=500 \
    eval_interval=100 \
    spectral_log_interval=100 \
    out_dir="out-mhc-baseline-test"

In [None]:
# Analyze the checkpoint
!python ../../scripts/analyze_checkpoint.py --checkpoint out-mhc-baseline-test/ckpt.pt

## Full Training Run (6-layer baseline)

In [None]:
# Full 5000 iteration training
!python train.py config/train_fineweb10B_mhc.py

In [None]:
# Analyze the trained checkpoint
!python ../../scripts/analyze_checkpoint.py \
    --checkpoint out-fineweb10B-mhc/ckpt.pt \
    --output ../../analysis_results/baseline_6l

## Residual-Only Ablation

Test whether H_pre/H_post injections compensate for mixing.

In [None]:
# Train residual-only variant
!python train.py config/train_fineweb10B_mhc_resonly.py

In [None]:
# Analyze residual-only checkpoint
!python ../../scripts/analyze_checkpoint.py \
    --checkpoint out-fineweb10B-mhc-resonly/ckpt.pt \
    --output ../../analysis_results/resonly_6l

## 48-Layer Deep Model

In [None]:
# Train 48-layer mHC
!python train.py config/train_fineweb10B_mhc_48l.py

In [None]:
# Analyze 48-layer checkpoint
!python ../../scripts/analyze_checkpoint.py \
    --checkpoint out-fineweb10B-mhc-48l/ckpt.pt \
    --output ../../analysis_results/mhc_48l

## Compare Sinkhorn vs Orthostochastic

In [None]:
# Train with orthostochastic projection
!python train.py config/train_fineweb10B_mhc.py \
    mhc_h_res_proj="orthostochastic" \
    out_dir="out-mhc-orthostochastic" \
    wandb_run_name="mhc-orthostochastic"

In [None]:
# Analyze orthostochastic checkpoint
!python ../../scripts/analyze_checkpoint.py \
    --checkpoint out-mhc-orthostochastic/ckpt.pt \
    --projection orthostochastic \
    --output ../../analysis_results/orthostochastic_6l

## Custom Analysis

Run analysis directly in Python for more control.

In [None]:
import sys
sys.path.insert(0, '../..')

import torch
import matplotlib.pyplot as plt
from hyper_connections.hyper_connections import sinkhorn_log
from analysis.spectral import analyze_h_res
from analysis.markov_metrics import cumulative_product_metrics

def load_and_analyze(checkpoint_path, projection='sinkhorn'):
    """Load checkpoint and run full analysis."""
    ckpt = torch.load(checkpoint_path, map_location='cpu')
    state_dict = ckpt['model']
    
    # Extract H_res matrices
    h_res_keys = sorted([k for k in state_dict if 'H_res_logits' in k])
    h_res_list = []
    per_layer = {}
    
    for i, key in enumerate(h_res_keys):
        logits = state_dict[key]
        H = sinkhorn_log(logits, 10, 0.05)
        h_res_list.append(H)
        per_layer[i] = analyze_h_res(H)
    
    cumulative = cumulative_product_metrics(h_res_list)
    
    return per_layer, cumulative, h_res_list

In [None]:
# Visualize results
def plot_mixing_analysis(per_layer, cumulative):
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Per-layer |λ₂|
    ax = axes[0, 0]
    lambda2s = [per_layer[i]['lambda_2_abs'] for i in sorted(per_layer.keys())]
    ax.plot(lambda2s, 'o-')
    ax.axhline(y=0.9, color='g', linestyle='--', label='Near-permutation threshold')
    ax.axhline(y=0.7, color='orange', linestyle='--', label='Moderate mixing threshold')
    ax.set_xlabel('Layer')
    ax.set_ylabel('|λ₂|')
    ax.set_title('Second Eigenvalue per Layer')
    ax.legend()
    ax.set_ylim(0, 1.05)
    
    # Cumulative distance to uniform
    ax = axes[0, 1]
    ax.plot(cumulative['dist_to_uniform'], 'o-')
    ax.set_xlabel('Depth (cumulative layers)')
    ax.set_ylabel('Frobenius distance')
    ax.set_title('Distance to Uniform Matrix')
    
    # Cumulative spectral gap
    ax = axes[1, 0]
    ax.plot(cumulative['spectral_gap'], 'o-')
    ax.set_xlabel('Depth (cumulative layers)')
    ax.set_ylabel('Spectral gap')
    ax.set_title('Cumulative Product Spectral Gap')
    
    # Per-layer entropy
    ax = axes[1, 1]
    entropies = [per_layer[i]['entropy'] for i in sorted(per_layer.keys())]
    ax.plot(entropies, 'o-')
    ax.set_xlabel('Layer')
    ax.set_ylabel('Entropy')
    ax.set_title('Matrix Entropy per Layer')
    
    plt.tight_layout()
    plt.show()
    
# Example usage (uncomment after training):
# per_layer, cumulative, h_res_list = load_and_analyze('out-fineweb10B-mhc/ckpt.pt')
# plot_mixing_analysis(per_layer, cumulative)

## Visualize H_res Matrices

In [None]:
def plot_h_res_heatmaps(h_res_list, max_show=12):
    """Visualize H_res matrices as heatmaps."""
    n_show = min(len(h_res_list), max_show)
    cols = 4
    rows = (n_show + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(3*cols, 3*rows))
    axes = axes.flatten() if n_show > 1 else [axes]
    
    for i in range(n_show):
        ax = axes[i]
        H = h_res_list[i].numpy()
        im = ax.imshow(H, cmap='Blues', vmin=0, vmax=1)
        ax.set_title(f'Layer {i}')
        ax.set_xticks(range(H.shape[0]))
        ax.set_yticks(range(H.shape[0]))
        plt.colorbar(im, ax=ax, fraction=0.046)
    
    # Hide unused subplots
    for i in range(n_show, len(axes)):
        axes[i].axis('off')
    
    plt.tight_layout()
    plt.show()

# Example usage (uncomment after training):
# plot_h_res_heatmaps(h_res_list)

## Compare Multiple Checkpoints

In [None]:
def compare_checkpoints(checkpoint_paths, labels):
    """Compare mixing behavior across multiple checkpoints."""
    results = {}
    for path, label in zip(checkpoint_paths, labels):
        try:
            per_layer, cumulative, _ = load_and_analyze(path)
            results[label] = {
                'per_layer': per_layer,
                'cumulative': cumulative
            }
        except FileNotFoundError:
            print(f"Checkpoint not found: {path}")
    
    if not results:
        print("No checkpoints found!")
        return
    
    # Plot comparison
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Compare cumulative distance to uniform
    ax = axes[0]
    for label, data in results.items():
        ax.plot(data['cumulative']['dist_to_uniform'], 'o-', label=label)
    ax.set_xlabel('Depth')
    ax.set_ylabel('Distance to Uniform')
    ax.set_title('Cumulative Mixing Comparison')
    ax.legend()
    
    # Compare avg |λ₂| per layer
    ax = axes[1]
    for label, data in results.items():
        lambda2s = [data['per_layer'][i]['lambda_2_abs'] 
                    for i in sorted(data['per_layer'].keys())]
        ax.plot(lambda2s, 'o-', label=label)
    ax.axhline(y=0.9, color='gray', linestyle='--', alpha=0.5)
    ax.set_xlabel('Layer')
    ax.set_ylabel('|λ₂|')
    ax.set_title('Per-Layer Second Eigenvalue')
    ax.legend()
    
    plt.tight_layout()
    plt.show()

# Example usage (uncomment after training multiple models):
# compare_checkpoints(
#     ['out-fineweb10B-mhc/ckpt.pt', 'out-fineweb10B-mhc-48l/ckpt.pt'],
#     ['6-layer', '48-layer']
# )