#TinySQL: Edge Attribution Patching (EAP)

**Background:** A "TinySQL" model takes as input 1) An Instruction, which is an english data request sentence and 2) A Context, which is a SQL table create statement. The model outputs a Response, which is a SQL select statement.

**Notebook purpose:** Visualize the flow of information through attention heads and MLP layers when a token is corrupted. We corrupt 1) The instruction table name 2) An instruction field name 3) The context table name or 4) A context field name.

**Notebook details:** This notebook:

- Was developed on Google Colab using an A100
- Runs with BM1,BM2 or BM2 with base/CS1/CS2/CS3 models.
- Requires a GITHUB_TOKEN secret to access Martian TinySQL code repository.
- Requires a HF_TOKEN secret to access Martian HuggingFace repository.
- Was developed under a grant provided by withmartian.com ( https://withmartian.com )
- Is based on the [Aaquib111](https://github.com/Aaquib111/edge-attribution-patching) implementation.
- Has roughly the same output format as ACDC in [auto-circuit](https://github.com/UFO-101/auto-circuit).

**Notebook workflow:** The workflow is as follows:
- We start by loading a model using HookedTransformer of [TransformerLens](https://github.com/TransformerLensOrg/TransformerLens).
- We then load a dataset of size batch_size, using CorruptFeatureTestGenerator of TinySQL and featuring the following elements per example:  

        clean_prompt
        clean_tokens
        corrupt_prompt
        corrupt_tokens
        clean_tokenizer_index
        corrupt_tokenizer_index

- We run Edge Attribution Patching on our dataset and the metric, in our case, was the logit difference metric
$$\frac{\text{patched_logit_diff} - \text{corrupted_logit_diff}}{\text{clean_logit_diff} - \text{corrupted_logit_diff}}$$

- We keep the **top_n_edges**
- We plot the result in the form of a Sankey diagram.
- We save the results as a JSON file.

**Results Interpretation**

- The EAP algorithm uses the following relationship to compute the scores of an edge:
$$(e_{\text{clean}} - e_{\text{corr}}). \frac{\partial}{\partial e_{\text{clean}}} L(x_{\text{clean}} | \text{do}(E = e_{\text{clean}}))$$

Where:
  - $e_{\text{clean}} - e_{\text{corr}}$ is the activation difference in node i of the edge.
  - $\frac{\partial}{\partial e_{\text{clean}}} L$ is the gradient of our metric at node j of the edge.

So the score of an edge represents how a corruption in the upstream node of the edge influences the metric.

  - A positive score means that the edge plays a role in a clean output.
  - A negative score means that the edge plays a role in the corrupt output.

**The top_n_edges are selected based on the absolute value of their scores.**

# Import libraries
Imports standard libraries. Do not read.

In [None]:
!pip install -U pandas plotly transformer_lens jaxtyping -q

In [None]:
!git clone -b minimal-implementation --single-branch https://github.com/abirharrasse/edge-attribution-patching.git


In [None]:
github_token = userdata.get("GITHUB_TOKEN")

!pip install --upgrade git+https://{github_token}@github.com/withmartian/TinySQL.git

import TinySQL as qts
from IPython.display import clear_output
import einops
import torch
import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
from getpass import getpass
from google.colab import userdata
pio.renderers.default = "colab"

%cd /content/edge-attribution-patching
from IPython import get_ipython
ipython = get_ipython()
if ipython is not None:
    ipython.magic("%load_ext autoreload")
    ipython.magic("%autoreload 2")


import torch as t
from torch import Tensor
import einops

from transformers import AutoModelForCausalLM, AutoTokenizer
from transformer_lens import HookedTransformer
import json
from datetime import datetime
from typing import Dict, List, Set, Tuple, Optional, Any
import numpy as np
from collections import defaultdict

from eap.eap_wrapper import EAP

from jaxtyping import Float

device = t.device('cuda') if t.cuda.is_available() else t.device('cpu')
print(f'Device: {device}')

# EAP


In [None]:
model_num = 1                     # 1=TinyStories, 2=Qwen, 3=Llama
cs_num = 1                        # 0=BaseModel, 1=CS1, 2=CS2 or 3=CS3
feature_name = qts.ENGTABLENAME   # Instruction (ENGTABLENAME, ENGFIELDNAME) or Context (DEFTABLENAME, DEFFIELDNAME) feature to test.
use_novel_names = False           # If True, we corrupt using words not found in the clean prompt or create sql e.g. "little" or "hammer"
use_synonyms_field = False
use_synonyms_table = False
num_fields = 2                    # Number of table fields in data from clean/corrupt data generator
num_examples = 1
batch_size = 30
top_n_edges = 10                        # We plot the top n scoring edges

# Helping functions

In [None]:
def plot_eap_sankey(top_edges, layer_spacing=False):
    # Get unique nodes
    nodes = set()
    for from_edge, to_edge, _ in top_edges:
        nodes.add(from_edge)
        nodes.add(to_edge)
    nodes = list(nodes)
    node_idx = {node: idx for idx, node in enumerate(nodes)}

    # Color palette for nodes
    COLOR_PALETTE = [
        "rgba(31, 119, 180, 0.8)",    # blue
        "rgba(255, 127, 14, 0.8)",    # orange
        "rgba(44, 160, 44, 0.8)",     # green
        "rgba(214, 39, 40, 0.8)",     # red
        "rgba(148, 103, 189, 0.8)",   # purple
        "rgba(140, 86, 75, 0.8)",     # brown
        "rgba(227, 119, 194, 0.8)",   # pink
        "rgba(127, 127, 127, 0.8)",   # gray
        "rgba(188, 189, 34, 0.8)",    # yellow
        "rgba(23, 190, 207, 0.8)"     # cyan
    ]

    # Assign unique colors to nodes
    node_colors = {}
    for i, node in enumerate(nodes):
        node_colors[node] = COLOR_PALETTE[i % len(COLOR_PALETTE)]

    # Calculate scores per node
    node_scores = defaultdict(float)
    for from_edge, to_edge, score in top_edges:
        node_scores[from_edge] += score
        node_scores[to_edge] += score

    # Define the sankey nodes with scores
    node_labels = [f"{node}<br>{node_scores[node]:.3f}" for node in nodes]
    node_colors = [node_colors[node] for node in nodes]

    # Setup for layer spacing
    lyr_nodes: Dict[int, List[str]] = defaultdict(list)
    for node in nodes:
        layer = int(node.split('.')[1]) if '.' in node else 0
        lyr_nodes[layer].append(node)

    # Define the sankey edges
    sources, targets, values, labels, colors = [], [], [], [], []
    included_layer_nodes: Dict[int, List[str]] = defaultdict(list)

    # Plot all edges without threshold filtering
    for from_edge, to_edge, score in top_edges:
        source_idx = node_idx[from_edge]
        target_idx = node_idx[to_edge]

        sources.append(source_idx)
        targets.append(target_idx)
        values.append(abs(score))
        labels.append(f"{from_edge} → {to_edge}<br>{score:.3f}")

        if score == 0:
            edge_color = "rgba(0,0,0,0.1)"
        elif score > 0:
            edge_color = "rgba(0,0,255,0.3)"
        else:
            edge_color = "rgba(255,0,0,0.3)"
        colors.append(edge_color)

        source_layer = int(from_edge.split('.')[1]) if '.' in from_edge else 0
        target_layer = int(to_edge.split('.')[1]) if '.' in to_edge else 0
        included_layer_nodes[source_layer].append(from_edge)
        included_layer_nodes[target_layer].append(to_edge)

    if layer_spacing:
        ordered_lyr_nodes = [nodes for _, nodes in sorted(included_layer_nodes.items())]
        ghost_edge_val = 1e-6

        for lyr_1_nodes, lyr_2_nodes in zip(ordered_lyr_nodes[:-1], ordered_lyr_nodes[1:]):
            first_lyr_1_node = lyr_1_nodes[0]
            first_lyr_2_node = lyr_2_nodes[0]

            for lyr_1_node in lyr_1_nodes:
                sources.append(node_idx[lyr_1_node])
                targets.append(node_idx[first_lyr_2_node])
                values.append(ghost_edge_val)
                labels.append("")
                colors.append("rgba(0,255,0,0.0)")

            for lyr_2_node in lyr_2_nodes:
                sources.append(node_idx[first_lyr_1_node])
                targets.append(node_idx[lyr_2_node])
                values.append(ghost_edge_val)
                labels.append("")
                colors.append("rgba(0,255,0,0.0)")

    # Create the Sankey diagram
    fig = go.Figure(go.Sankey(
        arrangement="perpendicular",
        node=dict(
            label=node_labels,
            color=node_colors,
            line=dict(width=0.0),
            pad=15,
            thickness=20,
        ),
        link=dict(
            source=sources,
            target=targets,
            value=values,
            label=labels,
            color=colors,
            arrowlen=25
        ),
        domain={'y': [0, 1]}
    ))

    # Update layout
    n_layers = len(included_layer_nodes)
    h = max(250, 400)
    w = max(50 * n_layers, 600)

    fig.update_layout(
        font_size=14,
        height=h,
        width=w,
        plot_bgcolor='white',
        paper_bgcolor='white',
        margin=dict(t=20, l=20, r=20, b=20),
        hoverlabel=dict(font_size=14)
    )

    return fig

In [None]:
def get_notebook_name():
   try:
       from IPython import get_ipython
       kernel = get_ipython()
       if kernel is None:
           return "unknown_script"

       # Get the notebook path
       path = kernel.kernel.session.config['IPKernelApp']['connection_file']
       # Extract just the notebook name from the path
       notebook_name = path.split('/')[-1].replace('kernel-', '').replace('.json', '')
       return notebook_name
   except:
       return "unknown_script"  # Fallback if not in a notebook or error occurs

def format_float(value):
   if isinstance(value, float):
       return round(value, 8)
   return value

def rename_edge(edge_name: str) -> str:
   # Handle MLP case
   if edge_name.startswith("mlp."):
       layer_num = edge_name.split(".")[1]
       return f"l{layer_num}.mlp"
   # Handle attention head case
   elif edge_name.startswith("head."):
       parts = edge_name.split(".")
       layer_num, head_num = parts[1], parts[2]
       component = parts[3] if len(parts) > 3 else ""
       return f"l{layer_num}.h{head_num}" + (f".{component}" if component else "")
   return edge_name

def tensor_to_native(value):
   if isinstance(value, torch.Tensor):
       result = value.item() if value.numel() == 1 else value.tolist()
       return format_float(result)
   elif isinstance(value, dict):
       return {k: tensor_to_native(v) for k, v in value.items()}
   elif isinstance(value, (list, tuple)):
       return [tensor_to_native(v) for v in value]
   elif isinstance(value, float):
       return format_float(value)
   return value

def collect_eap_results(
   model_name: int,
   hf_model: str,
   dataset_name: int,
   feature_name: str,
   use_novel_names: bool,
   use_synonyms_field: bool,
   use_synonyms_table: bool,
   num_sql_fields: int,
   batch_size: int,
   n_batches:int,
   n_positions: int,
   n_layers: int,
   n_heads: int,
   edges: List[Tuple[str, str, float]],
   top_n_edges: int,
   clean_logit_diff: Any,
   corrupt_logit_diff: Any,
   clean_metric: Any,
   corrupt_metric: Any,
   source: str = None,
   additional_params: Optional[Dict[str, Any]] = None
):
   notebook_name = source if source else get_notebook_name()

   results = {
       "metadata": {
           "source": notebook_name,
           "timestamp": datetime.now().isoformat(),
           "model": model_name,
           "commandset": dataset_name,
           "hf_model": hf_model,
           "feature_name": feature_name,
           "use_novel_names": use_novel_names,
           "use_synonyms_field": use_synonyms_field,
           "use_synonyms_table": use_synonyms_table,
           "num_sql_fields": num_sql_fields,
           "batch_size": batch_size,
           "n_batches": n_batches,
           "n_positions":n_positions,
           "n_layers": n_layers,
           "n_heads":n_heads
       },
       "parameters": {
           "top_n_edges": top_n_edges,
           "clean_logit_diff": tensor_to_native(clean_logit_diff),
           "corrupt_logit_diff": tensor_to_native(corrupt_logit_diff),
           "clean_metric": tensor_to_native(clean_metric),
           "corrupt_metric": tensor_to_native(corrupt_metric)
       },
       "edges": [
           {
               "from_edge": rename_edge(str(from_edge)),
               "to_edge": rename_edge(str(to_edge)),
               "score": format_float(tensor_to_native(score))
           }
           for from_edge, to_edge, score in edges
       ]
   }

   if additional_params:
       results["parameters"].update(
           {k: tensor_to_native(v) for k, v in additional_params.items()}
       )

   return results

def save_eap_results(results: Dict[str, Any], filepath: str) -> None:
   serializable_results = tensor_to_native(results)
   with open(filepath, 'w') as f:
       json.dump(serializable_results, f, indent=2)

def load_eap_results(filepath: str) -> Dict[str, Any]:
   with open(filepath, 'r') as f:
       return json.load(f)

def results_to_edges(results: Dict[str, Any]) -> List[Tuple[str, str, float]]:
   return [
       (edge["from_edge"], edge["to_edge"], edge["score"])
       for edge in results["edges"]
   ]

def plot_eap_json_sankey(top_edges: List[Tuple[str, str, float]], layer_spacing: bool = False):
   from collections import defaultdict

   # Extract unique nodes and create node mapping
   nodes = set()
   for src, dst, _ in top_edges:
       nodes.add(src)
       nodes.add(dst)
   nodes = list(nodes)
   node_to_idx = {node: idx for idx, node in enumerate(nodes)}

   # Color palette for nodes
   COLOR_PALETTE = [
       "rgba(31, 119, 180, 0.8)",    # blue
       "rgba(255, 127, 14, 0.8)",    # orange
       "rgba(44, 160, 44, 0.8)",     # green
       "rgba(214, 39, 40, 0.8)",     # red
       "rgba(148, 103, 189, 0.8)",   # purple
       "rgba(140, 86, 75, 0.8)",     # brown
       "rgba(227, 119, 194, 0.8)",   # pink
       "rgba(127, 127, 127, 0.8)",   # gray
       "rgba(188, 189, 34, 0.8)",    # yellow
       "rgba(23, 190, 207, 0.8)"     # cyan
   ]

   # Calculate scores per node
   node_scores = defaultdict(float)
   for src, dst, score in top_edges:
       node_scores[src] += score
       node_scores[dst] += score

   # Assign colors and create labels with scores
   node_colors = [COLOR_PALETTE[i % len(COLOR_PALETTE)] for i in range(len(nodes))]
   node_labels = [f"{node}<br>{node_scores[node]:.3f}" for node in nodes]

   # Get layer information for each node
   lyr_nodes: Dict[int, List[str]] = defaultdict(list)
   for node in nodes:
       try:
           layer_num = int(node.split('.')[0][1:])  # Get number after 'l'
           lyr_nodes[layer_num].append(node)
       except (IndexError, ValueError):
           lyr_nodes[0].append(node)

   # Create Sankey diagram data
   sources, targets, values, labels, colors = [], [], [], [], []
   included_layer_nodes: Dict[int, List[str]] = defaultdict(list)

   # Plot all edges
   for src, dst, score in top_edges:
       source_idx = node_to_idx[src]
       target_idx = node_to_idx[dst]

       sources.append(source_idx)
       targets.append(target_idx)
       values.append(abs(score))
       labels.append(f"{src} → {dst}<br>{score:.3f}")

       if score == 0:
           edge_color = "rgba(0,0,0,0.1)"
       elif score > 0:
           edge_color = "rgba(0,0,255,0.3)"
       else:
           edge_color = "rgba(255,0,0,0.3)"
       colors.append(edge_color)

       source_layer = int(src.split('.')[0][1:]) if '.' in src else 0
       target_layer = int(dst.split('.')[0][1:]) if '.' in dst else 0
       included_layer_nodes[source_layer].append(src)
       included_layer_nodes[target_layer].append(dst)

   # Add ghost edges for layer spacing if enabled
   if layer_spacing:
       ordered_lyr_nodes = [nodes for _, nodes in sorted(included_layer_nodes.items())]
       ghost_edge_val = 1e-6

       for lyr_1_nodes, lyr_2_nodes in zip(ordered_lyr_nodes[:-1], ordered_lyr_nodes[1:]):
           first_lyr_1_node = lyr_1_nodes[0]
           first_lyr_2_node = lyr_2_nodes[0]

           for lyr_1_node in lyr_1_nodes:
               sources.append(node_to_idx[lyr_1_node])
               targets.append(node_to_idx[first_lyr_2_node])
               values.append(ghost_edge_val)
               labels.append("")
               colors.append("rgba(0,255,0,0.0)")

           for lyr_2_node in lyr_2_nodes:
               sources.append(node_to_idx[first_lyr_1_node])
               targets.append(node_to_idx[lyr_2_node])
               values.append(ghost_edge_val)
               labels.append("")
               colors.append("rgba(0,255,0,0.0)")

   # Create the figure
   fig = go.Figure(go.Sankey(
       arrangement="perpendicular",
       node=dict(
           label=node_labels,
           color=node_colors,
           line=dict(width=0.0),
           pad=15,
           thickness=20,
       ),
       link=dict(
           source=sources,
           target=targets,
           value=values,
           label=labels,
           color=colors,
           arrowlen=25
       ),
       domain={'y': [0, 1]}
   ))

   # Update layout
   n_layers = len(included_layer_nodes)
   h = max(250, 400)
   w = max(50 * n_layers, 600)

   fig.update_layout(
       font_size=14,
       height=h,
       width=w,
       plot_bgcolor='white',
       paper_bgcolor='white',
       margin=dict(t=20, l=20, r=20, b=20),
       hoverlabel=dict(font_size=14)
   )

   return fig

In [None]:
def modify_hf_wte_copy(model: AutoModelForCausalLM) -> AutoModelForCausalLM:

    config = model.config
    new_model = AutoModelForCausalLM.from_config(config)
    new_model.load_state_dict(model.state_dict(), strict=False)

    original_wte = model.transformer.wte
    original_weight = original_wte.weight.data  # Shape: [vocab_size, hidden_size]

    new_weight = original_weight[:-1, :]  # Shape: [vocab_size - 1, hidden_size]

    new_wte = torch.nn.Embedding(new_weight.size(0), new_weight.size(1))
    new_wte.weight.data = new_weight

    new_model.transformer.wte = new_wte

    original_unembed_weight = model.lm_head.weight.data  # Shape: [vocab_size, hidden_size]

    new_unembed_weight = original_unembed_weight[:-1, :]  # Shape: [vocab_size - 1, hidden_size]

    new_model.lm_head = torch.nn.Linear(new_unembed_weight.size(1), new_unembed_weight.size(0), bias=False)
    new_model.lm_head.weight.data = new_unembed_weight

    return new_model

# Load model

In [None]:
base_model_name = qts.sql_interp_model_location(model_num, 0, synonym=False)
model_name = qts.sql_interp_model_location(model_num, cs_num, synonym=True)

hf_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer_sql = AutoTokenizer.from_pretrained(base_model_name)

fixed_hf_model = modify_hf_wte_copy(hf_model)

model_sql = HookedTransformer.from_pretrained(
            base_model_name,
            hf_model=fixed_hf_model,
            device= 'cuda',
            fold_ln=False,
            center_writing_weights=False,
            center_unembed=False,
            tokenizer=tokenizer_sql,
        )

model_sql.cfg.use_attn_result = True
model_sql.cfg.attn_only = False
model_sql.cfg.use_hook_mlp_in = True
model_sql.cfg.use_split_qkv_input = True

# Calculate results

In [None]:
# Key global "input" variables
clean_prompt = ""
corrupt_prompt = ""
clean_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for clean word
corrupt_tokenizer_index = qts.UNKNOWN_VALUE # Tokenizer vocab index for corrupted word
answer_token_index = qts.UNKNOWN_VALUE # Token index in sql command answer of clean/corrupt word

# Key global "results" variables
clean_logit_diff = qts.UNKNOWN_VALUE
corrupt_logit_diff = qts.UNKNOWN_VALUE

In [None]:
many_examples = []
for i in range(num_examples):
    if model_num > 0:
        # Generate a batch of clean and corrupt prompts for feature_name
        generator = qts.CorruptFeatureTestGenerator(model_num, cs_num, tokenizer_sql, use_novel_names=use_novel_names, use_synonyms_field=use_synonyms_field, use_synonyms_table=use_synonyms_table, num_fields = num_fields)
        examples = generator.generate_feature_examples(feature_name, batch_size)

    examples_prep = []
    for example in examples:

        clean_tokenizer_index = example.clean_tokenizer_index
        corrupt_tokenizer_index = example.corrupt_tokenizer_index
        answer_token_index = example.answer_token_index

        # Truncate the clean_prompt at answer_token_index
        clean_prompt = example.clean_BatchItem.get_alpaca_prompt() + example.clean_BatchItem.sql_statement
        clean_tokens = tokenizer_sql(clean_prompt)["input_ids"]
        clean_tokens = clean_tokens[:answer_token_index]
        clean_prompt = tokenizer_sql.decode(clean_tokens)

        # Truncate the corrupt_prompt at answer_token_index
        corrupt_prompt = example.corrupt_BatchItem.get_alpaca_prompt() + example.corrupt_BatchItem.sql_statement
        corrupt_tokens = tokenizer_sql(corrupt_prompt)["input_ids"]
        corrupt_tokens = corrupt_tokens[:answer_token_index]
        corrupt_prompt = tokenizer_sql.decode(corrupt_tokens)

        examples_prep.append({
            "clean_prompt": clean_prompt,
            "clean_tokens": clean_tokens,
            "corrupt_prompt": corrupt_prompt,
            "corrupt_tokens": corrupt_tokens,
            "clean_tokenizer_index": example.clean_tokenizer_index,
            "corrupt_tokenizer_index": example.corrupt_tokenizer_index,
        })
    many_examples.append(examples_prep)

In [None]:
def ave_logit_diff(
    logits: Float[Tensor, 'batch seq d_vocab'],
    examples_prep,
    per_prompt: bool = False
):
    '''
    Return average logit difference between correct and incorrect answers
    '''
    if hasattr(logits, 'logits'):
        logits = logits.logits
    batch_size = logits.size(0)

    clean_logits = logits[range(batch_size), -1, [example["clean_tokenizer_index"] for example in examples_prep]]
    corrupt_logits = logits[range(batch_size), -1, [example["corrupt_tokenizer_index"] for example in examples_prep]]
    logit_diff = clean_logits - corrupt_logits
    return logit_diff if per_prompt else logit_diff.mean()

# Compute averages across all batches
all_clean_logit_diffs = []
all_corrupt_logit_diffs = []
all_clean_metrics = []
all_corrupt_metrics = []

# First compute all logit differences
with t.no_grad():
    for examples_prep in many_examples:
        clean_logits = model_sql(t.tensor([example["clean_tokens"] for example in examples_prep]).to('cuda'))
        corrupt_logits = model_sql(t.tensor([example["corrupt_tokens"] for example in examples_prep]).to('cuda'))

        clean_logit_diff = ave_logit_diff(clean_logits, examples_prep).item()
        corrupt_logit_diff = ave_logit_diff(corrupt_logits, examples_prep).item()

        all_clean_logit_diffs.append(clean_logit_diff)
        all_corrupt_logit_diffs.append(corrupt_logit_diff)

# Compute final averages for logit differences
final_clean_logit_diff = sum(all_clean_logit_diffs) / len(all_clean_logit_diffs)
final_corrupt_logit_diff = sum(all_corrupt_logit_diffs) / len(all_corrupt_logit_diffs)

def metric(
    logits: Float[Tensor, "batch seq_len d_vocab"],
    corrupted_logit_diff: float = final_corrupt_logit_diff,
    clean_logit_diff: float = final_clean_logit_diff,
    examples = examples_prep
):
    patched_logit_diff = ave_logit_diff(logits, examples)
    return (patched_logit_diff - corrupted_logit_diff) / (clean_logit_diff - corrupted_logit_diff)

def negative_metric(logits: Float[Tensor, "batch seq_len d_vocab"]):
    return -metric(logits)

# Now compute metrics using the final averages
with t.no_grad():
    for examples_prep in many_examples:
        clean_logits = model_sql(t.tensor([example["clean_tokens"] for example in examples_prep]).to('cuda'))
        corrupt_logits = model_sql(t.tensor([example["corrupt_tokens"] for example in examples_prep]).to('cuda'))

        clean_metric = metric(clean_logits, final_corrupt_logit_diff, final_clean_logit_diff, examples_prep)
        corrupt_metric = metric(corrupt_logits, final_corrupt_logit_diff, final_clean_logit_diff, examples_prep)

        all_clean_metrics.append(clean_metric)
        all_corrupt_metrics.append(corrupt_metric)

# Compute final metric averages
final_clean_metric = sum(all_clean_metrics) / len(all_clean_metrics)
final_corrupt_metric = sum(all_corrupt_metrics) / len(all_corrupt_metrics)

print(f'Clean direction: {final_clean_logit_diff}, Corrupt direction: {final_corrupt_logit_diff}')
print(f'Clean metric: {final_clean_metric}, Corrupt metric: {final_corrupt_metric}')

In [None]:
# Initialize tracking
first_batch = True
edge_tracking = {}  # Dictionary to track edges appearing in top 10 across batches
total_batches = len(many_examples)

print("\nProcessing batches and tracking consistent top edges:")
# Process each batch
for batch_idx, examples_prep in enumerate(many_examples):
    print(f"\nBatch {batch_idx + 1}/{total_batches}:")

    # Prepare inputs for this batch
    clean_tokens = t.tensor([example["clean_tokens"] for example in examples_prep]).to('cuda')
    corrupt_tokens = t.tensor([example["corrupt_tokens"] for example in examples_prep]).to('cuda')

    # Run EAP for this batch
    batch_graph = EAP(
        model_sql,
        clean_tokens,
        corrupt_tokens,
        metric,
        upstream_nodes=["mlp", "head"],
        downstream_nodes=["mlp", "head"],
    )

    # Get top 10 edges for this batch
    top_edges = batch_graph.top_edges(n=top_n_edges, abs_scores=True)
    current_batch_edges = set()  # Track edges in current batch

    # Track edges and their scores
    for from_edge, to_edge, score in top_edges:
        edge_key = (from_edge, to_edge)
        current_batch_edges.add(edge_key)

        if edge_key not in edge_tracking:
            edge_tracking[edge_key] = {
                'scores': [score],
                'appearances': 1
            }
        else:
            edge_tracking[edge_key]['scores'].append(score)
            edge_tracking[edge_key]['appearances'] += 1

    if first_batch:
        accumulated_graph = batch_graph
        first_batch = False

# Find edges that were consistently in top 10
consistent_edges = []
for edge_key, data in edge_tracking.items():
    if data['appearances'] == total_batches:  # Edge was in top 10 for all batches
        from_edge, to_edge = edge_key
        mean_score = sum(data['scores']) / len(data['scores'])
        consistent_edges.append((from_edge, to_edge, mean_score))

# Sort consistent edges by absolute mean score
consistent_edges.sort(key=lambda x: abs(x[2]), reverse=True)

print("\nEdges that appeared in top 10 consistently across all batches:")
for from_edge, to_edge, mean_score in consistent_edges:
    print(f'{from_edge} -> [{round(mean_score, 3)}] -> {to_edge}')

print(f"\nFound {len(consistent_edges)} edges that were consistently in top 10 across all {total_batches} batches")

# Plot results

In [None]:
fig = plot_eap_sankey(consistent_edges, layer_spacing=True)
fig.show()

# Save results

In [None]:
def eap_file_name():
    return f"eap_results.BM{model_num}.CS{cs_num}_cs1.{feature_name}.{str(use_novel_names)}.{str(use_synonyms_table)}.{str(use_synonyms_field)}.json"

In [None]:
results = collect_eap_results(
    source='tinysql_EAP_JSON',
    model_name= model_num,
    dataset_name=cs_num,
    hf_model=model_name,
    feature_name= feature_name,
    use_novel_names= use_novel_names,
    use_synonyms_field=use_synonyms_field,
    use_synonyms_table=use_synonyms_table,
    num_sql_fields=num_fields,
    n_batches = num_examples,
    batch_size = batch_size,
    n_positions=clean_tokens.shape[1],
    n_layers=model_sql.cfg.n_layers,
    n_heads=model_sql.cfg.n_heads,
    top_n_edges = top_n_edges,
    edges=consistent_edges,
    clean_logit_diff=clean_logit_diff,
    corrupt_logit_diff=corrupt_logit_diff,
    clean_metric=clean_metric,
    corrupt_metric=corrupt_metric,
)

save_eap_results(results, eap_file_name())

# Test reload of saved results

In [None]:
loaded_results = load_eap_results(eap_file_name())
edges = results_to_edges(loaded_results)

fig = plot_eap_json_sankey(edges, layer_spacing=True)
fig.show()