# Feature Interaction Graph — Worked Example

This notebook demonstrates the full pipeline for constructing a feature interaction
graph on **GPT-2-small** with community SAEs from SAELens.

We'll walk through:
1. Loading the model and SAEs
2. Collecting feature activations
3. Building the co-activation atlas
4. Identifying candidate interaction pairs
5. Measuring causal interactions
6. Building and analyzing the interaction graph
7. Extracting behavior-specific subgraphs
8. Predicting and visualizing steering cascades

In [None]:
# Install dependencies (uncomment if needed)
# !pip install -e ".[dev]"

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
from pathlib import Path

from feature_graph.config import Config
from feature_graph.utils import setup_logging, set_seed

setup_logging("INFO")
set_seed(42)

# Check device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 1. Configuration

We configure a small-scale run for demonstration. For a full research run,
increase `n_tokens` to 1M+ and `top_k_features` to 200+.

In [None]:
cfg = Config(
    model_name="gpt2-small",
    sae_release="gpt2-small-res-jb",
    sae_id_template="blocks.{layer}.hook_resid_post",
    layers=[0, 1, 2, 3, 4, 5],  # First 6 layers for demo
    device=device,

    # Scale down for notebook demo
    n_tokens=10_000,
    context_length=128,
    batch_size=8,

    # Candidate filtering
    top_k_features=50,  # Small for demo
    layer_window=2,

    # Interaction measurement
    interaction_method="clamping",
    n_intervention_samples=20,  # Small for demo

    # Statistical testing
    significance_level=0.05,
    correction_method="fdr_bh",

    output_dir="outputs/demo",
    seed=42,
)

# Save config for reproducibility
cfg.save("outputs/demo/config.json")
print(f"Config saved. Analyzing layers {cfg.layers} with top-{cfg.top_k_features} features.")

## 2. Load Model and SAEs

We use TransformerLens for the model and SAELens for the sparse autoencoders.
On first run, this will download the model and SAE weights.

In [None]:
from feature_graph.loading import load_model_and_saes, get_sae_dict_size

model, saes = load_model_and_saes(cfg)

print(f"\nModel: {cfg.model_name}")
print(f"d_model: {model.cfg.d_model}")
print(f"n_layers: {model.cfg.n_layers}")
print(f"\nLoaded SAEs for layers: {sorted(saes.keys())}")
for layer in sorted(saes.keys()):
    print(f"  Layer {layer}: {get_sae_dict_size(saes[layer])} features")

## 3. Collect Feature Activations

We run a corpus of text through the model and record SAE feature activations
at every layer. This gives us the raw data for co-activation analysis.

In [None]:
from feature_graph.activations import collect_activations

activation_store = collect_activations(model, saes, cfg)

print(f"\nCollected activations for {activation_store.n_tokens} tokens")
print(f"Layers: {activation_store.layers}")
for layer in activation_store.layers:
    acts = activation_store.activations[layer]
    freq = activation_store.feature_frequencies[layer]
    n_active = (freq > 0).sum()
    mean_freq = freq[freq > 0].mean() if n_active > 0 else 0
    print(f"  Layer {layer}: {acts.shape[1]} features, {n_active} active, mean freq={mean_freq:.4f}")

# Save activations
activation_store.save("outputs/demo/activations.h5")

## 4. Build Co-Activation Atlas

The co-activation atlas is the correlational scaffold. For each pair of features
across layers, we compute:
- **Co-activation frequency**: P(f_j > 0 | f_i > 0)
- **Pointwise mutual information**: log P(i,j) / (P(i) · P(j))
- **Co-activation ratio**: P(f_j > 0 | f_i > 0) / P(f_j > 0)

This tells us which pairs *might* interact — but correlation ≠ causation.

In [None]:
from feature_graph.coactivation import build_coactivation_atlas

atlas = build_coactivation_atlas(activation_store, cfg)

print(f"\nCo-activation atlas built for {len(atlas.layer_pairs)} layer pairs:")
for lp in atlas.layer_pairs:
    pmi_mat = atlas.pmi.get(lp)
    if pmi_mat is not None:
        n_nonzero = pmi_mat.nnz
        total = pmi_mat.shape[0] * pmi_mat.shape[1]
        print(f"  {lp[0]} → {lp[1]}: {n_nonzero} / {total} pairs retained ({100*n_nonzero/total:.2f}%)")

atlas.save("outputs/demo/atlas.h5")

## 5. Identify Candidate Pairs

The four-stage pruning pipeline:
1. **Importance filter**: Select top-K features by activation frequency × causal effect
2. **Co-activation filter**: Remove pairs with negligible co-activation
3. **Decoder alignment filter**: Remove pairs with near-zero direct alignment (adjacent layers)
4. **Layer locality**: Only test pairs within the layer window

In [None]:
from feature_graph.candidates import identify_candidates

candidates = identify_candidates(
    atlas, saes, cfg,
    activation_store=activation_store,
    model=model,
)

print(f"\n{len(candidates)} candidate pairs identified")

# Show distribution across layer pairs
from collections import Counter
pair_counts = Counter((c.src.layer, c.tgt.layer) for c in candidates)
print("\nCandidates by layer pair:")
for (l_src, l_tgt), count in sorted(pair_counts.items()):
    print(f"  Layer {l_src} → {l_tgt}: {count} pairs")

# Show some candidate pairs
print("\nTop 10 candidate pairs (by PMI):")
for i, c in enumerate(candidates[:10]):
    print(f"  {c.src_id} → {c.tgt_id} | PMI={c.pmi:.2f} | coact={c.coactivation_prob:.3f} | align={c.decoder_alignment:.3f}")

## 6. Measure Causal Interactions

This is the core scientific step. For each candidate pair, we:
1. Sample inputs where the source feature is active
2. Clamp the source feature to high/low values
3. Measure the effect on the target feature
4. Compute interaction strength with statistical testing

**This step takes the most time.** For the demo config (20 samples × small candidate set), it should take a few minutes.

In [None]:
from feature_graph.interactions import measure_interactions

interactions = measure_interactions(model, saes, candidates, cfg)

print(f"\n{len(interactions)} interactions measured")

# Count by type
type_counts = Counter(r.interaction_type for r in interactions)
print(f"\nInteraction types:")
for t, c in type_counts.items():
    print(f"  {t}: {c}")

# Count significant interactions
sig = [r for r in interactions if r.p_value < cfg.significance_level]
print(f"\n{len(sig)} significant interactions (p < {cfg.significance_level})")

# Show strongest interactions
sig.sort(key=lambda r: r.abs_strength, reverse=True)
print("\nTop 10 strongest significant interactions:")
for r in sig[:10]:
    print(f"  {r.src_id} → {r.tgt_id} | {r.interaction_type} | strength={r.mean_strength:.4f} | p={r.p_value:.4f}")

## 7. Build and Analyze the Interaction Graph

We now construct the NetworkX graph from the significant interactions
and compute graph-theoretic statistics.

In [None]:
from feature_graph.graph import build_interaction_graph, save_graph
from feature_graph.analysis import compute_graph_statistics, find_hubs, count_motifs, detect_communities

G = build_interaction_graph(interactions, cfg)

# Save in multiple formats
save_graph(G, "outputs/demo/graph.graphml", format="graphml")
save_graph(G, "outputs/demo/graph.json", format="json")

# Compute statistics
stats = compute_graph_statistics(G)
print(stats.summary())

In [None]:
# Hub identification
hubs = find_hubs(G, top_k=10, metric="total_degree")
print("\nTop 10 Hub Features (by total degree):")
for h in hubs:
    print(f"  {h['id']}: degree={h.get('total_degree_score', 0)}, "
          f"in={h['in_degree']}, out={h['out_degree']}, label='{h['label']}'")

In [None]:
# Community detection
communities = detect_communities(G)
print(f"\nDetected {len(communities)} communities")
for i, comm in enumerate(sorted(communities, key=len, reverse=True)[:5]):
    print(f"  Community {i}: {len(comm)} features")
    # Show a few features from each
    for fid in list(comm)[:3]:
        layer = G.nodes[fid].get('layer', '?')
        label = G.nodes[fid].get('label', '')
        print(f"    {fid} (layer {layer}): {label}")

In [None]:
# Motif counting
motifs = count_motifs(G)
print("\nInteraction Motifs:")
for motif, count in motifs.items():
    print(f"  {motif}: {count}")

## 8. Visualization

Generate interactive visualizations of the interaction graph.

In [None]:
from feature_graph.visualization import (
    render_interactive_graph,
    render_neighborhood,
    plot_degree_distribution,
    plot_layer_distance_decay,
    plot_interaction_type_breakdown,
)

# Full graph visualization
render_interactive_graph(G, "outputs/demo/full_graph.html")
print("Full graph visualization saved to outputs/demo/full_graph.html")

# Statistical plots
fig_degree = plot_degree_distribution(G)
if fig_degree:
    fig_degree.show()

fig_decay = plot_layer_distance_decay(G)
if fig_decay:
    fig_decay.show()

fig_types = plot_interaction_type_breakdown(G)
if fig_types:
    fig_types.show()

In [None]:
# Neighborhood exploration for the top hub
if hubs:
    top_hub = hubs[0]['id']
    render_neighborhood(G, top_hub, n_hops=1, output_path="outputs/demo/top_hub_neighborhood.html")
    print(f"Neighborhood of {top_hub} saved to outputs/demo/top_hub_neighborhood.html")

## 9. Behavior-Specific Subgraph Extraction

We extract the subgraph relevant to factual recall behavior.

In [None]:
from feature_graph.subgraphs import extract_behavior_subgraph, PREDEFINED_BEHAVIORS

# Extract factual recall subgraph
factual_result = extract_behavior_subgraph(
    model, saes, G,
    behavior="factual_recall",
    cfg=cfg,
    n_hops=2,
    top_k_seeds=10,
)

print(f"\nFactual Recall Subgraph:")
print(f"  Seed features: {len(factual_result.seed_features)}")
print(f"  Subgraph nodes: {factual_result.subgraph.number_of_nodes()}")
print(f"  Subgraph edges: {factual_result.subgraph.number_of_edges()}")

print("\nTop seed features:")
for sf in factual_result.seed_features[:10]:
    print(f"  {sf['feature_id']}: effect_size={sf['effect_size']:.4f}")

# Visualize
if factual_result.subgraph.number_of_nodes() > 0:
    render_interactive_graph(
        factual_result.subgraph,
        "outputs/demo/factual_recall_subgraph.html",
        title="Factual Recall Circuit",
        seed_nodes=[sf['feature_id'] for sf in factual_result.seed_features[:10]],
    )
    print("\nVisualization saved to outputs/demo/factual_recall_subgraph.html")

## 10. Steering Cascade Prediction

Use the interaction graph to predict what happens when we steer on a feature.

In [None]:
from feature_graph.steering import predict_cascade
from feature_graph.visualization import render_cascade_overlay

# Predict cascade from the top hub feature
if hubs:
    top_hub = hubs[0]['id']
    cascade = predict_cascade(G, top_hub, steer_delta=2.0, n_hops=2)
    print(cascade.summary())

    # Visualize cascade
    render_cascade_overlay(G, cascade, "outputs/demo/cascade.html")
    print("\nCascade visualization saved to outputs/demo/cascade.html")

## 11. Summary and Next Steps

This demo showed the complete pipeline on a small scale. For a full research run:

1. **Increase `n_tokens`** to 1M+ for robust co-activation statistics
2. **Increase `top_k_features`** to 200+ for a richer graph
3. **Increase `n_intervention_samples`** to 100+ for reliable causal estimates
4. **Analyze all layers** (set `layers=list(range(12))` for GPT-2-small)
5. **Compare methods**: run with `interaction_method='both'` to compare clamping vs. Jacobian
6. **Extract multiple behavior subgraphs** and compare their overlap
7. **Test steering predictions** by comparing predicted vs. actual cascading effects

In [None]:
print("\n" + "="*60)
print("Pipeline complete!")
print("="*60)
print(f"\nOutputs saved to: outputs/demo/")
print(f"  config.json          - Run configuration")
print(f"  activations.h5       - Feature activations")
print(f"  atlas.h5             - Co-activation atlas")
print(f"  graph.graphml        - Interaction graph (GraphML)")
print(f"  graph.json           - Interaction graph (JSON)")
print(f"  full_graph.html      - Interactive graph visualization")
print(f"  top_hub_neighborhood.html - Hub neighborhood")
print(f"  factual_recall_subgraph.html - Factual recall circuit")
print(f"  cascade.html         - Steering cascade prediction")