# Interpretune SAELens Tutorial

![Fine-Tuning Scheduler logo](logo_fts.png){height="55px" width="401px"}

### Intro

[Interpretune](https://github.com/speediedan/interpretune) is a flexible framework for exploring, analyzing and tuning llm world models. In this tutorial, we'll walk through a simple example of using Interpretune to pursue interpretability research with SAELens. As we'll see, Interpretune handles the required execution context composition, allowing us to use the same code in a variety of contexts, depending upon the level of abstraction required.

As a long-time PyTorch and PyTorch Lightning contributor, I've found the PyTorch Lightning framework is the right level of abstraction for a large variety of ML research contexts, but some contexts benefit from using core PyTorch directly. Additionally, some users may prefer to use the core PyTorch framework directly for a wide variety of reasons including maximizing portability. As will be demonstrated here, Interpretune maximizes flexibility and portability by adhering to a well-defined protocol that allows auto-composition of our research module with the adapters required for execution in a wide variety of contexts. In this example, we'll be executing the same module with core PyTorch and PyTorch Lightning, demonstrating the use of `SAELens` w/ Interpretune for interpretability research.

> Note - **this is a WIP**, but this is the core idea. If you have any feedback, please let me know!

## A note on memory usage

In these exercises, we'll be loading some pretty large models into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

<details>
<summary>See this dropdown for some functions which you might find helpful, and how to use them.</summary>

First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.

```python
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
part32_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘</pre>

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

```python
del gemma_2_2b
del gemma_2_2b_sae

THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except:
        pass

# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}

part32_utils.print_memory_status()
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66</pre>

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out `obj.shape` when `obj` is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted `gemma_2_2b`.

</details>

#### Imports

In [1]:
import logging
from typing import Any, Dict, Optional, Tuple, List, Callable, Literal
from dataclasses import dataclass
from functools import partial

import evaluate
import datasets
import pandas as pd
import plotly.express as px
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizerBase
from transformers.tokenization_utils_base import BatchEncoding
from datasets.arrow_dataset import LazyDict
from tabulate import tabulate
from tqdm.auto import tqdm
from transformer_lens import ActivationCache

from interpretune.adapters.core import ITModule
from interpretune.base.call import it_init
from interpretune.base.config.datamodule import PromptConfig, ITDataModuleConfig
from interpretune.base.config.module import ITConfig
from interpretune.base.config.mixins import GenerativeClassificationConfig
from interpretune.base.components.mixins import ProfilerHooksMixin
from interpretune.base.datamodules import ITDataModule
from interpretune.utils.logging import rank_zero_warn
from interpretune.utils.types import STEP_OUTPUT
from interpretune.utils.tokenization import _sanitize_input_name
from interpretune.adapters.sae_lens import SAELensFromPretrainedConfig
from interpretune.base.config.shared import Adapter
from interpretune.base.contract.session import ITSessionConfig, ITSession
from it_examples import _ACTIVE_PATCHES # noqa: F401  # TODO: add note about this unless patched in SL before release
from interpretune.utils.analysis import (SaveCfg, AnalysisCache, boolean_logits_to_avg_logit_diff, batch_alive_latents,
                                         ablate_sae_latent, resolve_answer_indices, DEFAULT_DECODE_KWARGS,
                                         display_dashboard)


### Define Our IT Data Module


In [2]:
log = logging.getLogger(__name__)

TASK_TEXT_FIELD_MAP = {"rte": ("premise", "hypothesis"), "boolq": ("passage", "question")}
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
DEFAULT_TASK = "rte"
INVALID_TASK_MSG = f" is an invalid task_name. Proceeding with the default task: {DEFAULT_TASK!r}"

class RTEBoolqDataModule(ITDataModule):
    def __init__(self, itdm_cfg: ITDataModuleConfig) -> None:
        if itdm_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(itdm_cfg.task_name + INVALID_TASK_MSG)
            itdm_cfg.task_name = DEFAULT_TASK
        itdm_cfg.text_fields = TASK_TEXT_FIELD_MAP[itdm_cfg.task_name]
        super().__init__(itdm_cfg=itdm_cfg)

    def prepare_data(self, target_model: Optional[torch.nn.Module] = None) -> None:
        """Load the SuperGLUE dataset."""
        # N.B. prepare_data is called in a single process (rank 0, either per node or globally) so do not use it to
        # assign state (e.g. self.x=y)
        # note for raw pytorch we require a target_model
        # NOTE [HF Datasets Transformation Caching]:
        # HF Datasets' transformation cache fingerprinting algo necessitates construction of these partials as the hash
        # is generated using function args, dataset file, mapping args: https://bit.ly/HF_Datasets_fingerprint_algo)
        tokenization_func = partial(
            self.encode_for_rteboolq,
            tokenizer=self.tokenizer,
            text_fields=self.itdm_cfg.text_fields,
            prompt_cfg=self.itdm_cfg.prompt_cfg,
            template_fn=self.itdm_cfg.prompt_cfg.model_chat_template_fn,
            tokenization_pattern=self.itdm_cfg.cust_tokenization_pattern,
        )
        dataset = datasets.load_dataset("super_glue", self.itdm_cfg.task_name, trust_remote_code=True)
        for split in dataset.keys():
            dataset[split] = dataset[split].map(tokenization_func, **self.itdm_cfg.prepare_data_map_cfg)
            dataset[split] = self._remove_unused_columns(dataset[split], target_model)
        dataset.save_to_disk(self.itdm_cfg.dataset_path)

    def dataloader_factory(self, split: str, use_train_batch_size: bool = False) -> DataLoader:
        dataloader_kwargs = {"dataset": self.dataset[split], "collate_fn":self.data_collator,
                             **self.itdm_cfg.dataloader_kwargs}
        dataloader_kwargs['batch_size'] = self.itdm_cfg.train_batch_size if use_train_batch_size else \
            self.itdm_cfg.eval_batch_size
        return DataLoader(**dataloader_kwargs)

    # TODO: change to partialmethod's?
    def train_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='train', use_train_batch_size=True)

    def val_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def test_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def predict_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    #TODO: relax PreTrainedTokenizerBase to the protocol that is actually required
    @staticmethod
    def encode_for_rteboolq(example_batch: LazyDict, tokenizer: PreTrainedTokenizerBase, text_fields: List[str],
                            prompt_cfg: PromptConfig, template_fn: Callable,
                            tokenization_pattern: Optional[str] = None) -> BatchEncoding:
        example_batch['sequences'] = []
        # TODO: use promptsource instead of this manual approach after tinkering
        for field1, field2 in zip(example_batch[text_fields[0]],
                                  example_batch[text_fields[1]]):
            if prompt_cfg.cust_task_prompt:
                task_prompt = (prompt_cfg.cust_task_prompt['context'] + "\n" +
                               field1 + "\n" +
                               prompt_cfg.cust_task_prompt['question'] + "\n" +
                               field2)
            else:
                task_prompt = (field1 + prompt_cfg.ctx_question_join + field2 \
                               + prompt_cfg.question_suffix)
            sequence = template_fn(task_prompt=task_prompt, tokenization_pattern=tokenization_pattern)
            example_batch['sequences'].append(sequence)
        features = tokenizer.batch_encode_plus(example_batch["sequences"], padding="longest",
                                               padding_side=tokenizer.padding_side)
        features["labels"] = example_batch["label"]  # Rename label to labels, easier to pass to model forward
        features = _sanitize_input_name(tokenizer.model_input_names, features)
        return features

### Define Our IT Module

In [None]:

@dataclass(kw_only=True)
class RTEBoolqEntailmentMapping:
    entailment_mapping: Tuple = ("Yes", "No")  # RTE style, invert mapping for BoolQ
    entailment_mapping_indices: Optional[torch.Tensor] = None


@dataclass(kw_only=True)
class RTEBoolqPromptConfig(PromptConfig):
    ctx_question_join: str = 'Does the previous passage imply that '
    question_suffix: str = '? Answer with only one word, either Yes or No.'
    cust_task_prompt: Optional[Dict[str, Any]] = None


class RTEBoolqModule(torch.nn.Module):

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # when using TransformerLens, we need to manually calculate our loss from logit output
        self.loss_fn = CrossEntropyLoss()

    def setup(self, *args, **kwargs) -> None:
        super().setup(*args, **kwargs)
        self._init_entailment_mapping()

    def _before_it_cfg_init(self, it_cfg: ITConfig) -> ITConfig:
        if it_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(it_cfg.task_name + INVALID_TASK_MSG)
            it_cfg.task_name = DEFAULT_TASK
        it_cfg.num_labels = 0 if it_cfg.generative_step_cfg.enabled else TASK_NUM_LABELS[it_cfg.task_name]
        return it_cfg

    def load_metric(self) -> None:
        self.metric = evaluate.load("super_glue", self.it_cfg.task_name,
                                    experiment_id=self._it_state._init_hparams['experiment_id'])

    def _init_entailment_mapping(self) -> None:
        ent_cfg, tokenizer = self.it_cfg, self.datamodule.tokenizer
        token_ids = tokenizer.convert_tokens_to_ids(ent_cfg.entailment_mapping)
        device = self.device if isinstance(self.device, torch.device) else self.output_device
        ent_cfg.entailment_mapping_indices = torch.tensor(token_ids).to(device)

    def labels_to_ids(self, labels: List[str]) -> List[int]:
        return torch.take(self.it_cfg.entailment_mapping_indices, labels), labels

    @ProfilerHooksMixin.memprofilable
    def training_step(self, batch: BatchEncoding, batch_idx: int) -> STEP_OUTPUT:
        # TODO: need to be explicit about the compatibility constraints/contract
        # TODO: note that this example uses generative_step_cfg and lm_head except for the test_step where we demo how
        # to use the GenerativeStepMixin to run inference with or without a generative_step_cfg enabled as well as with
        # different heads (e.g., seqclassification or LM head in this case)
        answer_logits, labels, *_ = self.logits_and_labels(batch, batch_idx)
        loss = self.loss_fn(answer_logits, labels)
        self.log("train_loss", loss, sync_dist=True)
        return loss

    @ProfilerHooksMixin.memprofilable
    def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        answer_logits, labels, orig_labels, cache = self.logits_and_labels(batch, batch_idx)
        val_loss = self.loss_fn(answer_logits, labels)
        self.log("val_loss", val_loss, prog_bar=True, sync_dist=True)
        self.collect_answers(answer_logits, orig_labels)

    @ProfilerHooksMixin.memprofilable
    def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        if self.it_cfg.generative_step_cfg.enabled:
            self.generative_classification_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)
        else:
            self.default_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)

    def generative_classification_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> \
        Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self.it_generate(batch, **self.it_cfg.generative_step_cfg.lm_generation_cfg.generate_kwargs)
        self.collect_answers(outputs.logits, labels)

    def default_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        self.collect_answers(outputs.logits, labels)

    def predict_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        return self.collect_answers(outputs, labels, mode='return')

    def collect_answers(self, logits: torch.Tensor | tuple, labels: torch.Tensor, mode: str = 'log') -> Optional[Dict]:
        logits = self.standardize_logits(logits)
        per_example_answers, _ = torch.max(logits, dim=-2)
        preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
        metric_dict = self.metric.compute(predictions=preds, references=labels)
        # TODO: check if this type casting is still required for lightning torchmetrics, bug should be fixed now...
        metric_dict = dict(map(lambda x: (x[0], torch.tensor(x[1], device=self.device)
                                          .to(torch.float32)),
                               metric_dict.items()))
        if mode == 'log':
            self.log_dict(metric_dict, prog_bar=True, sync_dist=True)
        else:
            return metric_dict

    def standardize_logits(self, logits: torch.Tensor) -> torch.Tensor:
        # to support generative classification/non-generative classification configs and LM/SeqClassification heads we
        # adhere to the following logits logical shape invariant:
        # [batch size, positions to consider, answers to consider]
        if isinstance(logits, tuple):
            logits = torch.stack([out for out in logits], dim=1)
        logits = logits.to(device=self.device)
        if logits.ndim == 2:  # if answer logits have already been squeezed
            logits = logits.unsqueeze(1)
        if logits.shape[-1] != self.it_cfg.num_labels:
            logits = torch.index_select(logits, -1, self.it_cfg.entailment_mapping_indices)
            if not self.it_cfg.generative_step_cfg.enabled:
                logits = logits[:, -1:, :]
        return logits

    def per_latent_answer_logits(self, batch: BatchEncoding, batch_idx: int, fwd_hooks_cfg: Dict) -> dict:
        assert fwd_hooks_cfg is not None
        per_latent_answer_logits = {}
        for latent_idx in tqdm(fwd_hooks_cfg['alive_latents'][batch_idx]):
            batch_answer_idxs = fwd_hooks_cfg['answer_indices'][batch_idx]
            answer_logits = self.model.run_with_hooks_with_saes(
                **batch,
                saes=self.sae_handles,
                clear_contexts=True,
                fwd_hooks=[
                    (fwd_hooks_cfg['hook_names'],
                        partial(fwd_hooks_cfg['hook_fn'], latent_idx=latent_idx, seq_pos=batch_answer_idxs)
                    )
                ],
            )
            per_latent_answer_logits[latent_idx] = answer_logits
        return per_latent_answer_logits, None

    def attr_patch_logits(self, batch: BatchEncoding, batch_idx: int, logit_diff_fn: Callable, hooks_cfg: dict,
                          label_ids: torch.Tensor, orig_labels: torch.Tensor, answer_indices: torch.Tensor) \
                            -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, Any]:
        assert hooks_cfg is not None
        with self.model.saes(saes=self.sae_handles):
            # We add hooks to cache values from the forward and backward pass respectively
            with self.model.hooks(
                fwd_hooks=hooks_cfg['fwd_hooks'],
                bwd_hooks=hooks_cfg['bwd_hooks'],
            ):
                # Forward pass fills the fwd cache, then backward pass fills the bwd cache
                answer_logits = self.model(**batch)
                answer_logits = answer_logits[torch.arange(batch['input'].size(0)), answer_indices]
                answer_logits = self.standardize_logits(answer_logits)
                per_example_answers, _ = torch.max(answer_logits, dim=-2)
                preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
                logit_diffs = logit_diff_fn(logits=answer_logits, target_indices=orig_labels)
                #loss = self.loss_fn(answer_logits, labels)
                #answer_logits = self.standardize_logits(answer_logits)
                #per_example_answers, _ = torch.max(answer_logits, dim=-2)
                #preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
                logit_diffs.sum().backward()
                if logit_diffs.dim() == 0:
                    logit_diffs.unsqueeze_(0)
                # TODO: return only answer_indices and alive_latent cache indices for each cache
        return (answer_logits, logit_diffs, preds, label_ids, orig_labels,
                ActivationCache(hooks_cfg["cache_dict"]["fwd"], self.model),
                ActivationCache(hooks_cfg["cache_dict"]["bwd"], self.model),)

    # TODO: add this helper function to SL adapter?
    def run_with_ctx(self, mode: str, batch: BatchEncoding, batch_idx: int, hook_names: Optional[str]= None, **kwargs):
        if mode == 'clean':
            return self(**batch), None
        elif mode == 'cache_with_saes':
            return self.model.run_with_cache_with_saes(**batch, saes=self.sae_handles, names_filter=hook_names)
        else:
            return self.model.run_with_saes(**batch, saes=self.sae_handles)

    def ablation_logits_with_labels(self, batch: BatchEncoding, batch_idx: int, fwd_hooks_cfg: Dict) \
            -> tuple[dict[Any, torch.Tensor], torch.Tensor, torch.Tensor, Any]:
        label_ids, labels = self.labels_to_ids(batch.pop("labels"))
        batch_sz = batch['input'].size(0)
        answer_indices = fwd_hooks_cfg['answer_indices'][batch_idx]
        per_latent_logits, cache = self.per_latent_answer_logits(batch, batch_idx, fwd_hooks_cfg=fwd_hooks_cfg)
        per_latent_logits = {k: v[torch.arange(batch_sz), answer_indices, :] for k, v in per_latent_logits.items()}
        return per_latent_logits, label_ids, labels, cache

    def logits_and_labels(self, batch: BatchEncoding, batch_idx: int, run_ctx: str,
                          hook_names: Optional[str] = None) -> torch.Tensor:
        label_ids, labels = self.labels_to_ids(batch.pop("labels"))
        cache = None
        logits, cache = self.run_with_ctx(run_ctx, batch, batch_idx, hook_names=hook_names)
        # TODO: add another layer of abstraction here to handle different model output types? Tradeoffs to consider...
        if not isinstance(logits, torch.Tensor):
            logits = logits.logits
            assert isinstance(logits, torch.Tensor), f"Expected logits to be a torch.Tensor but got {type(logits)}"
        return torch.squeeze(logits[:, -1, :], dim=1), label_ids, labels, cache

    def ablation_test_step(self, batch: BatchEncoding, batch_idx: int, analysis_cache: AnalysisCache,
                           logit_diff_fn: Callable, fwd_hooks_cfg: Optional[Dict] = None,
                           dataloader_idx: int = 0, *args, **kwargs) -> \
        Optional[STEP_OUTPUT]:
        per_latent_answer_logits, labels, orig_labels, cache = self.ablation_logits_with_labels(
            batch, batch_idx, fwd_hooks_cfg=fwd_hooks_cfg)
        per_latent_loss, per_latent_logit_diffs, per_latent_preds = {}, {}, {}
        ablation_effects = torch.zeros(batch['input'].size(0), self.sae_handles[0].cfg.d_sae)
        # TODO: return per-latent cache, preds, logit_diffs and loss or just ablation_effects?
        for latent_idx, logits in per_latent_answer_logits.items():
            per_latent_loss[latent_idx] = self.loss_fn(logits, labels)
            logits = self.standardize_logits(logits)
            per_example_answers, _ = torch.max(logits, dim=-2)
            per_latent_preds[latent_idx] = torch.argmax(per_example_answers, axis=-1)
            logit_diffs = logit_diff_fn(logits, target_indices=orig_labels, reduction=None, keep_as_tensor=True)
            example_mask = (logit_diffs > 0).cpu()
            per_latent_logit_diffs[latent_idx] = logit_diffs[example_mask].detach().cpu()
            # for the edge case where the mask/saved batch tensors are scalars (usually due to a batch size of 1)
            for t in [example_mask, fwd_hooks_cfg['base_logit_diffs'][batch_idx]]:
                if t.dim() == 0:
                    t.unsqueeze_(0)
            ablation_effects[example_mask, latent_idx] = (
                fwd_hooks_cfg['base_logit_diffs'][batch_idx][example_mask] - per_latent_logit_diffs[latent_idx]
            )
        step_summ = {"loss": per_latent_loss, "logit_diffs": per_latent_logit_diffs, "labels": labels,
                    "orig_labels": orig_labels, "preds": per_latent_preds, "ablation_effects": ablation_effects}
        analysis_cache.save(step_summ, batch, cache, tokenizer=self.datamodule.tokenizer)

    def attribution_patch_test_step(self, batch: BatchEncoding, batch_idx: int, analysis_cache: AnalysisCache,
                                    hooks_cfg: Dict, logit_diff_fn: Callable,
                                    dataloader_idx: int = 0, *args, **kwargs) -> Optional[STEP_OUTPUT]:
        label_ids, orig_labels = self.labels_to_ids(batch.pop("labels"))
        answer_indices = resolve_answer_indices(batch['input'].detach().cpu()) if \
            self.datamodule.tokenizer.padding_side == 'right' else torch.full((batch['input'].size(0),), -1)
        answer_logits, logit_diffs, preds, labels, orig_labels, cache, grad_cache = (
            self.attr_patch_logits(
                batch, batch_idx, logit_diff_fn, hooks_cfg, label_ids, orig_labels, answer_indices
            )
        )
        attribution_values = torch.zeros(batch['input'].size(0), self.sae_handles[0].cfg.d_sae)
        # TODO: consider allowing passing these values from previous sae_cache run if provided to avoid recomputing
        alive_latents = batch_alive_latents(answer_indices, cache, hooks_cfg['hook_names'])
        batch_sae_acts_post = cache[hooks_cfg['hook_names']][torch.arange(batch['input'].size(0)), answer_indices]
        batch_grad_sae_acts_post = grad_cache[hooks_cfg['hook_names']][torch.arange(batch['input'].size(0)),
                                                                       answer_indices]
        for t in [batch_sae_acts_post, batch_grad_sae_acts_post]:
            if t.dim() == 2:
                t.unsqueeze_(1)
        correct_activations = torch.squeeze(batch_sae_acts_post[(logit_diffs > 0), :, :], dim=1)
        # # Compute attribution values for all latents, then index to get live ones
        attribution_values[:, alive_latents] = torch.squeeze(
            (batch_grad_sae_acts_post[:, :, alive_latents] * batch_sae_acts_post[:, :, alive_latents]).cpu(), dim=1
        )
        step_summ = {
            "attribution_values": attribution_values, "labels": labels, "orig_labels": orig_labels, #"loss": loss,
            "correct_activations": correct_activations, "alive_latents": alive_latents,
            "answer_logits": answer_logits, "logit_diffs": logit_diffs, "preds": preds}
        analysis_cache.save(step_summ, batch, cache, grad_cache, tokenizer=self.datamodule.tokenizer)

    def activation_cache_test_step(self, batch: BatchEncoding, batch_idx: int, analysis_cache: AnalysisCache,
                                   logit_diff_fn: Callable,
                                   hook_names: str, run_ctx: str = 'clean',
                                   dataloader_idx: int = 0) -> \
        Optional[STEP_OUTPUT]:
        answer_logits, labels, orig_labels, cache = self.logits_and_labels(batch, batch_idx, run_ctx=run_ctx,
                                                                           hook_names=hook_names)
        loss = self.loss_fn(answer_logits, labels)
        answer_logits = self.standardize_logits(answer_logits)
        per_example_answers, _ = torch.max(answer_logits, dim=-2)
        preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
        logit_diffs = logit_diff_fn(answer_logits, target_indices=orig_labels, reduction=None, keep_as_tensor=True)
        for t in [logit_diffs]:
            if t.dim() == 0:
                t.unsqueeze_(0)
        step_summ = {"loss": loss, "logit_diffs": logit_diffs, "labels": labels, "orig_labels": orig_labels,
                    "preds": preds, "answer_logits": answer_logits}
        if run_ctx == 'cache_with_saes':
            answer_indices = resolve_answer_indices(batch['input'].detach().cpu()) if \
                self.datamodule.tokenizer.padding_side == 'right' else torch.full((batch['input'].size(0),), -1)
            alive_latents = batch_alive_latents(answer_indices, cache, hook_names)
            correct_activations = cache[hook_names][(logit_diffs > 0), -1, :]
            step_summ.update({"answer_indices": answer_indices, "alive_latents": alive_latents,
                             "correct_activations": correct_activations})
        analysis_cache.save(step_summ, batch, cache, tokenizer=self.datamodule.tokenizer)


### Configure our IT Session


In [None]:
from interpretune.adapters.transformer_lens import TLensGenerationConfig
from interpretune.base.config.mixins import HFFromPretrainedConfig
from interpretune.adapters.transformer_lens import ITLensFromPretrainedNoProcessingConfig
from interpretune.base.config.shared import ITSharedConfig, AutoCompConfig


shared_cfg = ITSharedConfig(model_name_or_path='gpt2', task_name='rte', tokenizer_id_overrides={'pad_token_id': 50256},
                  tokenizer_kwargs={'model_input_names': ['input'], 'padding_side': 'left', 'add_bos_token': True})
datamodule_cfg = ITDataModuleConfig(prompt_cfg=RTEBoolqPromptConfig(), train_batch_size=2, eval_batch_size=2,
                                    signature_columns=['input', 'labels'], prepare_data_map_cfg={"batched": True})
genclassif_cfg = GenerativeClassificationConfig(enabled=True, lm_generation_cfg=TLensGenerationConfig(max_new_tokens=1))
hf_cfg = HFFromPretrainedConfig(pretrained_kwargs={'torch_dtype': 'float32'}, model_head='transformers.GPT2LMHeadModel')
tl_cfg = ITLensFromPretrainedNoProcessingConfig(model_name="gpt2-small", default_padding_side='left')
# testing using attention SAEs release here instead of pre resid small res
# sae_cfgs = [
#     SAELensFromPretrainedConfig(release="gpt2-small-res-jb", sae_id="blocks.9.hook_resid_pre"),
#     SAELensFromPretrainedConfig(release="gpt2-small-res-jb", sae_id="blocks.10.hook_resid_pre"),
# ]
sae_cfgs = [
    SAELensFromPretrainedConfig(release="gpt2-small-hook-z-kk", sae_id="blocks.9.hook_z"),
    SAELensFromPretrainedConfig(release="gpt2-small-hook-z-kk", sae_id="blocks.10.hook_z"),
]
auto_comp_cfg = AutoCompConfig(module_cfg_name='RTEBoolqConfig', module_cfg_mixin=RTEBoolqEntailmentMapping)
module_cfg = ITConfig(auto_comp_cfg=auto_comp_cfg, generative_step_cfg=genclassif_cfg, hf_from_pretrained_cfg=hf_cfg,
                      tl_cfg=tl_cfg, sae_cfgs=sae_cfgs)

session_cfg = ITSessionConfig(adapter_ctx=(Adapter.core, Adapter.sae_lens),
                              datamodule_cls=RTEBoolqDataModule, module_cls=RTEBoolqModule,
                              shared_cfg=shared_cfg, datamodule_cfg=datamodule_cfg, module_cfg=module_cfg)
it_session = ITSession(session_cfg)
# TODO: maybe open a PR for the below
# https://github.com/jbloomAus/SAELens/blob/aa8f42bf06d9c68bb890f4881af0aac916ecd17c/sae_lens/sae.py#L144-L151 warning
# that inspects whether the loaded model has a default config override specified in ``pretrained_saes.yaml`` (e.g.
# 'gpt2-small-res-jb', config_overrides: model_from_pretrained_kwargs: center_writing_weights: true) and if so, avoids
# giving an arguably spurious warning to the user

In [None]:
# TODO: consider instantiating a trainer here once initial exploration is done
# manually init IT components for now
it_init(**it_session)
sl_test_module = it_session.module
assert sl_test_module.it_cfg.entailment_mapping_indices is not None

layer = 9

# TODO: move to module
def run_test_split_fn(
    module: ITModule,
    datamodule: ITDataModule,
    limit_test_batches: int,
    analysis_cache: Optional[AnalysisCache] = None,
    step_fn: str = "test_step",
    *args,
    **kwargs,
):
    dataloader = datamodule.test_dataloader()
    analysis_cache = analysis_cache or AnalysisCache()
    module._it_state._current_epoch = 0  # TODO: test removing this, prob not needed in this context
    step_func = getattr(module, step_fn)
    for batch_idx, batch in tqdm(enumerate(dataloader)):
        if batch_idx >= limit_test_batches >= 0:
            break
        batch = module.batch_to_device(batch)
        step_func(batch, batch_idx, analysis_cache=analysis_cache, *args, **kwargs)
        module.global_step += 1
    return analysis_cache

hook_sae_acts_post = f"{sl_test_module.sae_handles[0].cfg.hook_name}.hook_sae_acts_post"

no_reduce_bool_logit_diff_fn = partial(boolean_logits_to_avg_logit_diff, reduction=None)
sum_reduce_bool_logit_diff_fn = partial(boolean_logits_to_avg_logit_diff, reduction="sum")


#TODO: make hook_names a list of strings instead of a single string and support it in downstream functions
test_cache_base_args = dict(
    #limit_test_batches=-1, hook_names=hook_sae_acts_post, **it_session
    limit_test_batches=2, hook_names=hook_sae_acts_post, **it_session
)

torch.set_grad_enabled(False)

logit_diff_summs = {}

# TODO: maybe combine prompt and tokens collection into a single function, wrapping clean, sae and ablated modes

run_clean_sae_cache= True
if run_clean_sae_cache:
    logit_diff_summs["sae_cache"] = AnalysisCache(save_cfg=SaveCfg(prompts=True, tokens=True))
    run_test_split_fn(
        run_ctx="cache_with_saes",
        step_fn="activation_cache_test_step",
        analysis_cache=logit_diff_summs["sae_cache"],
        logit_diff_fn=no_reduce_bool_logit_diff_fn,
        **test_cache_base_args,
    )

# TODO: after experimentation clean this toggles up
run_clean_sanity_check = True
if run_clean_sanity_check:
    assert run_clean_sae_cache, "Need to run clean SAE cache first to get the base logit diffs and other cache values"
    logit_diff_summs["clean"] = AnalysisCache()
    run_test_split_fn(
        run_ctx="clean",
        step_fn="activation_cache_test_step",
        analysis_cache=logit_diff_summs["clean"],
        logit_diff_fn=no_reduce_bool_logit_diff_fn,
        **test_cache_base_args,
    )

run_ablation_sanity_check = True
if run_ablation_sanity_check:
    assert run_clean_sae_cache, "Need to run clean SAE cache first to get the base logit diffs and other cache values"
    ablation_fwd_hook_cfg = {
        "hook_names": hook_sae_acts_post,
        "hook_fn": ablate_sae_latent,
        "alive_latents": logit_diff_summs["sae_cache"].alive_latents,
        "answer_indices": logit_diff_summs["sae_cache"].answer_indices,
        "base_logit_diffs": logit_diff_summs["sae_cache"].logit_diffs,
    }

    logit_diff_summs["ablation"] = AnalysisCache()
    run_test_split_fn(
        run_ctx="hooks_with_saes",
        step_fn="ablation_test_step",
        fwd_hooks_cfg=ablation_fwd_hook_cfg,
        analysis_cache=logit_diff_summs["ablation"],
        logit_diff_fn=no_reduce_bool_logit_diff_fn,
        **test_cache_base_args,
    )

run_attr_patching = True

if run_attr_patching:
    torch.set_grad_enabled(True)
    # assert torch.is_grad_enabled()
    def cache_hook(act, hook, dir: Literal["fwd", "bwd"], cache_dict: dict):
        cache_dict[dir][hook.name] = act.detach()

    hooks_lambda = lambda name: "hook_sae_acts_post" in name

    cache_dict = {"fwd": {}, "bwd": {}}
    default_attr_patch_hook_cfg = {"hook_names": hook_sae_acts_post, "cache_dict": cache_dict}

    prompts_tokens_flags = {"prompts": True, "tokens": True} if not run_clean_sae_cache else {}
    # TODO: test multiple layer analysis via hooks_lambda instead of a specific hook name
    hooks_cfg = {**default_attr_patch_hook_cfg}
    hooks_cfg['fwd_hooks'] = [(hook_sae_acts_post, partial(cache_hook, dir="fwd", cache_dict=hooks_cfg["cache_dict"]))]
    hooks_cfg['bwd_hooks'] = [(hook_sae_acts_post, partial(cache_hook, dir="bwd", cache_dict=hooks_cfg["cache_dict"]))]
    logit_diff_summs["attr_patching"] = AnalysisCache(save_cfg=SaveCfg(**prompts_tokens_flags))
    run_test_split_fn(
        step_fn="attribution_patch_test_step",
        hooks_cfg=hooks_cfg,
        analysis_cache=logit_diff_summs["attr_patching"],
        logit_diff_fn=no_reduce_bool_logit_diff_fn,
        **test_cache_base_args,
    )
    torch.set_grad_enabled(False)

In [None]:
if run_clean_sanity_check:

    translated_labels = [
        sl_test_module.datamodule.tokenizer.batch_decode(labels, **DEFAULT_DECODE_KWARGS)
        for labels in logit_diff_summs["clean"].labels
    ]

    df = pd.DataFrame(
        {
            "prompt": logit_diff_summs["sae_cache"].prompts,
            "correct_answer": translated_labels,
            "clean_logit_diff": logit_diff_summs["clean"].logit_diffs,
            "sae_logit_diff": logit_diff_summs["sae_cache"].logit_diffs,
        }
    )
    df = df.explode(["prompt", "correct_answer", "clean_logit_diff", "sae_logit_diff"])
    df["sample_id"] = range(len(df))
    df = df[["sample_id", "prompt", "correct_answer", "clean_logit_diff", "sae_logit_diff"]]
    df = df[df.clean_logit_diff > 0].sort_values(by="clean_logit_diff", ascending=False)

    print(
        tabulate(
            df,
            headers=["Sample ID", "Prompt", "Answer", "Clean Logit Diff", "SAE Logit Diff"],
            maxcolwidths=[None, 80, None, None, None],
            tablefmt="grid",
            numalign="left",
            floatfmt="+.3f",
            showindex="never",
        )
    )


In [None]:
mode_correct = {}

for mode, summ in logit_diff_summs.items():
    if mode in ["clean", "sae_cache", "attr_patching"]:
        correct_statuses = [(labels == preds) for labels, preds in zip(summ.orig_labels, summ.preds)]
        positive_logit_diff_statuses = [(logit_diffs > 0) for logit_diffs in summ.logit_diffs]
        assert all(torch.eq(torch.cat(correct_statuses), torch.cat(positive_logit_diff_statuses)))
        total_correct = sum(torch.cat(correct_statuses)).item()
        percentage_correct = total_correct / len(torch.cat(correct_statuses)) * 100
        mode_correct[mode] = (total_correct, percentage_correct)
    elif mode == "ablation":
        ablation_per_batch_preds = [torch.stack([p for p in pl.values()]).mode(dim=0).values.cpu() for pl in summ.preds]
        num_correct = [(labels == preds).nonzero().unique().size(0) for labels, preds in zip(summ.orig_labels,
                                                                                             ablation_per_batch_preds)]
        total_correct = sum(num_correct)
        batch_size = summ.labels[0].size(0)
        percentage_correct = total_correct / (len(torch.cat(summ.orig_labels))) * 100
        mode_correct[mode] = (total_correct, percentage_correct)
    else:
        raise ValueError(f"Unknown mode: {mode}")

table_rows = []
for mode, (total_correct, percentage_correct) in mode_correct.items():
    table_rows.append([mode, total_correct, f"{percentage_correct:.2f}%"])

print(tabulate(table_rows, headers=["Mode", "Total Correct", "Percentage Correct"], tablefmt="grid"))

In [None]:
generate_per_batch_effects_graphs = False
if generate_per_batch_effects_graphs and run_ablation_sanity_check:
    for i, ablation_effects in enumerate(logit_diff_summs["ablation"].ablation_effects):
        len_alive = len(logit_diff_summs['sae_cache'].alive_latents[i])
        px.line(
            ablation_effects.mean(dim=0).cpu().numpy(),
            title=f"Causal effects of latent ablation on logit diff of batch {i} ({len_alive} alive)",
            labels={"index": "Latent", "value": "Causal effect on logit diff"},
            template="ggplot2",
            width=1000,
        ).update_layout(showlegend=False).show()

In [None]:

if run_ablation_sanity_check:
    k = 10

    correct_acts = torch.cat(logit_diff_summs["sae_cache"].correct_activations)
    avg_correct_act, num_correct_act = correct_acts.mean(dim=0), correct_acts.count_nonzero(dim=0)
    proportion_correct_active = num_correct_act / len(correct_acts)
    correct_mask = torch.cat([(labels == preds) for labels, preds in zip(logit_diff_summs["ablation"].orig_labels,
                                                                         ablation_per_batch_preds)])
    per_example_ablation_effects = torch.cat(logit_diff_summs["ablation"].ablation_effects)[correct_mask]
    total_ablation_effects = per_example_ablation_effects.sum(dim=0)
    avg_ablation_effects = per_example_ablation_effects.mean(dim=0)

    topk_entries = {
        "positive": total_ablation_effects.topk(k),
        "negative": total_ablation_effects.topk(k, largest=False),
    }

    for label, total_summ in topk_entries.items():
        table_rows = []
        for value, ind in zip(*total_summ):
            table_rows.append([
                ind.item(),
                value.item(),
                avg_ablation_effects[ind].item(),
                avg_correct_act[ind].item(),
                num_correct_act[ind].item(),
                proportion_correct_active[ind].item(),
            ])
        print(tabulate(
            table_rows,
            headers=[
                "Latent Index",
                f"Total {label.capitalize()} Effect",
                "Mean Effect",
                "Mean Activation",
                "Num Examples Active",
                "Proportion Correct Examples Active",
            ],
            tablefmt="grid",
        ))

    # Print the top 3 positive and negative dashboards
    dashboard_k = 3
    topk_ablation_latents = {
        "ablation_positive": total_ablation_effects.topk(dashboard_k),
        "ablation_negative": total_ablation_effects.topk(dashboard_k, largest=False),
    }

    for label, topk_latents in topk_ablation_latents.items():
        for value, ind in zip(*topk_latents):
            print(f"#{ind} had total causal effect {value:.2f} and was active in {num_correct_act[ind]} examples")
            display_dashboard(
                sae_release="gpt2-small-hook-z-kk",
                sae_id=f"blocks.{layer}.hook_z",
                latent_idx=int(ind),
            )



In [None]:

if run_attr_patching:
    k = 10

    correct_acts = torch.cat(logit_diff_summs["attr_patching"].correct_activations)
    avg_correct_act, num_correct_act = correct_acts.mean(dim=0), correct_acts.count_nonzero(dim=0)
    proportion_correct_active = num_correct_act / len(correct_acts)
    correct_mask = torch.cat([(logit_diffs > 0) for logit_diffs in summ.logit_diffs])
    per_example_attribution_values = torch.cat(logit_diff_summs["attr_patching"].attribution_values)[correct_mask]
    total_attribution_values = per_example_attribution_values.sum(dim=0)
    avg_attribution_values = per_example_attribution_values.mean(dim=0)

    topk_entries = {
        "positive": total_attribution_values.topk(k),
        "negative": total_attribution_values.topk(k, largest=False),
    }

    for label, total_summ in topk_entries.items():
        table_rows = []
        for value, ind in zip(*total_summ):
            table_rows.append([
                ind.item(),
                value.item(),
                avg_attribution_values[ind].item(),
                avg_correct_act[ind].item(),
                num_correct_act[ind].item(),
                proportion_correct_active[ind].item(),
            ])
        print(tabulate(
            table_rows,
            headers=[
                "Latent Index",
                f"Total {label.capitalize()} Effect",
                "Mean Effect",
                "Mean Activation",
                "Num Examples Active",
                "Proportion Correct Examples Active",
            ],
            tablefmt="grid",
        ))

    # Print the top 3 positive and negative dashboards
    dashboard_k = 3
    topk_ablation_latents = {
        "ablation_positive": total_attribution_values.topk(dashboard_k),
        "ablation_negative": total_attribution_values.topk(dashboard_k, largest=False),
    }

    for label, topk_latents in topk_ablation_latents.items():
        for value, ind in zip(*topk_latents):
            print(f"#{ind} had total causal effect {value:.2f} and was active in {num_correct_act[ind]} examples")
            display_dashboard(
                sae_release="gpt2-small-hook-z-kk",
                sae_id=f"blocks.{layer}.hook_z",
                latent_idx=int(ind),
            )



In [None]:
if run_attr_patching and run_ablation_sanity_check:
    # Visualize results
    px.scatter(
        pd.DataFrame(
            {
                "Ablation": total_ablation_effects.numpy(),
                "Attribution Patching": total_attribution_values.numpy(),
                "Latent": torch.arange(total_attribution_values.size(0)).numpy(),
            }
        ),
        x="Ablation",
        y="Attribution Patching",
        hover_data=["Latent"],
        title="Attribution Patching vs Ablation",
        template="ggplot2",
        width=800,
        height=600,
    ).add_shape(
        type="line",
        x0=total_attribution_values.min(),
        x1=total_attribution_values.max(),
        y0=total_attribution_values.min(),
        y1=total_attribution_values.max(),
        line=dict(color="red", width=2, dash="dash"),
    ).show()



In [15]:
# TODO: continue implementation here
#   - unify hooks_cfg and fwd_hooks_cfg (replacing fwd_hooks_cfg with hooks_cfg fwd only config)
#   - abstract sae_cache_test_step to handle both ablation and attribution patching (and clean/sae_cache runs) via saelens adapter
#   - cleanup existing functions and toggle flags etc.
#   - extend attribution patching to collect latent effects at two different layers and replot top latents (look into
#     3898 latent ablation/attribution patching divergence if others emerge)
#   - re-test padding right, verify/adjust if necessary for multiple answer positions
#   - apply created pipeline to gemma2 and other models!
#   - demo framework flexibility by running Lightning-supported steps with lightning and non-lightning supported custom 
#     steps (analysis step and sae training steps) with IT basictrainer and/or raw pytorch training loop
#   - explore using a threshold positive logit diff for tuning
#   - though relative logits remain comparable, investigate impact of logit magnitude being substantially smaller with
#     padding_side=left (should be able to re-run with padding_side=right)
