In [32]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from jaxtyping import Float, Int
from torch import Tensor

torch.set_grad_enabled(False)

# Device setup
GPU_TO_USE = 1

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Device: cuda:1


In [33]:
(2 * 16 * 3000 * 16000 * 2 * 42 + 16 * 3000 * 131000 * 2 * 42) / 8589934592

76.51001214981079

In [34]:
from pathlib import Path

def get_data_path(data_folder, in_colab=COLAB):
  if in_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    return Path(f'/content/drive/MyDrive/{data_folder}')
  else:
    return Path(f'./{data_folder}')
  
datapath = get_data_path('./data')
datapath

PosixPath('data')

In [35]:
import sys
import os

# Add the parent directory (sfc_deception) to sys.path
sys.path.append(os.path.abspath(os.path.join('..')))

from classes.sfc_model import *

In [36]:
import pickle

def load_dict(filename):
    with open(filename, 'rb') as f:
        data_dict = pickle.load(f)
    return data_dict

## Analysis utils

In [37]:
def format_component_name(component, feat_idx=None):
    """
    Formats component name for readability with predictable structure.

    Args:
        component (str): Original component name from cache.
        feat_idx (int, optional): Feature index for components with 'hook_sae_acts_post'.

    Returns:
        str: Formatted component name in the format "{layer}__{type}__{category}".
    """
    parts = component.split('.')
    layer = parts[1]  # Extract the layer number
    component_type = parts[2]  # The type (resid, mlp, or attn)
    category = "error" if 'hook_sae_error' in component else str(feat_idx) if feat_idx is not None else "unknown"

    def component_type_to_string(component_type):
        if component_type == "hook_resid_post":
            return "resid_post"
        elif component_type == "hook_mlp_out":
            return "mlp_out"
        elif component_type == "attn":
            return "attn_z"
        else:
            return "unknown"
        
    component_type = component_type_to_string(component_type)
    
    if 0 <= int(layer) <= 9:
        layer = f'0{layer}'

    return f"{layer}__{component_type}__{category}"


def get_contributing_components_by_token(cache, threshold):
    """
    Identifies components in the cache whose contribution scores are greater than a given threshold.
    Also includes their numerical contribution scores and sorts the components by scores.

    Args:
        cache (dict): Dictionary with keys representing component names and values being tensors of contribution scores.
        threshold (float): The threshold value above which components are considered significant.

    Returns:
        tuple: A tuple containing:
            - dict: Dictionary where keys are token positions (integers) and values are lists of tuples.
                    Each tuple contains (component_name, contribution_score) sorted by contribution scores in descending order.
            - int: Total count of contributing components across all tokens.
    """
    contributing_components = {}
    total_count = 0

    for component, tensor in cache.items():
        if 'hook_sae_error' in component:
            high_contrib_tokens = torch.where(tensor > threshold)[0]

            for token_idx in high_contrib_tokens:
                token_idx = token_idx.item()
                if token_idx not in contributing_components:
                    contributing_components[token_idx] = []

                formatted_name = format_component_name(component)
                contributing_components[token_idx].append((formatted_name, tensor[token_idx].item()))
                total_count += 1

        elif 'hook_sae_acts_post' in component:
            high_contrib_tokens, high_contrib_features = torch.where(tensor > threshold)

            for token_idx, feat_idx in zip(high_contrib_tokens, high_contrib_features):
                token_idx = token_idx.item()
                feat_idx = feat_idx.item()
                component_name = format_component_name(component, feat_idx)

                if token_idx not in contributing_components:
                    contributing_components[token_idx] = []

                contributing_components[token_idx].append((component_name, tensor[token_idx, feat_idx].item()))
                total_count += 1

    for token_idx in contributing_components:
        contributing_components[token_idx] = sorted(
            contributing_components[token_idx], key=lambda x: x[1], reverse=True
        )

    sorted_contributing_components = dict(sorted(contributing_components.items()))

    return sorted_contributing_components, total_count


def get_contributing_components(cache, threshold):
    """
    Identifies components in the cache whose contribution scores are greater than a given threshold.
    Works with scalar and 1-dimensional tensors, omitting token positions in the output.

    Args:
        cache (dict): Dictionary with keys representing component names and values being tensors of contribution scores.
        threshold (float): The threshold value above which components are considered significant.

    Returns:
        tuple: A tuple containing:
            - list: List of tuples, each containing (component_name, contribution_score) sorted by contribution scores in descending order.
            - int: Total count of contributing components across all entries.
    """
    contributing_components = []
    total_count = 0

    for component, tensor in cache.items():
        if 'hook_sae_error' in component:
            if tensor.item() > threshold:
                formatted_name = format_component_name(component)
                contributing_components.append((formatted_name, tensor.item()))
                total_count += 1

        elif 'hook_sae_acts_post' in component:
            high_contrib_indices = torch.where(tensor > threshold)[0]

            for idx in high_contrib_indices:
                component_name = format_component_name(component, idx.item())
                contributing_components.append((component_name, tensor[idx].item()))
                total_count += 1

    sorted_contributing_components = sorted(contributing_components, key=lambda x: x[1], reverse=True)

    return sorted_contributing_components, total_count

In [38]:
def summarize_contributing_components_by_token(contributing_components, show_layers=False):
    """
    Generates hierarchical summary statistics for the contributing components of each token, with optional layer-level details.

    Args:
        contributing_components (dict): Dictionary where keys are token positions (integers) and values are lists of tuples.
                                         Each tuple contains (component_name, contribution_score).
        show_layers (bool): If True, includes layer statistics in the output.

    Returns:
        dict: Hierarchical overview dictionary with summary statistics for each token.
    """
    overview = {}

    for token, components in contributing_components.items():
        # Initialize counters for the current token
        total_count = len(components)
        resid_total = 0
        mlp_total = 0
        attn_total = 0
        
        resid_latents = 0
        resid_errors = 0
        mlp_latents = 0
        mlp_errors = 0
        attn_latents = 0
        attn_errors = 0

        # Initialize layer-specific counters if layer statistics are enabled
        layer_stats = {'resid_post': {}, 'mlp_out': {}, 'attn_z': {}} if show_layers else None

        # Count components and categorize
        for component_name, _ in components:
            # Extract layer and component type from the formatted name
            layer, component_type, category = component_name.split('__')
            layer = int(layer)  # Convert layer to integer for numeric sorting
            is_error = category == 'error'

            # Update the counters based on component type and category
            if component_type == 'resid_post':
                resid_total += 1
                if is_error:
                    resid_errors += 1
                else:
                    resid_latents += 1
                
                # Track layer-specific stats for resid_post components
                if show_layers:
                    if layer not in layer_stats['resid_post']:
                        layer_stats['resid_post'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
                    layer_stats['resid_post'][layer]['total'] += 1
                    layer_stats['resid_post'][layer]['Errors' if is_error else 'Latents'] += 1

            elif component_type == 'mlp_out':
                mlp_total += 1
                if is_error:
                    mlp_errors += 1
                else:
                    mlp_latents += 1
                
                # Track layer-specific stats for mlp_out components
                if show_layers:
                    if layer not in layer_stats['mlp_out']:
                        layer_stats['mlp_out'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
                    layer_stats['mlp_out'][layer]['total'] += 1
                    layer_stats['mlp_out'][layer]['Errors' if is_error else 'Latents'] += 1

            elif component_type == 'attn_z':
                attn_total += 1
                if is_error:
                    attn_errors += 1
                else:
                    attn_latents += 1
                
                # Track layer-specific stats for attn_z components
                if show_layers:
                    if layer not in layer_stats['attn_z']:
                        layer_stats['attn_z'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
                    layer_stats['attn_z'][layer]['total'] += 1
                    layer_stats['attn_z'][layer]['Errors' if is_error else 'Latents'] += 1

        # Compile the statistics into the hierarchical overview dictionary
        overview[token] = {
            'total': total_count,
            'resid_post': {
                'total': resid_total,
                'Latents': resid_latents,
                'Errors': resid_errors,
            },
            'mlp_out': {
                'total': mlp_total,
                'Latents': mlp_latents,
                'Errors': mlp_errors,
            },
            'attn_z': {
                'total': attn_total,
                'Latents': attn_latents,
                'Errors': attn_errors,
            }
        }

        # Include layer-specific statistics if enabled
        if show_layers:
            overview[token]['resid_post']['layers'] = layer_stats['resid_post']
            overview[token]['mlp_out']['layers'] = layer_stats['mlp_out']
            overview[token]['attn_z']['layers'] = layer_stats['attn_z']

    return overview


def summarize_contributing_components(contributing_components, show_layers=False):
    """
    Generates hierarchical summary statistics for a flat list of contributing components,
    with optional layer-level details.

    Args:
        contributing_components (list of tuples): List of tuples where each tuple contains 
                                                  (component_name, contribution_score).
        show_layers (bool): If True, includes layer statistics in the output.

    Returns:
        dict: Hierarchical overview dictionary with summary statistics for each component type (resid, mlp, attn).
    """
    overview = {
        'resid_post': {'total': 0, 'Latents': 0, 'Errors': 0, 'layers': {} if show_layers else None},
        'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0, 'layers': {} if show_layers else None},
        'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0, 'layers': {} if show_layers else None}
    }

    for component_name, score in contributing_components:
        # Extract component type and layer information
        layer, component_type, category = component_name.split('__')
        layer = int(layer)  # Convert layer to integer for numeric sorting
        is_error = category == 'error'

        # Update the main component type counters
        overview[component_type]['total'] += 1
        if is_error:
            overview[component_type]['Errors'] += 1
        else:
            overview[component_type]['Latents'] += 1

        # Update layer-specific stats if layer statistics are enabled
        if show_layers:
            if layer not in overview[component_type]['layers']:
                overview[component_type]['layers'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
            overview[component_type]['layers'][layer]['total'] += 1
            overview[component_type]['layers'][layer]['Errors' if is_error else 'Latents'] += 1

    # Remove 'layers' key if show_layers is False
    if not show_layers:
        for component_type in overview:
            del overview[component_type]['layers']

    return overview

## Aggregation case

In [39]:
aggregation_type = AttributionAggregation.ALL_TOKENS
NODES_PREFIX = 'resid_saes_128k'

def get_nodes_fname(truthful_nodes=True, nodes_prefix=NODES_PREFIX):
    nodes_type = 'truthful' if truthful_nodes else 'deceptive'
    if nodes_prefix:
        fname = f'{aggregation_type.value}_agg_{nodes_prefix}_{nodes_type}_scores.pkl'
    else:
        fname = f'{aggregation_type.value}_agg_{nodes_type}_scores.pkl'

    return datapath / fname

### Truthful nodes

In [74]:
aggregation_type = AttributionAggregation.ALL_TOKENS

truthful_nodes_fname = get_nodes_fname(truthful_nodes=True)

truthful_nodes_scores = load_dict(truthful_nodes_fname)
truthful_nodes_scores['blocks.0.attn.hook_z.hook_sae_error'].shape

torch.Size([])

In [75]:
truthful_nodes_scores, total_components = get_contributing_components(truthful_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

truthful_nodes_scores

Total contributing components: 3262


[('33__resid_post__47402', 0.94140625),
 ('35__resid_post__55733', 0.890625),
 ('34__resid_post__113223', 0.82421875),
 ('36__resid_post__121683', 0.80078125),
 ('31__resid_post__108794', 0.7109375),
 ('32__resid_post__75203', 0.69921875),
 ('30__resid_post__108138', 0.62890625),
 ('27__resid_post__38183', 0.59765625),
 ('29__resid_post__6672', 0.57421875),
 ('28__resid_post__12999', 0.5625),
 ('26__resid_post__64318', 0.45703125),
 ('25__resid_post__76025', 0.40625),
 ('36__resid_post__error', 0.333984375),
 ('38__resid_post__88841', 0.326171875),
 ('25__mlp_out__error', 0.3125),
 ('41__resid_post__89775', 0.310546875),
 ('24__resid_post__34177', 0.298828125),
 ('18__attn_z__error', 0.294921875),
 ('41__resid_post__67881', 0.29296875),
 ('39__resid_post__115900', 0.291015625),
 ('06__resid_post__error', 0.27734375),
 ('35__resid_post__error', 0.265625),
 ('37__resid_post__21833', 0.26171875),
 ('40__resid_post__80858', 0.2578125),
 ('36__resid_post__107506', 0.2421875),
 ('10__resid_p

In [76]:
node_scores_summary = summarize_contributing_components(truthful_nodes_scores, show_layers=True)
node_scores_summary

{'resid_post': {'total': 2590,
  'Latents': 2570,
  'Errors': 20,
  'layers': {33: {'total': 38, 'Latents': 37, 'Errors': 1},
   35: {'total': 36, 'Latents': 35, 'Errors': 1},
   34: {'total': 40, 'Latents': 39, 'Errors': 1},
   36: {'total': 35, 'Latents': 34, 'Errors': 1},
   31: {'total': 48, 'Latents': 47, 'Errors': 1},
   32: {'total': 43, 'Latents': 43, 'Errors': 0},
   30: {'total': 38, 'Latents': 37, 'Errors': 1},
   27: {'total': 49, 'Latents': 49, 'Errors': 0},
   29: {'total': 29, 'Latents': 28, 'Errors': 1},
   28: {'total': 41, 'Latents': 41, 'Errors': 0},
   26: {'total': 34, 'Latents': 34, 'Errors': 0},
   25: {'total': 41, 'Latents': 41, 'Errors': 0},
   38: {'total': 59, 'Latents': 59, 'Errors': 0},
   41: {'total': 30, 'Latents': 30, 'Errors': 0},
   24: {'total': 33, 'Latents': 32, 'Errors': 1},
   39: {'total': 67, 'Latents': 67, 'Errors': 0},
   6: {'total': 69, 'Latents': 68, 'Errors': 1},
   37: {'total': 47, 'Latents': 47, 'Errors': 0},
   40: {'total': 48, 'Lat

#### Analysis

### Deceptive nodes

In [77]:
aggregation_type = AttributionAggregation.ALL_TOKENS

deceptive_nodes_fname = get_nodes_fname(truthful_nodes=False)

deceptive_nodes_scores = load_dict(deceptive_nodes_fname)
deceptive_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [78]:
deceptive_nodes_scores, total_components = get_contributing_components(deceptive_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

deceptive_nodes_scores

Total contributing components: 2957


[('14__mlp_out__error', 0.47265625),
 ('15__resid_post__error', 0.36328125),
 ('18__mlp_out__error', 0.302734375),
 ('01__resid_post__11754', 0.298828125),
 ('04__resid_post__2471', 0.294921875),
 ('00__resid_post__10914', 0.28515625),
 ('23__resid_post__error', 0.283203125),
 ('06__resid_post__7093', 0.279296875),
 ('02__resid_post__88', 0.27734375),
 ('02__resid_post__11760', 0.275390625),
 ('01__resid_post__7812', 0.2734375),
 ('05__resid_post__11393', 0.267578125),
 ('14__resid_post__error', 0.267578125),
 ('10__resid_post__6339', 0.2431640625),
 ('04__resid_post__10543', 0.2392578125),
 ('00__resid_post__5477', 0.2314453125),
 ('11__resid_post__6339', 0.22265625),
 ('08__resid_post__15120', 0.1904296875),
 ('02__resid_post__5829', 0.189453125),
 ('19__resid_post__6427', 0.189453125),
 ('07__resid_post__10093', 0.185546875),
 ('06__resid_post__1115', 0.1728515625),
 ('15__attn_z__error', 0.169921875),
 ('03__resid_post__5829', 0.1650390625),
 ('11__mlp_out__error', 0.1650390625),
 

In [79]:
summarize_contributing_components(deceptive_nodes_scores, show_layers=True)

{'resid_post': {'total': 2316,
  'Latents': 2304,
  'Errors': 12,
  'layers': {15: {'total': 109, 'Latents': 108, 'Errors': 1},
   1: {'total': 85, 'Latents': 85, 'Errors': 0},
   4: {'total': 85, 'Latents': 85, 'Errors': 0},
   0: {'total': 116, 'Latents': 115, 'Errors': 1},
   23: {'total': 21, 'Latents': 20, 'Errors': 1},
   6: {'total': 106, 'Latents': 106, 'Errors': 0},
   2: {'total': 89, 'Latents': 89, 'Errors': 0},
   5: {'total': 97, 'Latents': 97, 'Errors': 0},
   14: {'total': 98, 'Latents': 97, 'Errors': 1},
   10: {'total': 112, 'Latents': 111, 'Errors': 1},
   11: {'total': 113, 'Latents': 112, 'Errors': 1},
   8: {'total': 105, 'Latents': 105, 'Errors': 0},
   19: {'total': 84, 'Latents': 84, 'Errors': 0},
   7: {'total': 101, 'Latents': 101, 'Errors': 0},
   3: {'total': 85, 'Latents': 85, 'Errors': 0},
   9: {'total': 90, 'Latents': 90, 'Errors': 0},
   17: {'total': 88, 'Latents': 88, 'Errors': 0},
   25: {'total': 12, 'Latents': 11, 'Errors': 1},
   16: {'total': 96,

## Truthful vs Deceptive nodes analysis

In [80]:
len(truthful_nodes_scores), len(deceptive_nodes_scores)

(3262, 2957)

In [81]:
import plotly.express as px
import math
import plotly.graph_objects as go

def plot_top_k_shared_nodes(clean_node_scores, corrupted_node_scores, K=1000, selection_criterion='clean'):
    # Convert lists to dictionaries for easy access by node name
    clean_dict = dict(clean_node_scores)
    corrupted_dict = dict(corrupted_node_scores)

    # Find intersection of node names
    common_nodes = set(clean_dict.keys()).intersection(corrupted_dict.keys())
    
    # Extract scores for common nodes
    common_data = [
        (node, clean_dict[node], corrupted_dict[node]) 
        for node in common_nodes
    ]
    
    # Select top K nodes based on max value in either list, depending on the selection criterion
    if selection_criterion == 'clean':
        sorted_data = sorted(common_data, key=lambda x: clean_dict[x[0]], reverse=True)
    elif selection_criterion == 'corrupted':
        sorted_data = sorted(common_data, key=lambda x: corrupted_dict[x[0]], reverse=True)
    else:
        raise ValueError("selection_criterion should be 'clean' or 'corrupted'")
        
    top_k_data = sorted_data[:K]

    # Prepare data for scatter plot with log-10 transformed scores
    node_names = [x[0] for x in top_k_data]
    clean_scores = [x[1] for x in top_k_data]
    corrupted_scores = [x[2] for x in top_k_data]

    
    # Create the scatter plot with hover info
    fig = px.scatter(x=clean_scores, y=corrupted_scores, hover_name=node_names, 
                     labels={'x': '(Clean Node Scores)', 'y': '(Corrupted Node Scores)'},
                     title="Top K Shared Node Scores for Clean and Corrupted Tasks ( Scaled)")

    # Add the y=x line for clarity
    max_score = max(max(clean_scores), max(corrupted_scores))
    min_score = min(min(clean_scores), min(corrupted_scores))
    fig.add_trace(go.Scatter(x=[min_score, max_score], y=[min_score, max_score], mode='lines', 
                             line=dict(dash='dash', color='gray'), name='y = x'))
    # Customize plot
    fig.update_traces(marker=dict(size=8))
    fig.update_layout(xaxis_title="(Score) (Clean Task)", yaxis_title="(Score) (Corrupted Task)")

    fig.show()


plot_top_k_shared_nodes(truthful_nodes_scores, deceptive_nodes_scores, 
                        selection_criterion='corrupted', K=5000)

## No aggregation case

In [64]:
aggregation_type = AttributionAggregation.NONE

### Truthful nodes

In [68]:
aggregation_type = AttributionAggregation.NONE

truthful_nodes_fname = get_nodes_fname(truthful_nodes=True, nodes_prefix=None)

truthful_nodes_scores = load_dict(truthful_nodes_fname)
truthful_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [69]:
node_scores, total_components = get_contributing_components_by_token(truthful_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

node_scores

Total contributing components: 1336


{0: [('11__resid_post__6339', 0.01025390625)],
 1: [('11__resid_post__6339', 0.01373291015625)],
 3: [('01__resid_post__error', 0.01458740234375),
  ('00__resid_post__error', 0.0125732421875)],
 4: [('08__resid_post__4692', 0.01165771484375),
  ('08__resid_post__error', 0.0115966796875),
  ('08__resid_post__2340', 0.01019287109375)],
 5: [('00__resid_post__2250', 0.01123046875)],
 6: [('00__resid_post__15681', 0.025390625),
  ('04__resid_post__11151', 0.0157470703125),
  ('02__resid_post__7709', 0.01123046875)],
 9: [('03__resid_post__8716', 0.0155029296875),
  ('04__resid_post__16383', 0.0120849609375),
  ('00__resid_post__10904', 0.0111083984375),
  ('02__resid_post__5116', 0.01104736328125),
  ('05__resid_post__2583', 0.01068115234375),
  ('01__resid_post__9479', 0.010498046875)],
 10: [('07__resid_post__error', 0.01123046875)],
 11: [('05__resid_post__5427', 0.02099609375),
  ('01__resid_post__16294', 0.0205078125),
  ('00__resid_post__7425', 0.0189208984375),
  ('03__resid_post__9

In [70]:
node_scores_summary = summarize_contributing_components_by_token(node_scores, show_layers=False)
node_scores_summary

{0: {'total': 1,
  'resid_post': {'total': 1, 'Latents': 1, 'Errors': 0},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 1: {'total': 1,
  'resid_post': {'total': 1, 'Latents': 1, 'Errors': 0},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 3: {'total': 2,
  'resid_post': {'total': 2, 'Latents': 0, 'Errors': 2},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 4: {'total': 3,
  'resid_post': {'total': 3, 'Latents': 2, 'Errors': 1},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 5: {'total': 1,
  'resid_post': {'total': 1, 'Latents': 1, 'Errors': 0},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 6: {'total': 3,
  'resid_post': {'total': 3, 'Latents': 3, 'Errors': 0},
  'mlp_out': {'t

### Deceptive nodes

In [71]:
aggregation_type = AttributionAggregation.NONE

deceptive_nodes_fname = get_nodes_fname(truthful_nodes=False, nodes_prefix=None)

deceptive_nodes_scores = load_dict(deceptive_nodes_fname)
deceptive_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [72]:
node_scores, total_components = get_contributing_components_by_token(deceptive_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

node_scores

Total contributing components: 2960


{0: [('03__resid_post__15561', 0.034912109375),
  ('04__resid_post__9803', 0.0269775390625),
  ('03__resid_post__6861', 0.01953125),
  ('04__resid_post__16152', 0.0194091796875),
  ('02__resid_post__9887', 0.01806640625),
  ('00__resid_post__11517', 0.017578125),
  ('03__resid_post__14256', 0.016845703125),
  ('01__resid_post__8052', 0.01611328125),
  ('00__attn_z__8785', 0.013671875),
  ('04__resid_post__1970', 0.0125732421875),
  ('01__resid_post__15561', 0.0123291015625),
  ('11__resid_post__6339', 0.01080322265625),
  ('00__resid_post__6339', 0.0103759765625),
  ('02__resid_post__11460', 0.01031494140625)],
 1: [('02__resid_post__4678', 0.0220947265625),
  ('03__resid_post__11636', 0.01153564453125),
  ('11__resid_post__6339', 0.01068115234375),
  ('04__resid_post__4678', 0.010498046875)],
 2: [('03__resid_post__13647', 0.0419921875),
  ('01__resid_post__14124', 0.031005859375),
  ('00__resid_post__10543', 0.02587890625),
  ('02__resid_post__6761', 0.02197265625),
  ('00__resid_pos

In [73]:
node_scores_summary = summarize_contributing_components_by_token(node_scores, show_layers=False)
node_scores_summary

{0: {'total': 14,
  'resid_post': {'total': 13, 'Latents': 13, 'Errors': 0},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 1, 'Latents': 1, 'Errors': 0}},
 1: {'total': 4,
  'resid_post': {'total': 4, 'Latents': 4, 'Errors': 0},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 2: {'total': 10,
  'resid_post': {'total': 9, 'Latents': 9, 'Errors': 0},
  'mlp_out': {'total': 1, 'Latents': 1, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 3: {'total': 5,
  'resid_post': {'total': 5, 'Latents': 5, 'Errors': 0},
  'mlp_out': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn_z': {'total': 0, 'Latents': 0, 'Errors': 0}},
 4: {'total': 30,
  'resid_post': {'total': 26, 'Latents': 24, 'Errors': 2},
  'mlp_out': {'total': 3, 'Latents': 3, 'Errors': 0},
  'attn_z': {'total': 1, 'Latents': 1, 'Errors': 0}},
 5: {'total': 18,
  'resid_post': {'total': 18, 'Latents': 17, 'Errors': 1},
  'mlp