# Interpretune SAELens Tutorial

![Fine-Tuning Scheduler logo](logo_fts.png){height="55px" width="401px"}

### Intro

[Interpretune](https://github.com/speediedan/interpretune) is a flexible framework for exploring, analyzing and tuning llm world models. In this tutorial, we'll walk through a simple example of using Interpretune to pursue interpretability research with SAELens. As we'll see, Interpretune handles the required execution context composition, allowing us to use the same code in a variety of contexts, depending upon the level of abstraction required.

As a long-time PyTorch and PyTorch Lightning contributor, I've found the PyTorch Lightning framework is the right level of abstraction for a large variety of ML research contexts, but some contexts benefit from using core PyTorch directly. Additionally, some users may prefer to use the core PyTorch framework directly for a wide variety of reasons including maximizing portability. As will be demonstrated here, Interpretune maximizes flexibility and portability by adhering to a well-defined protocol that allows auto-composition of our research module with the adapters required for execution in a wide variety of contexts. In this example, we'll be executing the same module with core PyTorch and PyTorch Lightning, demonstrating the use of `SAELens` w/ Interpretune for interpretability research.

> Note - **this is a WIP**, but this is the core idea. If you have any feedback, please let me know!

## A note on memory usage

In these exercises, we'll be loading some pretty large models into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

<details>
<summary>See this dropdown for some functions which you might find helpful, and how to use them.</summary>

First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.

```python
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
part32_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘</pre>

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

```python
del gemma_2_2b
del gemma_2_2b_sae

THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except:
        pass

# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}

part32_utils.print_memory_status()
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66</pre>

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out `obj.shape` when `obj` is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted `gemma_2_2b`.

</details>

## Setup (don't read, just run)

In [1]:
import os
import gc
import itertools
import math
import random
import sys
import logging
from typing import Any, Dict, Optional, Tuple, List, Callable, Literal
from dataclasses import dataclass, field
from pprint import pformat
from pathlib import Path
from collections import Counter
from copy import deepcopy
from functools import partial

import evaluate
import datasets
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from torch import Tensor
from transformers import PreTrainedTokenizerBase
from transformers.tokenization_utils_base import BatchEncoding
from datasets.arrow_dataset import LazyDict
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
#from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from tabulate import tabulate
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, utils
from transformer_lens.hook_points import HookPoint
from sae_lens import SAE, HookedSAETransformer
from sae_lens.toolkit.pretrained_saes_directory import (
    #PretrainedSAELookup,
    get_pretrained_saes_directory,
    #get_repo_id_and_folder_name,
)

from it_examples.example_module_registry import MODULE_EXAMPLE_REGISTRY
from interpretune.adapters.transformer_lens import ITLensConfig
from interpretune.adapters.sae_lens import SAELensConfig
from interpretune.adapters.core import ITModule
from interpretune.base.call import it_init
from interpretune.base.config.datamodule import PromptConfig, ITDataModuleConfig
from interpretune.base.config.module import ITConfig
from interpretune.base.config.mixins import GenerativeClassificationConfig, BaseGenerationConfig, HFGenerationConfig
from interpretune.base.components.mixins import ProfilerHooksMixin
from interpretune.base.datamodules import ITDataModule
from interpretune.utils.logging import rank_zero_warn
from interpretune.utils.types import STEP_OUTPUT
from interpretune.utils.tokenization import _sanitize_input_name
from interpretune.adapters.transformer_lens import ITLensFromPretrainedConfig, ITLensCustomConfig
from interpretune.adapters.sae_lens import SAELensFromPretrainedConfig, SAELensCustomConfig
from interpretune.base.config.shared import Adapter
from interpretune.base.config.datamodule import ITDataModuleConfig
from interpretune.base.config.module import ITConfig
from interpretune.base.contract.session import ITSessionConfig, ITSession
from interpretune.utils.types import StrOrPath
from tests import seed_everything
from tests.utils import get_model_input_dtype
from base_defaults import  BaseCfg
from it_examples.notebooks.saelens_adapter_example.sl_utils import sl_mem_utils
#from interpretune.base.components.cli import IT_BASE

# TODO: re-enable cuda detection after initial cpu-only debugging
device = "cpu"
#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# EXAMPLE_DIR = Path(IT_BASE) / "experiments" / "notebooks" / "saelens_adapter_example"
# if str(EXAgenerative_classification_test_steps.path:
#     sys.path.append(str(EXAMPLE_DIR))

# import sl_utils.part31_tests as part31_tests
# import sl_utils.part31_utils as part31_utils
# import sl_utils.part32_tests as part32_tests
# import sl_utils.sl_mem_utils as part32_utils
# from sl_utils.plotly_utils import imshow, line

#MAIN = __name__ == "__main__"

/home/speediedan/repos/interpretune/src/interpretune/adapters/transformer_lens.py:130: Interpretune manages the HF model instantiation via `model_name_or_path`. Since `tokenizer_name was not provided, the value provided for `tl_cfg.cfg.tokenizer_name` will be used for `tokenizer_name`.
/home/speediedan/repos/interpretune/src/interpretune/adapters/transformer_lens.py:130: Interpretune manages the HF model instantiation via `model_name_or_path`. Since `tokenizer_name` was provided, `tl_cfg.cfg.tokenizer_name` will be ignored.


## Setup Our IT Module


In [2]:
log = logging.getLogger(__name__)

TASK_TEXT_FIELD_MAP = {"rte": ("premise", "hypothesis"), "boolq": ("passage", "question")}
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
DEFAULT_TASK = "rte"
INVALID_TASK_MSG = f" is an invalid task_name. Proceeding with the default task: {DEFAULT_TASK!r}"


@dataclass(kw_only=True)
class RTEBoolqEntailmentMapping:
    entailment_mapping: Tuple = ("Yes", "No")  # RTE style, invert mapping for BoolQ
    entailment_mapping_indices: Optional[torch.Tensor] = None


@dataclass(kw_only=True)
class RTEBoolqPromptConfig(PromptConfig):
    ctx_question_join: str = 'Does the previous passage imply that '
    question_suffix: str = '? Answer with only one word, either Yes or No.'
    cust_task_prompt: Optional[Dict[str, Any]] = None

#from interpretune.adapters.registration import register_user_composition

# if you want add custom attributes to module (ITConfig) or datamodule (ITDataModuleConfig) configs, you can have
# Interpretune generate the corresponding adapter configs like so:

# @dataclass(kw_only=True)
# class RTEBoolqConfig(RTEBoolqEntailmentMapping, ITConfig):
#     ...

#register_user_composition(RTEBoolqConfig, RTEBoolqEntailmentMapping, ITConfig)
# register_user_composition("RTEBoolqConfig", RTEBoolqEntailmentMapping, ITConfig)

# which is equivalent to manually generating the configs as follows:

# @dataclass(kw_only=True)
# class RTEBoolqConfig(RTEBoolqEntailmentMapping, ITConfig):
#     ...

# @dataclass(kw_only=True)
# class RTEBoolqConfig__transformer_lens(RTEBoolqEntailmentMapping, ITLensConfig):
#     ...

# @dataclass(kw_only=True)
# class RTEBoolqConfig__sae_lens(RTEBoolqEntailmentMapping, SAELensConfig):
#     ...


class RTEBoolqDataModule(ITDataModule):
    def __init__(self, itdm_cfg: ITDataModuleConfig) -> None:
        if itdm_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(itdm_cfg.task_name + INVALID_TASK_MSG)
            itdm_cfg.task_name = DEFAULT_TASK
        itdm_cfg.text_fields = TASK_TEXT_FIELD_MAP[itdm_cfg.task_name]
        super().__init__(itdm_cfg=itdm_cfg)

    def prepare_data(self, target_model: Optional[torch.nn.Module] = None) -> None:
        """Load the SuperGLUE dataset."""
        # N.B. prepare_data is called in a single process (rank 0, either per node or globally) so do not use it to
        # assign state (e.g. self.x=y)
        # note for raw pytorch we require a target_model
        # NOTE [HF Datasets Transformation Caching]:
        # HF Datasets' transformation cache fingerprinting algo necessitates construction of these partials as the hash
        # is generated using function args, dataset file, mapping args: https://bit.ly/HF_Datasets_fingerprint_algo)
        tokenization_func = partial(
            self.encode_for_rteboolq,
            tokenizer=self.tokenizer,
            text_fields=self.itdm_cfg.text_fields,
            prompt_cfg=self.itdm_cfg.prompt_cfg,
            template_fn=self.itdm_cfg.prompt_cfg.model_chat_template_fn,
            tokenization_pattern=self.itdm_cfg.cust_tokenization_pattern,
        )
        dataset = datasets.load_dataset("super_glue", self.itdm_cfg.task_name, trust_remote_code=True)
        for split in dataset.keys():
            dataset[split] = dataset[split].map(tokenization_func, **self.itdm_cfg.prepare_data_map_cfg)
            dataset[split] = self._remove_unused_columns(dataset[split], target_model)
        dataset.save_to_disk(self.itdm_cfg.dataset_path)

    def dataloader_factory(self, split: str, use_train_batch_size: bool = False) -> DataLoader:
        dataloader_kwargs = {"dataset": self.dataset[split], "collate_fn":self.data_collator,
                             **self.itdm_cfg.dataloader_kwargs}
        dataloader_kwargs['batch_size'] = self.itdm_cfg.train_batch_size if use_train_batch_size else \
            self.itdm_cfg.eval_batch_size
        return DataLoader(**dataloader_kwargs)

    # TODO: change to partialmethod's?
    def train_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='train', use_train_batch_size=True)

    def val_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def test_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def predict_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    #TODO: relax PreTrainedTokenizerBase to the protocol that is actually required
    @staticmethod
    def encode_for_rteboolq(example_batch: LazyDict, tokenizer: PreTrainedTokenizerBase, text_fields: List[str],
                            prompt_cfg: PromptConfig, template_fn: Callable,
                            tokenization_pattern: Optional[str] = None) -> BatchEncoding:
        example_batch['sequences'] = []
        # TODO: use promptsource instead of this manual approach after tinkering
        for field1, field2 in zip(example_batch[text_fields[0]],
                                  example_batch[text_fields[1]]):
            if prompt_cfg.cust_task_prompt:
                task_prompt = (prompt_cfg.cust_task_prompt['context'] + "\n" +
                               field1 + "\n" +
                               prompt_cfg.cust_task_prompt['question'] + "\n" +
                               field2)
            else:
                task_prompt = (field1 + prompt_cfg.ctx_question_join + field2 \
                               + prompt_cfg.question_suffix)
            sequence = template_fn(task_prompt=task_prompt, tokenization_pattern=tokenization_pattern)
            example_batch['sequences'].append(sequence)
        features = tokenizer.batch_encode_plus(example_batch["sequences"], padding="longest",
                                               padding_side=tokenizer.padding_side)
        features["labels"] = example_batch["label"]  # Rename label to labels, easier to pass to model forward
        features = _sanitize_input_name(tokenizer.model_input_names, features)
        return features

# TODO: maybe keep as tensors via cat for future analysis rather than separate batches
@dataclass(kw_only=True)
class LogitDiffsSumm:
    logit_diffs: list[torch.Tensor] = field(default_factory=list)
    answer_logits: list[torch.Tensor] = field(default_factory=list)
    loss: list[torch.Tensor] = field(default_factory=list)
    labels: list[torch.Tensor] = field(default_factory=list)
    orig_labels: list[torch.Tensor] = field(default_factory=list)
    preds: list[torch.Tensor] = field(default_factory=list)
    caches: list[ActivationCache] = field(default_factory=list)
    alive_latents: list[int] = field(default_factory=list)
    answer_indices: list[torch.Tensor] = field(default_factory=list)
    correct_activations: list[torch.Tensor] = field(default_factory=list)
    ablation_effects: list[torch.Tensor] = field(default_factory=list)
    tokens: list[torch.Tensor] = field(default_factory=list)
    prompts: list[str] = field(default_factory=list)

DEFAULT_DECODE_KWARGS = dict( skip_special_tokens=True, clean_up_tokenization_spaces=True)

logit_diff_summaries = {}

def boolean_logits_to_avg_logit_diff(
    logits: Float[Tensor, "batch seq 2"],
    target_indices: torch.Tensor,
    reduction: Literal["mean", "sum"] | None = "mean",
    keep_as_tensor: bool = False,
) -> list[float] | float:
    """
    Returns the avg logit diff on a set of prompts, with fixed s2 pos and stuff.
    """
    incorrect_indices = 1 - target_indices
    correct_logits = torch.gather(logits, 2, torch.reshape(target_indices, (-1,1,1))).squeeze()
    incorrect_logits = torch.gather(logits, 2, torch.reshape(incorrect_indices, (-1,1,1))).squeeze()
    logit_diff = correct_logits - incorrect_logits
    if reduction is not None:
        logit_diff = logit_diff.mean() if reduction == "mean" else logit_diff.sum()
    return logit_diff if keep_as_tensor else logit_diff.tolist()

def resolve_answer_indices(tokens):
    nonpadding_mask = tokens != 50256
    answer_indices = torch.where(nonpadding_mask, 1, 0).sum(dim=1) - 1
    return answer_indices

def batch_alive_latents(answer_indices, cache, hook_names):
    acts = cache[hook_names]
    alive_latents = (acts[torch.arange(acts.size(0)), answer_indices, :] > 0).any(dim=0).nonzero().squeeze().tolist()
    return alive_latents

def ablate_sae_latent(
    sae_acts: Tensor,
    hook: HookPoint,
    latent_idx: int | None = None,
    #seq_pos: int | None = None,
    seq_pos: torch.Tensor | None = None,  # batched
) -> Tensor:
    """
    Ablate a particular latent at a particular sequence position. If either argument is None, we ablate at all latents
    / sequence positions.
    """
    #sae_acts[:, seq_pos, latent_idx] = 0.0
    sae_acts[torch.arange(sae_acts.size(0)), seq_pos, latent_idx] = 0.0
    return sae_acts

class RTEBoolqModule(torch.nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # when using TransformerLens, we need to manually calculate our loss from logit output
        self.loss_fn = CrossEntropyLoss()

    def setup(self, *args, **kwargs) -> None:
        super().setup(*args, **kwargs)
        self._init_entailment_mapping()

    def _before_it_cfg_init(self, it_cfg: ITConfig) -> ITConfig:
        if it_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(it_cfg.task_name + INVALID_TASK_MSG)
            it_cfg.task_name = DEFAULT_TASK
        it_cfg.num_labels = 0 if it_cfg.generative_step_cfg.enabled else TASK_NUM_LABELS[it_cfg.task_name]
        return it_cfg

    def load_metric(self) -> None:
        self.metric = evaluate.load("super_glue", self.it_cfg.task_name,
                                    experiment_id=self._it_state._init_hparams['experiment_id'])

    def _init_entailment_mapping(self) -> None:
        ent_cfg, tokenizer = self.it_cfg, self.datamodule.tokenizer
        token_ids = tokenizer.convert_tokens_to_ids(ent_cfg.entailment_mapping)
        device = self.device if isinstance(self.device, torch.device) else self.output_device
        ent_cfg.entailment_mapping_indices = torch.tensor(token_ids).to(device)

    def labels_to_ids(self, labels: List[str]) -> List[int]:
        return torch.take(self.it_cfg.entailment_mapping_indices, labels), labels

    # TODO: add this helper function to SL adapter?
    def run_with_ctx(self, mode: str, batch: BatchEncoding, batch_idx: int, fwd_hooks_cfg: Optional[Dict] = None, 
                     hooks_filter: Optional[str]= None, **kwargs):
        if mode == 'clean':
            return self(**batch), None
        elif mode == 'cache_with_saes':
            return self.model.run_with_cache_with_saes(**batch, saes=self.sae_handles, names_filter=hooks_filter)
        elif mode == 'hooks_with_saes':
            assert fwd_hooks_cfg is not None
            per_latent_answer_logits = {}
            for latent_idx in tqdm(fwd_hooks_cfg['alive_latents'][batch_idx]):
                batch_answer_idxs = fwd_hooks_cfg['answer_indices'][batch_idx]
                #shape_summ = sl_mem_utils.summarize_cuda_tensors_by_shape()
                answer_logits = self.model.run_with_hooks_with_saes(
                    **batch,
                    saes=self.sae_handles,
                    #names_filter=hooks_filter,
                    clear_contexts=True,
                    fwd_hooks=[
                        (fwd_hooks_cfg['hook_names'],
                         partial(fwd_hooks_cfg['hook_fn'], latent_idx=latent_idx, seq_pos=batch_answer_idxs)
                        )
                    ],
                )
                per_latent_answer_logits[latent_idx] = answer_logits  # answer_logits.detach().cpu()
                #shape_summ_after = sl_mem_utils.summarize_cuda_tensors_by_shape()
                pass
            return per_latent_answer_logits, None
        else:
            return self.model.run_with_saes(**batch, saes=self.sae_handles)

    def ablation_logits_with_labels(
        self,
        batch: BatchEncoding,
        batch_idx: int,
        run_ctx: str,
        fwd_hooks_cfg: Dict
    ) -> torch.Tensor:
        label_ids, labels = self.labels_to_ids(batch.pop("labels"))
        batch_sz = batch['input'].size(0)
        answer_indices = fwd_hooks_cfg['answer_indices'][batch_idx]
        per_latent_logits, cache = self.run_with_ctx(run_ctx, batch, batch_idx, fwd_hooks_cfg=fwd_hooks_cfg)
        per_latent_logits = {k: v[torch.arange(batch_sz), answer_indices, :] for k, v in per_latent_logits.items()}
        #logits = self(**batch)
        # TODO: add another layer of abstraction here to handle different model output types? Tradeoffs to consider...
        # if not isinstance(logits, torch.Tensor):
        #     logits = logits.logits
        #     assert isinstance(logits, torch.Tensor), f"Expected logits to be a torch.Tensor but got {type(logits)}"
        return per_latent_logits, label_ids, labels, cache

    def logits_and_labels(self, batch: BatchEncoding, batch_idx: int, run_ctx: str, hooks_filter: Optional[str] = None) -> torch.Tensor:
        label_ids, labels = self.labels_to_ids(batch.pop("labels"))
        cache = None
        logits, cache = self.run_with_ctx(run_ctx, batch, batch_idx, hooks_filter=hooks_filter)
        #logits = self(**batch)
        # TODO: add another layer of abstraction here to handle different model output types? Tradeoffs to consider...
        if not isinstance(logits, torch.Tensor):
            logits = logits.logits
            assert isinstance(logits, torch.Tensor), f"Expected logits to be a torch.Tensor but got {type(logits)}"
        return torch.squeeze(logits[:, -1, :], dim=1), label_ids, labels, cache

    @ProfilerHooksMixin.memprofilable
    def training_step(self, batch: BatchEncoding, batch_idx: int) -> STEP_OUTPUT:
        # TODO: need to be explicit about the compatibility constraints/contract
        # TODO: note that this example uses generative_step_cfg and lm_head except for the test_step where we demo how to
        # use the GenerativeStepMixin to run inference with or without a generative_step_cfg enabled as well as with
        # different heads (e.g., seqclassification or LM head in this case)
        answer_logits, labels, *_ = self.logits_and_labels(batch, batch_idx)
        loss = self.loss_fn(answer_logits, labels)
        self.log("train_loss", loss, sync_dist=True)
        return loss

    @ProfilerHooksMixin.memprofilable
    def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        answer_logits, labels, orig_labels, cache = self.logits_and_labels(batch, batch_idx)
        val_loss = self.loss_fn(answer_logits, labels)
        self.log("val_loss", val_loss, prog_bar=True, sync_dist=True)
        self.collect_answers(answer_logits, orig_labels)

    @ProfilerHooksMixin.memprofilable
    def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        if self.it_cfg.generative_step_cfg.enabled:
            self.generative_classification_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)
        else:
            self.default_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)

    def generative_classification_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> \
        Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self.it_generate(batch, **self.it_cfg.generative_step_cfg.lm_generation_cfg.generate_kwargs)
        self.collect_answers(outputs.logits, labels)

    def save_summary(self, batch: BatchEncoding, summ_map: Dict, out_summ: Optional[LogitDiffsSumm] = None, 
                           save_prompts: Optional[bool] = False, save_tokens: bool = False, 
                           save_caches: bool = False, cache: Optional[ActivationCache] = None) -> None:
        if save_prompts:
            summ_map["prompts"] = self.datamodule.tokenizer.batch_decode(batch['input'], **DEFAULT_DECODE_KWARGS)
        if save_tokens:
            summ_map["tokens"] = batch['input'].detach().cpu()
        if save_caches:
            summ_map["caches"] = cache
        if out_summ:
            for key, val in summ_map.items():
                getattr(out_summ, key).append(val.detach().cpu() if isinstance(val, torch.Tensor) else val)
        else:
            return summ_map

    def ablation_test_step(self, batch: BatchEncoding, batch_idx: int, logit_diff_fn: Callable, 
                                 run_ctx: str = 'hooks_with_saes', fwd_hooks_cfg: Optional[Dict] = None,
                                 out_summ: Optional[LogitDiffsSumm] = None, save_caches: bool = False,
                                 save_prompts: bool = False, save_tokens: bool = False,
                                 dataloader_idx: int = 0, *args, **kwargs) -> \
        Optional[STEP_OUTPUT]:
        #torch.cuda.memory._dump_snapshot(Path(self.core_log_dir) / "before_first_ablation_step.pickle")
        # namespace = globals().copy() | locals()
        # sl_mem_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
        per_latent_answer_logits, labels, orig_labels, cache = self.ablation_logits_with_labels(
            batch, batch_idx, run_ctx=run_ctx, fwd_hooks_cfg=fwd_hooks_cfg
        )
        preds, per_latent_loss, per_latent_logit_diffs, per_latent_preds = None, {}, {}, {}
        #aggregate_activations = torch.zeros(self.sae_handles[0].cfg.d_sae)
        ablation_effects = torch.zeros(batch['input'].size(0), self.sae_handles[0].cfg.d_sae)
        # TODO: return per-latent cache, preds, logit_diffs and loss or just ablation_effects?
        for latent_idx, logits in per_latent_answer_logits.items():
            per_latent_loss[latent_idx] = self.loss_fn(logits, labels)
            logits = self.standardize_logits(logits)
            per_example_answers, _ = torch.max(logits, dim=-2)
            per_latent_preds[latent_idx] = torch.argmax(per_example_answers, axis=-1)
            logit_diffs = logit_diff_fn(logits, target_indices=orig_labels, reduction=None, keep_as_tensor=True)
            example_mask = (logit_diffs > 0).cpu()
            per_latent_logit_diffs[latent_idx] = logit_diffs[example_mask].detach().cpu()
            # for the edge case where the mask/saved batch tensors are scalars (usually due to a batch size of 1)
            for t in [example_mask, fwd_hooks_cfg['base_logit_diffs'][batch_idx]]:
                if t.dim() == 0:
                    t.unsqueeze_(0)
            ablation_effects[example_mask, latent_idx] = (
                fwd_hooks_cfg['base_logit_diffs'][batch_idx][example_mask] - per_latent_logit_diffs[latent_idx]
            )
            #self.model.reset_hooks(including_permanent=True)

        summ_map = {"loss": per_latent_loss, "logit_diffs": per_latent_logit_diffs, "labels": labels,
                    "orig_labels": orig_labels, "preds": per_latent_preds, "ablation_effects": ablation_effects}
        #torch.cuda.memory._dump_snapshot(Path(self.core_log_dir) / "after_first_ablation_step_before_save.pickle")
        self.save_summary(batch, summ_map, out_summ, save_prompts, save_tokens, save_caches, cache)
        #torch.cuda.memory._dump_snapshot(Path(self.core_log_dir) / "after_first_ablation_step_after_save.pickle")
        #pass

    def activation_cache_test_step(self, batch: BatchEncoding, batch_idx: int, logit_diff_fn: Callable, 
                                   hooks_filter: str, run_ctx: str = 'clean', out_summ: Optional[LogitDiffsSumm] = None,
                                   save_caches: bool = False, save_prompts: bool = False, save_tokens: bool = False,
                                   dataloader_idx: int = 0) -> \
        Optional[STEP_OUTPUT]:
        answer_logits, labels, orig_labels, cache = self.logits_and_labels(batch, batch_idx, run_ctx=run_ctx, hooks_filter=hooks_filter)
        loss = self.loss_fn(answer_logits, labels)
        answer_logits = self.standardize_logits(answer_logits)
        per_example_answers, _ = torch.max(answer_logits, dim=-2)
        preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
        logit_diffs = logit_diff_fn(answer_logits, target_indices=orig_labels, reduction=None, keep_as_tensor=True)
        for t in [logit_diffs]:
            if t.dim() == 0:
                t.unsqueeze_(0)
        summ_map = {"loss": loss, "logit_diffs": logit_diffs, "labels": labels, "orig_labels": orig_labels, 
                    "preds": preds, "answer_logits": answer_logits}
        if run_ctx == 'cache_with_saes':
            answer_indices = resolve_answer_indices(batch['input'].detach().cpu()) if \
                self.datamodule.tokenizer.padding_side == 'right' else torch.full((batch['input'].size(0),), -1)
            alive_latents = batch_alive_latents(answer_indices, cache, hooks_filter)
            correct_activations = cache[hooks_filter][(logit_diffs > 0), -1, :]
            #avg_correct_activation = .mean(dim=0)
            summ_map.update({"answer_indices": answer_indices, "alive_latents": alive_latents,
                             "correct_activations": correct_activations})
        self.save_summary(batch, summ_map, out_summ, save_prompts, save_tokens, save_caches, cache)

    def default_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        self.collect_answers(outputs.logits, labels)

    def predict_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        return self.collect_answers(outputs, labels, mode='return')

    def collect_answers(self, logits: torch.Tensor | tuple, labels: torch.Tensor, mode: str = 'log') -> Optional[Dict]:
        logits = self.standardize_logits(logits)
        per_example_answers, _ = torch.max(logits, dim=-2)
        preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
        metric_dict = self.metric.compute(predictions=preds, references=labels)
        # TODO: check if this type casting is still required for lightning torchmetrics, bug should be fixed now...
        metric_dict = dict(map(lambda x: (x[0], torch.tensor(x[1], device=self.device)
                                          .to(torch.float32)), 
                               metric_dict.items()))
        if mode == 'log':
            self.log_dict(metric_dict, prog_bar=True, sync_dist=True)
        else:
            return metric_dict

    def standardize_logits(self, logits: torch.Tensor) -> torch.Tensor:
        # to support generative classification/non-generative classification configs and LM/SeqClassification heads we
        # adhere to the following logits logical shape invariant:
        # [batch size, positions to consider, answers to consider]
        if isinstance(logits, tuple):
            logits = torch.stack([out for out in logits], dim=1)
        logits = logits.to(device=self.device)
        if logits.ndim == 2:  # if answer logits have already been squeezed
            logits = logits.unsqueeze(1)
        if logits.shape[-1] != self.it_cfg.num_labels:
            logits = torch.index_select(logits, -1, self.it_cfg.entailment_mapping_indices)
            if not self.it_cfg.generative_step_cfg.enabled:
                logits = logits[:, -1:, :]
        return logits


## Configure our IT Session


In [3]:
from interpretune.adapters.transformer_lens import TLensGenerationConfig
from interpretune.base.config.mixins import HFFromPretrainedConfig
from interpretune.adapters.transformer_lens import ITLensFromPretrainedNoProcessingConfig
from interpretune.base.config.shared import ITSharedConfig, AutoCompConfig

#torch.cuda.memory._record_memory_history()

shared_cfg = ITSharedConfig(model_name_or_path='gpt2', task_name='rte', tokenizer_id_overrides={'pad_token_id': 50256},
                  tokenizer_kwargs={'model_input_names': ['input'], 'padding_side': 'left', 'add_bos_token': True})
datamodule_cfg = ITDataModuleConfig(prompt_cfg=RTEBoolqPromptConfig(), train_batch_size=2, eval_batch_size=2,
                                    signature_columns=['input', 'labels'], prepare_data_map_cfg={"batched": True})
genclassif_cfg = GenerativeClassificationConfig(enabled=True, lm_generation_cfg=TLensGenerationConfig(max_new_tokens=1))
hf_cfg = HFFromPretrainedConfig(pretrained_kwargs={'torch_dtype': 'float32'}, model_head='transformers.GPT2LMHeadModel')
tl_cfg = ITLensFromPretrainedNoProcessingConfig(model_name="gpt2-small", default_padding_side='left')
# testing using attention SAEs release here instead of pre resid small res
# sae_cfgs = [
#     SAELensFromPretrainedConfig(release="gpt2-small-res-jb", sae_id="blocks.9.hook_resid_pre"),
#     SAELensFromPretrainedConfig(release="gpt2-small-res-jb", sae_id="blocks.10.hook_resid_pre"),
# ]
sae_cfgs = [
    SAELensFromPretrainedConfig(release="gpt2-small-hook-z-kk", sae_id="blocks.9.hook_z"),
    SAELensFromPretrainedConfig(release="gpt2-small-hook-z-kk", sae_id="blocks.10.hook_z"),
]
auto_comp_cfg = AutoCompConfig(module_cfg_name='RTEBoolqConfig', module_cfg_mixin=RTEBoolqEntailmentMapping)
module_cfg = ITConfig(auto_comp_cfg=auto_comp_cfg, generative_step_cfg=genclassif_cfg, hf_from_pretrained_cfg=hf_cfg,
                      tl_cfg=tl_cfg, sae_cfgs=sae_cfgs)

session_cfg = ITSessionConfig(adapter_ctx=(Adapter.core, Adapter.sae_lens),
                              datamodule_cls=RTEBoolqDataModule, module_cls=RTEBoolqModule,
                              shared_cfg=shared_cfg, datamodule_cfg=datamodule_cfg, module_cfg=module_cfg)
it_session = ITSession(session_cfg)
# TODO: maybe open a PR for the below
# https://github.com/jbloomAus/SAELens/blob/aa8f42bf06d9c68bb890f4881af0aac916ecd17c/sae_lens/sae.py#L144-L151 warning
# that inspects whether the loaded model has a default config override specified in ``pretrained_saes.yaml`` (e.g.
# 'gpt2-small-res-jb', config_overrides: model_from_pretrained_kwargs: center_writing_weights: true) and if so, avoids
# giving an arguably spurious warning to the user

/home/speediedan/repos/interpretune/src/interpretune/adapters/transformer_lens.py:282: Overriding `device_map` passed to TransformerLens to transform pretrained weights on cpu prior to moving the model to target device: None


Loaded pretrained model gpt2-small into HookedTransformer


  weights = torch.load(file_path, map_location=device)


In [4]:
# repr(module_cfg)

In [5]:
# prompt = "hello world"
# sl_test_module = it_session.module
# filter_sae_acts = lambda name: "hook_sae_acts_post" in name
# cache_dict = {"fwd": {}, "bwd": {}}


# def cache_hook(act, hook, dir: Literal["fwd", "bwd"]):
#     cache_dict[dir][hook.name] = act.detach()

# with sl_test_module.model.saes(saes=sl_test_module.sae_handles):
#     # We add hooks to cache values from the forward and backward pass respectively
#     with sl_test_module.model.hooks(
#         fwd_hooks=[(filter_sae_acts, partial(cache_hook, dir="fwd"))],
#         bwd_hooks=[(filter_sae_acts, partial(cache_hook, dir="bwd"))],
#     ):
#         # fill fwd/bwd cache, hooks then removed on cm exit
#         out = sl_test_module.model(prompt)
#         out[0, -1, 42].backward()

# cache_dict = {k: ActivationCache(cache_dict[k], sl_test_module.model) for k in cache_dict.keys()}
# for cache in cache_dict.values():
#     assert isinstance(cache, ActivationCache)
#     for sae_handle in sl_test_module.sae_handles:
#         assert sae_handle.name + ".hook_sae_acts_post" in cache

In [6]:
# torch.set_grad_enabled(False)

# pass

# gpt2: HookedSAETransformer = HookedSAETransformer.from_pretrained("gpt2-small", device=device)

# gpt2_sae, cfg_dict, sparsity = SAE.from_pretrained(
#     release="gpt2-small-res-jb",
#     sae_id="blocks.7.hook_resid_pre",
#     device=str(device),
# )

In [7]:
def display_dashboard(
    sae_release="gpt2-small-res-jb",
    sae_id="blocks.9.hook_resid_pre",
    latent_idx=0,
    width=800,
    height=600,
):
    release = get_pretrained_saes_directory()[sae_release]
    neuronpedia_id = release.neuronpedia_id[sae_id]

    url = f"https://neuronpedia.org/{neuronpedia_id}/{latent_idx}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

    print(url)
    display(IFrame(url, width=width, height=height))

sl_test_module = it_session.module
gpt2_sae = sl_test_module.saes[0].handle  # just inspect the first SAE for now
latent_idx = random.randint(0, gpt2_sae.cfg.d_sae)
#display_dashboard(latent_idx=latent_idx, sae_id=gpt2_sae.name)

In [8]:

# attn_saes = {
#     layer: SAE.from_pretrained(
#         "gpt2-small-hook-z-kk",
#         f"blocks.{layer}.hook_z",
#         device=str(device),
#     )[0]
#     for layer in range(sl_test_module.cfg.n_layers)
# }

layer = 9

# display_dashboard(
#     sae_release="gpt2-small-hook-z-kk",
#     sae_id=f"blocks.{layer}.hook_z",
#     latent_idx=2,  # or you can try `random.randint(0, attn_saes[layer].cfg.d_sae)`
# )

In [9]:
# TODO: consider instantiating a trainer here once initial exploration is done
# manually init IT components for now
it_init(**it_session)
assert sl_test_module.it_cfg.entailment_mapping_indices is not None

#from copy import deepcopy

# TODO: move to module
# TODO: constrain out_summ analysis typing appropriately once usage pattern stabilizes
def run_test_split_fn(
    module: ITModule,
    datamodule: ITDataModule,
    limit_test_batches: int,
    out_summ: Any,
    step_fn: str = "test_step",
    *args,
    **kwargs,
):
    dataloader = datamodule.test_dataloader()
    module._it_state._current_epoch = 0  # TODO: test removing this, prob not needed in this context
    #module.model.eval()
    step_func = getattr(module, step_fn)
    for batch_idx, batch in tqdm(enumerate(dataloader)):
        if batch_idx >= limit_test_batches >= 0:
            break
        batch = module.batch_to_device(batch)
        step_func(batch, batch_idx, out_summ=out_summ, *args, **kwargs)
        module.global_step += 1
    return out_summ

for handle in sl_test_module.sae_handles:
    handle.requires_grad_(False)
hook_sae_acts_post = f"{sl_test_module.sae_handles[0].cfg.hook_name}.hook_sae_acts_post"

test_cache_base_args = dict(
    limit_test_batches=-1, logit_diff_fn=boolean_logits_to_avg_logit_diff, hooks_filter=hook_sae_acts_post, **it_session
)

logit_diff_summs = {"sae_cache": LogitDiffsSumm(), "ablation": LogitDiffsSumm()}

#logit_diff_summs_sae_cache = LogitDiffsSumm()

# TODO: maybe combine prompt and tokens collection into a single function, wrapping clean, sae and ablated modes
run_clean_sanity_check = False

sl_test_module.model.requires_grad_(False)
#{n: p.requires_grad for n,p in sl_test_module.model.named_parameters()}
if run_clean_sanity_check:
    logit_diff_summs["clean"] = LogitDiffsSumm()
    #logit_diff_summs_clean = LogitDiffsSumm()
    run_test_split_fn(
        run_ctx="clean",
        step_fn="activation_cache_test_step",
        out_summ=logit_diff_summs["clean"],
        #out_summ=logit_diff_summs_clean,
        **test_cache_base_args,
    )

#gc.collect()
#torch.cuda.memory._dump_snapshot(Path(sl_test_module.core_log_dir) / "before_cache_with_saes_run.pickle")
#logit_diff_summs_sae_cache = LogitDiffsSumm()
run_test_split_fn(
    run_ctx="cache_with_saes",
    step_fn="activation_cache_test_step",
    out_summ=logit_diff_summs["sae_cache"],
    #out_summ=logit_diff_summs_sae_cache,
    save_prompts=True,
    save_tokens=True,
    #save_caches=True,
    **test_cache_base_args,
)

#namespace = globals().copy() | locals()
#sl_mem_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")

ablation_fwd_hook_cfg = {
    "hook_names": hook_sae_acts_post,
    "hook_fn": ablate_sae_latent,
    "alive_latents": logit_diff_summs["sae_cache"].alive_latents,
    "answer_indices": logit_diff_summs["sae_cache"].answer_indices,
    "base_logit_diffs": logit_diff_summs["sae_cache"].logit_diffs,
    # "alive_latents": deepcopy(logit_diff_summs["sae_cache"].alive_latents),
    # "answer_indices": deepcopy(logit_diff_summs["sae_cache"].answer_indices),
    # "base_logit_diffs": deepcopy(logit_diff_summs["sae_cache"].logit_diffs),
    # "alive_latents": deepcopy(logit_diff_summs_sae_cache.alive_latents),
    # "answer_indices": deepcopy(logit_diff_summs_sae_cache.answer_indices),
    # "base_logit_diffs": deepcopy(logit_diff_summs_sae_cache.logit_diffs),
}

#del logit_diff_summs["sae_cache"]
#del logit_diff_summs_sae_cache
#gc.collect()

#torch.cuda.memory._record_memory_history()
#torch.cuda.memory._dump_snapshot(Path(sl_test_module.core_log_dir) / "before_ablation_run.pickle")
#logit_diff_summs_ablation = LogitDiffsSumm()
run_test_split_fn(
    run_ctx="hooks_with_saes",
    step_fn="ablation_test_step",
    fwd_hooks_cfg=ablation_fwd_hook_cfg,
    out_summ=logit_diff_summs["ablation"],
    #out_summ=logit_diff_summs_ablation,
    **test_cache_base_args,
)


Map:   0%|          | 0/2490 [00:00<?, ? examples/s]

Map:   0%|          | 0/277 [00:00<?, ? examples/s]

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/2490 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/277 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3000 [00:00<?, ? examples/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/17 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/11 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/9 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/15 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

LogitDiffsSumm(logit_diffs=[{70: tensor([0.4429]), 4963: tensor([0.4689]), 9492: tensor([0.4424]), 17491: tensor([0.4424]), 17550: tensor([0.4424]), 18914: tensor([0.4424]), 21642: tensor([0.4424]), 21988: tensor([0.4363])}, {1750: tensor([0.2972, 0.6439]), 2102: tensor([0.2970, 0.6442]), 7482: tensor([0.2928, 0.6439]), 7823: tensor([0.2970, 0.6500]), 8721: tensor([0.2970, 0.6475]), 13212: tensor([0.2970, 0.6440]), 13701: tensor([0.2970, 0.6442]), 16325: tensor([0.2970, 0.6423]), 24551: tensor([0.2970, 0.6424])}, {4963: tensor([]), 6600: tensor([]), 7690: tensor([]), 12680: tensor([])}, {70: tensor([]), 1172: tensor([]), 3986: tensor([]), 4963: tensor([]), 7267: tensor([]), 11786: tensor([]), 12817: tensor([]), 13187: tensor([]), 17588: tensor([]), 21328: tensor([])}, {70: tensor([0.1190]), 1063: tensor([0.0987]), 2107: tensor([0.1192]), 4963: tensor([0.0734]), 5374: tensor([0.1192]), 7146: tensor([0.1192]), 12943: tensor([0.1069]), 17489: tensor([0.1269])}, {2421: tensor([0.3577]), 45

In [10]:
if run_clean_sanity_check:

    translated_labels = [
        sl_test_module.datamodule.tokenizer.batch_decode(labels, **DEFAULT_DECODE_KWARGS)
        for labels in logit_diff_summs["clean"].labels
    ]

    df = pd.DataFrame(
        {
            "prompt": logit_diff_summs["sae_cache"].prompts,
            "correct_answer": translated_labels,
            "clean_logit_diff": logit_diff_summs["clean"].logit_diffs,
            "sae_logit_diff": logit_diff_summs["sae_cache"].logit_diffs,
        }
    )
    df = df.explode(["prompt", "correct_answer", "clean_logit_diff", "sae_logit_diff"])
    df["sample_id"] = range(len(df))
    df = df[["sample_id", "prompt", "correct_answer", "clean_logit_diff", "sae_logit_diff"]]
    df = df[df.clean_logit_diff > 0].sort_values(by="clean_logit_diff", ascending=False)

    print(
        tabulate(
            df,
            headers=["Sample ID", "Prompt", "Answer", "Clean Logit Diff", "SAE Logit Diff"],
            maxcolwidths=[None, 80, None, None, None],
            tablefmt="grid",
            numalign="left",
            floatfmt="+.3f",
            showindex="never",
        )
    )


In [11]:
mode_correct = {}

for mode, summ in logit_diff_summs.items():
    if mode in ["clean", "sae_cache"]:
        correct_statuses = [(labels == preds) for labels, preds in zip(summ.orig_labels, summ.preds)]
        positive_logit_diff_statuses = [(logit_diffs > 0) for logit_diffs in summ.logit_diffs]
        assert all(torch.eq(torch.cat(correct_statuses), torch.cat(positive_logit_diff_statuses)))
        total_correct = sum(torch.cat(correct_statuses)).item()
        percentage_correct = total_correct / len(torch.cat(correct_statuses)) * 100
        mode_correct[mode] = (total_correct, percentage_correct)
    else:
        ablation_per_batch_preds = [torch.stack([p for p in pl.values()]).mode(dim=0).values.cpu() for pl in summ.preds]
        num_correct = [(labels == preds).nonzero().unique().size(0) for labels, preds in zip(summ.orig_labels, ablation_per_batch_preds)]
        total_correct = sum(num_correct)
        batch_size = summ.labels[0].size(0)
        percentage_correct = total_correct / (len(torch.cat(summ.orig_labels))) * 100
        mode_correct[mode] = (total_correct, percentage_correct)

table_rows = []
for mode, (total_correct, percentage_correct) in mode_correct.items():
    table_rows.append([mode, total_correct, f"{percentage_correct:.2f}%"])

print(tabulate(table_rows, headers=["Mode", "Total Correct", "Percentage Correct"], tablefmt="grid"))

+-----------+-----------------+----------------------+
| Mode      |   Total Correct | Percentage Correct   |
| sae_cache |             129 | 46.57%               |
+-----------+-----------------+----------------------+
| ablation  |             129 | 46.57%               |
+-----------+-----------------+----------------------+


In [12]:
generate_per_batch_effects_graphs = False

if generate_per_batch_effects_graphs:
    for i, ablation_effects in enumerate(logit_diff_summs["ablation"].ablation_effects):
        px.line(
            ablation_effects.mean(dim=0).cpu().numpy(),
            title=f"Causal effects of latent ablation on logit diff of batch {i} ({len(logit_diff_summs['sae_cache'].alive_latents[i])} alive)",
            labels={"index": "Latent", "value": "Causal effect on logit diff"},
            template="ggplot2",
            width=1000,
        ).update_layout(showlegend=False).show()

In [13]:

k = 10

correct_acts = torch.cat(logit_diff_summs["sae_cache"].correct_activations)
avg_correct_act, num_correct_act = correct_acts.mean(dim=0), correct_acts.count_nonzero(dim=0)
proportion_correct_active = num_correct_act / len(correct_acts)
correct_mask = torch.cat([(labels == preds) for labels, preds in zip(logit_diff_summs["ablation"].orig_labels, ablation_per_batch_preds)])
per_example_ablation_effects = torch.cat(logit_diff_summs["ablation"].ablation_effects)[correct_mask]
total_ablation_effects = per_example_ablation_effects.sum(dim=0)
avg_ablation_effects = per_example_ablation_effects.mean(dim=0)
topk_latents_by_total_ablation_positive = total_ablation_effects.topk(k)
topk_latents_by_total_ablation_negative = total_ablation_effects.topk(k, largest=False)

for i, total_summ in enumerate([topk_latents_by_total_ablation_positive, topk_latents_by_total_ablation_negative]):
    table_rows = []
    for value, ind in zip(*total_summ):
        table_rows.append([
            ind.item(),
            value.item(),
            avg_ablation_effects[ind].item(),
            avg_correct_act[ind].item(),
            num_correct_act[ind].item(),
            proportion_correct_active[ind].item(),
        ])
    print(tabulate(table_rows,
                headers=["Latent Index", f"Total {"Positive" if i == 0 else "Negative"} Effect", "Mean Effect", "Mean Activation", "Num Examples Active", 
                            "Proportion Correct Examples Active"], tablefmt="grid"))

# Print the top 3 positive and negative dashboards
dashboard_k = 3
topk_latents_by_total_ablation_positive = total_ablation_effects.topk(dashboard_k)
topk_latents_by_total_ablation_negative = total_ablation_effects.topk(dashboard_k, largest=False)
for i, topk_ablation_latents in enumerate([topk_latents_by_total_ablation_positive, topk_latents_by_total_ablation_negative]):
    for value, ind in zip(*topk_ablation_latents):
        print(f"#{ind} had total causal effect {value:.2f} and was active in {num_correct_act[ind]} examples")
        display_dashboard(
            sae_release="gpt2-small-hook-z-kk",
            sae_id=f"blocks.{layer}.hook_z",
            latent_idx=int(ind),
        )



+----------------+-------------------------+---------------+-------------------+-----------------------+--------------------------------------+
|   Latent Index |   Total Positive Effect |   Mean Effect |   Mean Activation |   Num Examples Active |   Proportion Correct Examples Active |
|           9501 |               0.177643  |   0.00137708  |         0.042603  |                     2 |                           0.0155039  |
+----------------+-------------------------+---------------+-------------------+-----------------------+--------------------------------------+
|          22308 |               0.0970917 |   0.000752649 |         0.0381489 |                     3 |                           0.0232558  |
+----------------+-------------------------+---------------+-------------------+-----------------------+--------------------------------------+
|            234 |               0.080368  |   0.000623008 |         0.0267507 |                     4 |                           0.031

#22308 had total causal effect 0.10 and was active in 3 examples
https://neuronpedia.org/gpt2-small/9-att-kk/22308?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


#234 had total causal effect 0.08 and was active in 4 examples
https://neuronpedia.org/gpt2-small/9-att-kk/234?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


#4963 had total causal effect -1.91 and was active in 81 examples
https://neuronpedia.org/gpt2-small/9-att-kk/4963?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


#4634 had total causal effect -0.12 and was active in 1 examples
https://neuronpedia.org/gpt2-small/9-att-kk/4634?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


#4692 had total causal effect -0.12 and was active in 1 examples
https://neuronpedia.org/gpt2-small/9-att-kk/4692?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300


In [14]:
# TODO: continue implementation here
#   - continue with attribution patching reproduction of latent effects! (on 2 batch subset, plot top latents for dataset)
#   - extend attribution patching to collect latent effects at two different layers and replot top latents
#   - apply created pipeline to gemma2 and other models!
#   - go back and extend ablation to iterate over multiple sae layers
#   - explore using a threshold positive logit diff for tuning
#   - though relative logits remain comparable, investigate impact of logit magnitude being substantially smaller with padding_side=left (should be able to re-run with padding_side=right)



In [15]:
# def get_cache_fwd_and_bwd(model: HookedSAETransformer, saes: list[SAE], input, metric):
#     """
#     Get forward and backward caches for a model, given a metric.
#     """
#     filter_sae_acts = lambda name: "hook_sae_acts_post" in name

#     # This hook function will store activations in the appropriate cache
#     cache_dict = {"fwd": {}, "bwd": {}}

#     def cache_hook(act, hook, dir: Literal["fwd", "bwd"]):
#         cache_dict[dir][hook.name] = act.detach()

#     with model.saes(saes=saes):
#         # We add hooks to cache values from the forward and backward pass respectively
#         with model.hooks(
#             fwd_hooks=[(filter_sae_acts, partial(cache_hook, dir="fwd"))],
#             bwd_hooks=[(filter_sae_acts, partial(cache_hook, dir="bwd"))],
#         ):
#             # Forward pass fills the fwd cache, then backward pass fills the bwd cache (we don't care about metric value)
#             _ = metric(model(input)).backward()

#     return (
#         ActivationCache(cache_dict["fwd"], model),
#         ActivationCache(cache_dict["bwd"], model),
#     )


# clean_logits = gpt2.run_with_saes(prompts, saes=[attn_saes[layer]])
# clean_logit_diff = logits_to_ave_logit_diff(clean_logits)

# torch.set_grad_enabled(True)
# clean_cache, clean_grad_cache = get_cache_fwd_and_bwd(
#     gpt2,
#     [attn_saes[layer]],
#     prompts,
#     lambda logits: logits_to_ave_logit_diff(logits, keep_as_tensor=True, reduction="sum"),
# )
# torch.set_grad_enabled(False)

# # Extract activations and gradients
# hook_sae_acts_post = f"{attn_saes[layer].cfg.hook_name}.hook_sae_acts_post"
# clean_sae_acts_post = clean_cache[hook_sae_acts_post]
# clean_grad_sae_acts_post = clean_grad_cache[hook_sae_acts_post]

# # Compute attribution values for all latents, then index to get live ones
# attribution_values = (clean_grad_sae_acts_post * clean_sae_acts_post)[:, s2_pos, alive_latents].mean(0)

# # Visualize results
# px.scatter(
#     pd.DataFrame(
#         {
#             "Ablation": ablation_effects[alive_latents].cpu().numpy(),
#             "Attribution Patching": attribution_values.cpu().numpy(),
#             "Latent": alive_latents,
#         }
#     ),
#     x="Ablation",
#     y="Attribution Patching",
#     hover_data=["Latent"],
#     title="Attribution Patching vs Ablation",
#     template="ggplot2",
#     width=800,
#     height=600,
# ).add_shape(
#     type="line",
#     x0=attribution_values.min(),
#     x1=attribution_values.max(),
#     y0=attribution_values.min(),
#     y1=attribution_values.max(),
#     line=dict(color="red", width=2, dash="dash"),
# ).show()