In [1]:
try:
    import google.colab # type: ignore
    from google.colab import output
    COLAB = True
    %pip install sae-lens transformer-lens
except:
    COLAB = False
    from IPython import get_ipython # type: ignore
    ipython = get_ipython(); assert ipython is not None
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

# Standard imports
import os
import torch
import numpy as np
from tqdm import tqdm
import plotly.express as px
import pandas as pd
import einops
from jaxtyping import Float, Int
from torch import Tensor

torch.set_grad_enabled(False)

# Device setup
GPU_TO_USE = 1

if torch.backends.mps.is_available():
    device = "mps"
else:
    device = f"cuda:{GPU_TO_USE}" if torch.cuda.is_available() else "cpu"

print(f"Device: {device}")

# utility to clear variables out of the memory & and clearing cuda cache
import gc
def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()

Device: cuda:1


In [2]:
from pathlib import Path

def get_data_path(data_folder, in_colab=COLAB):
  if in_colab:
    from google.colab import drive
    drive.mount('/content/drive')

    return Path(f'/content/drive/MyDrive/{data_folder}')
  else:
    return Path(f'./{data_folder}')
  
datapath = get_data_path('./data')
datapath

PosixPath('data')

In [4]:
import sys
import os

# Add the parent directory (sfc_deception) to sys.path
sys.path.append(os.path.abspath(os.path.join('..')))

from classes.sfc_model import *

In [3]:
import pickle

def load_dict(filename):
    with open(filename, 'rb') as f:
        data_dict = pickle.load(f)
    return data_dict

## Analysis utils

In [10]:
def get_contributing_components_by_token(cache, threshold):
    """
    Identifies components in the cache whose contribution scores are greater than a given threshold.
    Also includes their numerical contribution scores and sorts the components by scores.

    Args:
        cache (dict): Dictionary with keys representing component names and values being tensors of contribution scores.
        threshold (float): The threshold value above which components are considered significant.

    Returns:
        tuple: A tuple containing:
            - dict: Dictionary where keys are token positions (integers) and values are lists of tuples.
                    Each tuple contains (component_name, contribution_score) sorted by contribution scores in descending order.
            - int: Total count of contributing components across all tokens.
    """
    contributing_components = {}
    total_count = 0  # Initialize the total count of contributing components

    for component, tensor in cache.items():
        if 'hook_sae_error' in component:
            # Single scalar contribution per token [n_context]
            high_contrib_tokens = torch.where(tensor > threshold)[0]

            for token_idx in high_contrib_tokens:
                token_idx = token_idx.item()
                if token_idx not in contributing_components:
                    contributing_components[token_idx] = []

                # Add component name and its contribution score
                contributing_components[token_idx].append(
                    (component, tensor[token_idx].item())
                )
                total_count += 1  # Increment count for each contributing component

        elif 'hook_sae_acts_post' in component:
            # Per-dimension contribution [d_sae]
            high_contrib_tokens, high_contrib_features = torch.where(tensor > threshold)

            for token_idx, feat_idx in zip(high_contrib_tokens, high_contrib_features):
                token_idx = token_idx.item()
                feat_idx = feat_idx.item()
                component_name = f"{component}_{feat_idx}"

                if token_idx not in contributing_components:
                    contributing_components[token_idx] = []

                contributing_components[token_idx].append(
                    (component_name, tensor[token_idx, feat_idx].item())
                )
                total_count += 1  # Increment count for each contributing component

    # Sort each token's component list by the contribution scores in descending order
    for token_idx in contributing_components:
        contributing_components[token_idx] = sorted(
            contributing_components[token_idx], key=lambda x: x[1], reverse=True
        )

    # Sort the keys (token positions) in ascending order
    sorted_contributing_components = dict(sorted(contributing_components.items()))

    return sorted_contributing_components, total_count  # Return both the sorted components and the count

def get_contributing_components(cache, threshold):
    """
    Identifies components in the cache whose contribution scores are greater than a given threshold.
    Works with scalar and 1-dimensional tensors, omitting token positions in the output.

    Args:
        cache (dict): Dictionary with keys representing component names and values being tensors of contribution scores.
        threshold (float): The threshold value above which components are considered significant.

    Returns:
        tuple: A tuple containing:
            - list: List of tuples, each containing (component_name, contribution_score) sorted by contribution scores in descending order.
            - int: Total count of contributing components across all entries.
    """
    contributing_components = []
    total_count = 0  # Initialize the total count of contributing components

    for component, tensor in cache.items():
        # For error terms (scalar scores)
        if 'hook_sae_error' in component:
            # Check if the scalar is above the threshold
            if tensor.item() > threshold:
                # Add component name and its contribution score
                contributing_components.append((component, tensor.item()))
                total_count += 1

        # For post-activation terms (1D tensor)
        elif 'hook_sae_acts_post' in component:
            high_contrib_indices = torch.where(tensor > threshold)[0]

            for idx in high_contrib_indices:
                # Generate component name including feature index
                component_name = f"{component}_{idx.item()}"
                
                # Add component name and its contribution score
                contributing_components.append((component_name, tensor[idx].item()))
                total_count += 1

    # Sort components by contribution scores in descending order
    sorted_contributing_components = sorted(contributing_components, key=lambda x: x[1], reverse=True)

    return sorted_contributing_components, total_count  # Return both the sorted components and the count

def summarize_contributing_components_by_token(contributing_components, show_layers=False):
    """
    Generates hierarchical summary statistics for the contributing components of each token, with optional layer-level details.

    Args:
        contributing_components (dict): Dictionary where keys are token positions (integers) and values are lists of tuples.
                                         Each tuple contains (component_name, contribution_score).
        show_layers (bool): If True, includes layer statistics in the output.

    Returns:
        dict: Hierarchical overview dictionary with summary statistics for each token.
    """
    overview = {}

    for token, components in contributing_components.items():
        # Initialize counters for the current token
        total_count = len(components)
        resid_total = 0
        mlp_total = 0
        attn_total = 0
        
        resid_latents = 0
        resid_errors = 0
        mlp_latents = 0
        mlp_errors = 0
        attn_latents = 0
        attn_errors = 0

        # Initialize layer-specific counters if layer statistics are enabled
        layer_stats = {'resid': {}, 'mlp': {}, 'attn': {}} if show_layers else None

        # Count components and categorize
        for component_name, _ in components:
            layer = int(component_name.split('.')[1])  # Extract the layer number
            
            if 'resid' in component_name:
                resid_total += 1
                is_error = 'hook_sae_error' in component_name
                if is_error:
                    resid_errors += 1
                else:
                    resid_latents += 1
                
                # Track layer-specific stats for resid components
                if show_layers:
                    if layer not in layer_stats['resid']:
                        layer_stats['resid'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
                    layer_stats['resid'][layer]['total'] += 1
                    layer_stats['resid'][layer]['Errors' if is_error else 'Latents'] += 1

            elif 'mlp' in component_name:
                mlp_total += 1
                is_error = 'hook_sae_error' in component_name
                if is_error:
                    mlp_errors += 1
                else:
                    mlp_latents += 1
                
                # Track layer-specific stats for mlp components
                if show_layers:
                    if layer not in layer_stats['mlp']:
                        layer_stats['mlp'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
                    layer_stats['mlp'][layer]['total'] += 1
                    layer_stats['mlp'][layer]['Errors' if is_error else 'Latents'] += 1

            elif 'attn' in component_name:
                attn_total += 1
                is_error = 'hook_sae_error' in component_name
                if is_error:
                    attn_errors += 1
                else:
                    attn_latents += 1
                
                # Track layer-specific stats for attn components
                if show_layers:
                    if layer not in layer_stats['attn']:
                        layer_stats['attn'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
                    layer_stats['attn'][layer]['total'] += 1
                    layer_stats['attn'][layer]['Errors' if is_error else 'Latents'] += 1

        # Compile the statistics into the hierarchical overview dictionary
        overview[token] = {
            'total': total_count,
            'resid': {
                'total': resid_total,
                'Latents': resid_latents,
                'Errors': resid_errors,
            },
            'mlp': {
                'total': mlp_total,
                'Latents': mlp_latents,
                'Errors': mlp_errors,
            },
            'attn': {
                'total': attn_total,
                'Latents': attn_latents,
                'Errors': attn_errors,
            }
        }

        # Include layer-specific statistics if enabled
        if show_layers:
            overview[token]['resid']['layers'] = layer_stats['resid']
            overview[token]['mlp']['layers'] = layer_stats['mlp']
            overview[token]['attn']['layers'] = layer_stats['attn']

    return overview

def summarize_contributing_components(contributing_components, show_layers=False):
    """
    Generates hierarchical summary statistics for a flat list of contributing components,
    with optional layer-level details.

    Args:
        contributing_components (list of tuples): List of tuples where each tuple contains 
                                                  (component_name, contribution_score).
        show_layers (bool): If True, includes layer statistics in the output.

    Returns:
        dict: Hierarchical overview dictionary with summary statistics for each component type (resid, mlp, attn).
    """
    overview = {
        'resid': {'total': 0, 'Latents': 0, 'Errors': 0, 'layers': {} if show_layers else None},
        'mlp': {'total': 0, 'Latents': 0, 'Errors': 0, 'layers': {} if show_layers else None},
        'attn': {'total': 0, 'Latents': 0, 'Errors': 0, 'layers': {} if show_layers else None}
    }

    for component_name, score in contributing_components:
        # Extract component type and layer information
        layer = int(component_name.split('.')[1])
        component_type = 'resid' if 'resid' in component_name else 'mlp' if 'mlp' in component_name else 'attn'
        is_error = 'hook_sae_error' in component_name

        # Update the main component type counters
        overview[component_type]['total'] += 1
        if is_error:
            overview[component_type]['Errors'] += 1
        else:
            overview[component_type]['Latents'] += 1

        # Update layer-specific stats if layer statistics are enabled
        if show_layers:
            if layer not in overview[component_type]['layers']:
                overview[component_type]['layers'][layer] = {'total': 0, 'Latents': 0, 'Errors': 0}
            overview[component_type]['layers'][layer]['total'] += 1
            overview[component_type]['layers'][layer]['Errors' if is_error else 'Latents'] += 1

    # Remove 'layers' key if show_layers is False
    if not show_layers:
        for component_type in overview:
            del overview[component_type]['layers']

    return overview

## Aggregation case

In [50]:
aggregation_type = AttributionAggregation.ALL_TOKENS

### Truthful nodes

In [51]:
aggregation_type = AttributionAggregation.ALL_TOKENS

truthful_nodes_fname = datapath / f'{aggregation_type.value}_agg_truthful_scores.pkl'

truthful_nodes_scores = load_dict(truthful_nodes_fname)
truthful_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [54]:
node_scores, total_components = get_contributing_components(truthful_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

node_scores[:10]

Total contributing components: 1695


[('blocks.29.hook_resid_post.hook_sae_acts_post_8545', 0.99609375),
 ('blocks.34.hook_resid_post.hook_sae_acts_post_1063', 0.953125),
 ('blocks.32.hook_resid_post.hook_sae_acts_post_1518', 0.8828125),
 ('blocks.30.hook_resid_post.hook_sae_acts_post_12918', 0.83203125),
 ('blocks.25.hook_resid_post.hook_sae_acts_post_15410', 0.8203125),
 ('blocks.28.hook_resid_post.hook_sae_acts_post_3206', 0.7890625),
 ('blocks.33.hook_resid_post.hook_sae_acts_post_1725', 0.74609375),
 ('blocks.27.hook_resid_post.hook_sae_acts_post_2532', 0.59765625),
 ('blocks.26.hook_resid_post.hook_sae_acts_post_2899', 0.55078125),
 ('blocks.35.hook_resid_post.hook_sae_acts_post_1063', 0.51171875)]

In [55]:
node_scores_summary = summarize_contributing_components(node_scores, show_layers=True)
node_scores_summary

{'resid': {'total': 1385,
  'Latents': 1372,
  'Errors': 13,
  'layers': {29: {'total': 22, 'Latents': 21, 'Errors': 1},
   34: {'total': 27, 'Latents': 26, 'Errors': 1},
   32: {'total': 33, 'Latents': 32, 'Errors': 1},
   30: {'total': 20, 'Latents': 20, 'Errors': 0},
   25: {'total': 26, 'Latents': 25, 'Errors': 1},
   28: {'total': 31, 'Latents': 31, 'Errors': 0},
   33: {'total': 27, 'Latents': 26, 'Errors': 1},
   27: {'total': 32, 'Latents': 32, 'Errors': 0},
   26: {'total': 26, 'Latents': 26, 'Errors': 0},
   35: {'total': 30, 'Latents': 30, 'Errors': 0},
   31: {'total': 29, 'Latents': 28, 'Errors': 1},
   40: {'total': 32, 'Latents': 32, 'Errors': 0},
   39: {'total': 28, 'Latents': 28, 'Errors': 0},
   38: {'total': 35, 'Latents': 35, 'Errors': 0},
   37: {'total': 28, 'Latents': 28, 'Errors': 0},
   36: {'total': 32, 'Latents': 32, 'Errors': 0},
   41: {'total': 25, 'Latents': 25, 'Errors': 0},
   24: {'total': 26, 'Latents': 25, 'Errors': 1},
   10: {'total': 34, 'Latents

### Deceptive nodes

In [60]:
aggregation_type = AttributionAggregation.ALL_TOKENS

deceptive_nodes_fname = datapath / f'{aggregation_type.value}_agg_deceptive_scores.pkl'

deceptive_nodes_scores = load_dict(deceptive_nodes_fname)
deceptive_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [61]:
node_scores, total_components = get_contributing_components(deceptive_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

node_scores

Total contributing components: 3899


[('blocks.4.hook_resid_post.hook_sae_acts_post_2471', 0.55078125),
 ('blocks.0.hook_resid_post.hook_sae_acts_post_10914', 0.51171875),
 ('blocks.1.hook_resid_post.hook_sae_acts_post_11754', 0.5),
 ('blocks.14.hook_mlp_out.hook_sae_error', 0.478515625),
 ('blocks.2.hook_resid_post.hook_sae_acts_post_88', 0.47265625),
 ('blocks.2.hook_resid_post.hook_sae_acts_post_11760', 0.455078125),
 ('blocks.15.hook_resid_post.hook_sae_error', 0.451171875),
 ('blocks.1.hook_resid_post.hook_sae_acts_post_7812', 0.44140625),
 ('blocks.6.hook_resid_post.hook_sae_acts_post_7093', 0.400390625),
 ('blocks.3.hook_resid_post.hook_sae_acts_post_15246', 0.345703125),
 ('blocks.23.hook_resid_post.hook_sae_error', 0.296875),
 ('blocks.20.hook_resid_post.hook_sae_error', 0.291015625),
 ('blocks.5.hook_resid_post.hook_sae_acts_post_11393', 0.2734375),
 ('blocks.10.hook_resid_post.hook_sae_acts_post_6339', 0.26953125),
 ('blocks.7.hook_resid_post.hook_sae_acts_post_10093', 0.2578125),
 ('blocks.4.hook_resid_post.ho

In [62]:
node_scores_summary = summarize_contributing_components(node_scores, show_layers=True)
node_scores_summary

{'resid': {'total': 3021,
  'Latents': 3001,
  'Errors': 20,
  'layers': {4: {'total': 101, 'Latents': 101, 'Errors': 0},
   0: {'total': 173, 'Latents': 172, 'Errors': 1},
   1: {'total': 133, 'Latents': 133, 'Errors': 0},
   2: {'total': 136, 'Latents': 136, 'Errors': 0},
   15: {'total': 146, 'Latents': 145, 'Errors': 1},
   6: {'total': 132, 'Latents': 132, 'Errors': 0},
   3: {'total': 115, 'Latents': 115, 'Errors': 0},
   23: {'total': 28, 'Latents': 27, 'Errors': 1},
   20: {'total': 62, 'Latents': 61, 'Errors': 1},
   5: {'total': 126, 'Latents': 126, 'Errors': 0},
   10: {'total': 139, 'Latents': 138, 'Errors': 1},
   7: {'total': 131, 'Latents': 131, 'Errors': 0},
   8: {'total': 133, 'Latents': 133, 'Errors': 0},
   14: {'total': 123, 'Latents': 122, 'Errors': 1},
   19: {'total': 116, 'Latents': 116, 'Errors': 0},
   9: {'total': 113, 'Latents': 113, 'Errors': 0},
   11: {'total': 142, 'Latents': 142, 'Errors': 0},
   22: {'total': 53, 'Latents': 53, 'Errors': 0},
   17: {'

## No aggregation case

In [30]:
aggregation_type = AttributionAggregation.NONE

### Truthful nodes

In [63]:
aggregation_type = AttributionAggregation.NONE

truthful_nodes_fname = datapath / f'{aggregation_type.value}_agg_truthful_scores.pkl'

truthful_nodes_scores = load_dict(truthful_nodes_fname)
truthful_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [64]:
node_scores, total_components = get_contributing_components_by_token(truthful_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

node_scores

Total contributing components: 1336


{0: [('blocks.11.hook_resid_post.hook_sae_acts_post_6339', 0.01025390625)],
 1: [('blocks.11.hook_resid_post.hook_sae_acts_post_6339', 0.01373291015625)],
 3: [('blocks.1.hook_resid_post.hook_sae_error', 0.01458740234375),
  ('blocks.0.hook_resid_post.hook_sae_error', 0.0125732421875)],
 4: [('blocks.8.hook_resid_post.hook_sae_acts_post_4692', 0.01165771484375),
  ('blocks.8.hook_resid_post.hook_sae_error', 0.0115966796875),
  ('blocks.8.hook_resid_post.hook_sae_acts_post_2340', 0.01019287109375)],
 5: [('blocks.0.hook_resid_post.hook_sae_acts_post_2250', 0.01123046875)],
 6: [('blocks.0.hook_resid_post.hook_sae_acts_post_15681', 0.025390625),
  ('blocks.4.hook_resid_post.hook_sae_acts_post_11151', 0.0157470703125),
  ('blocks.2.hook_resid_post.hook_sae_acts_post_7709', 0.01123046875)],
 9: [('blocks.3.hook_resid_post.hook_sae_acts_post_8716', 0.0155029296875),
  ('blocks.4.hook_resid_post.hook_sae_acts_post_16383', 0.0120849609375),
  ('blocks.0.hook_resid_post.hook_sae_acts_post_1090

In [65]:
node_scores_summary = summarize_contributing_components_by_token(node_scores, show_layers=False)
node_scores_summary

{0: {'total': 1,
  'resid': {'total': 1, 'Latents': 1, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 1: {'total': 1,
  'resid': {'total': 1, 'Latents': 1, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 3: {'total': 2,
  'resid': {'total': 2, 'Latents': 0, 'Errors': 2},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 4: {'total': 3,
  'resid': {'total': 3, 'Latents': 2, 'Errors': 1},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 5: {'total': 1,
  'resid': {'total': 1, 'Latents': 1, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 6: {'total': 3,
  'resid': {'total': 3, 'Latents': 3, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'La

### Deceptive nodes

In [66]:
aggregation_type = AttributionAggregation.NONE

deceptive_nodes_fname = datapath / f'{aggregation_type.value}_agg_deceptive_scores.pkl'

deceptive_nodes_scores = load_dict(deceptive_nodes_fname)
deceptive_nodes_scores.keys()

dict_keys(['blocks.0.attn.hook_z.hook_sae_error', 'blocks.0.attn.hook_z.hook_sae_acts_post', 'blocks.0.hook_mlp_out.hook_sae_error', 'blocks.0.hook_mlp_out.hook_sae_acts_post', 'blocks.0.hook_resid_post.hook_sae_error', 'blocks.0.hook_resid_post.hook_sae_acts_post', 'blocks.1.attn.hook_z.hook_sae_error', 'blocks.1.attn.hook_z.hook_sae_acts_post', 'blocks.1.hook_mlp_out.hook_sae_error', 'blocks.1.hook_mlp_out.hook_sae_acts_post', 'blocks.1.hook_resid_post.hook_sae_error', 'blocks.1.hook_resid_post.hook_sae_acts_post', 'blocks.2.attn.hook_z.hook_sae_error', 'blocks.2.attn.hook_z.hook_sae_acts_post', 'blocks.2.hook_mlp_out.hook_sae_error', 'blocks.2.hook_mlp_out.hook_sae_acts_post', 'blocks.2.hook_resid_post.hook_sae_error', 'blocks.2.hook_resid_post.hook_sae_acts_post', 'blocks.3.attn.hook_z.hook_sae_error', 'blocks.3.attn.hook_z.hook_sae_acts_post', 'blocks.3.hook_mlp_out.hook_sae_error', 'blocks.3.hook_mlp_out.hook_sae_acts_post', 'blocks.3.hook_resid_post.hook_sae_error', 'blocks.3.ho

In [67]:
node_scores, total_components = get_contributing_components_by_token(deceptive_nodes_scores, 0.01)
print(f"Total contributing components: {total_components}")

node_scores

Total contributing components: 2960


{0: [('blocks.3.hook_resid_post.hook_sae_acts_post_15561', 0.034912109375),
  ('blocks.4.hook_resid_post.hook_sae_acts_post_9803', 0.0269775390625),
  ('blocks.3.hook_resid_post.hook_sae_acts_post_6861', 0.01953125),
  ('blocks.4.hook_resid_post.hook_sae_acts_post_16152', 0.0194091796875),
  ('blocks.2.hook_resid_post.hook_sae_acts_post_9887', 0.01806640625),
  ('blocks.0.hook_resid_post.hook_sae_acts_post_11517', 0.017578125),
  ('blocks.3.hook_resid_post.hook_sae_acts_post_14256', 0.016845703125),
  ('blocks.1.hook_resid_post.hook_sae_acts_post_8052', 0.01611328125),
  ('blocks.0.attn.hook_z.hook_sae_acts_post_8785', 0.013671875),
  ('blocks.4.hook_resid_post.hook_sae_acts_post_1970', 0.0125732421875),
  ('blocks.1.hook_resid_post.hook_sae_acts_post_15561', 0.0123291015625),
  ('blocks.11.hook_resid_post.hook_sae_acts_post_6339', 0.01080322265625),
  ('blocks.0.hook_resid_post.hook_sae_acts_post_6339', 0.0103759765625),
  ('blocks.2.hook_resid_post.hook_sae_acts_post_11460', 0.010314

In [68]:
node_scores_summary = summarize_contributing_components_by_token(node_scores, show_layers=False)
node_scores_summary

{0: {'total': 14,
  'resid': {'total': 13, 'Latents': 13, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 1, 'Latents': 1, 'Errors': 0}},
 1: {'total': 4,
  'resid': {'total': 4, 'Latents': 4, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 2: {'total': 10,
  'resid': {'total': 9, 'Latents': 9, 'Errors': 0},
  'mlp': {'total': 1, 'Latents': 1, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 3: {'total': 5,
  'resid': {'total': 5, 'Latents': 5, 'Errors': 0},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'total': 0, 'Latents': 0, 'Errors': 0}},
 4: {'total': 30,
  'resid': {'total': 26, 'Latents': 24, 'Errors': 2},
  'mlp': {'total': 3, 'Latents': 3, 'Errors': 0},
  'attn': {'total': 1, 'Latents': 1, 'Errors': 0}},
 5: {'total': 18,
  'resid': {'total': 18, 'Latents': 17, 'Errors': 1},
  'mlp': {'total': 0, 'Latents': 0, 'Errors': 0},
  'attn': {'tota