# Interpretune SAELens Tutorial

![Fine-Tuning Scheduler logo](logo_fts.png){height="55px" width="401px"}

### Intro

[Interpretune](https://github.com/speediedan/interpretune) is a flexible framework for exploring, analyzing and tuning llm world models. In this tutorial, we'll walk through a simple example of using Interpretune to pursue interpretability research with SAELens. As we'll see, Interpretune handles the required execution context composition, allowing us to use the same code in a variety of contexts, depending upon the level of abstraction required.

As a long-time PyTorch and PyTorch Lightning contributor, I've found the PyTorch Lightning framework is the right level of abstraction for a large variety of ML research contexts, but some contexts benefit from using core PyTorch directly. Additionally, some users may prefer to use the core PyTorch framework directly for a wide variety of reasons including maximizing portability. As will be demonstrated here, Interpretune maximizes flexibility and portability by adhering to a well-defined protocol that allows auto-composition of our research module with the adapters required for execution in a wide variety of contexts. In this example, we'll be executing the same module with core PyTorch and PyTorch Lightning, demonstrating the use of `SAELens` w/ Interpretune for interpretability research.

> Note - **this is a WIP**, but this is the core idea. If you have any feedback, please let me know!

## A note on memory usage

In these exercises, we'll be loading some pretty large models into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

<details>
<summary>See this dropdown for some functions which you might find helpful, and how to use them.</summary>

First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.

```python
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
part32_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘</pre>

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

```python
del gemma_2_2b
del gemma_2_2b_sae

THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except:
        pass

# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}

part32_utils.print_memory_status()
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66</pre>

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out `obj.shape` when `obj` is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted `gemma_2_2b`.

</details>

#### Imports

In [1]:
import logging
import os
from typing import Any, Dict, Optional, Tuple, List, Callable, Set
from dataclasses import dataclass, field
from functools import partial

import evaluate
import datasets
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizerBase
from transformers.tokenization_utils_base import BatchEncoding
from datasets.arrow_dataset import LazyDict
from tabulate import tabulate
from tqdm.auto import tqdm
from transformer_lens import ActivationCache  # noqa: F401

from interpretune.adapters.core import ITModule
from interpretune.base.call import it_init
from interpretune.base.config.datamodule import PromptConfig, ITDataModuleConfig
from interpretune.base.config.module import ITConfig
from interpretune.base.config.mixins import GenerativeClassificationConfig
from interpretune.base.components.mixins import ProfilerHooksMixin
from interpretune.base.datamodules import ITDataModule
from interpretune.utils.logging import rank_zero_warn
from interpretune.utils.types import STEP_OUTPUT
from interpretune.utils.tokenization import _sanitize_input_name
from interpretune.adapters.sae_lens import SAELensFromPretrainedConfig
from interpretune.base.config.shared import Adapter
from interpretune.base.contract.session import ITSessionConfig, ITSession
from it_examples import _ACTIVE_PATCHES # noqa: F401  # TODO: add note about this unless patched in SL before release
from interpretune.utils.analysis import (SaveCfg, AnalysisCache,  AnalysisBatch, AnalysisCfg, run_with_ctx,
                                         construct_names_filter, loss_and_logit_diffs, create_attribution_tables,
                                         calculate_latent_metrics, display_latent_dashboards, calc_activation_summary,
                                         plot_latent_effects, display_ref_vs_sae_logit_diffs, latent_metrics_scatter,
                                         compute_correct, AnalysisMode)


### Tutorial Configuration

In [2]:
# define our Tutorial Configuration
# By default, we will run all analysis modes available with LatentAttributionMixin:
@dataclass
class TutorialConfig:
    analysis_mode_demos: Set[AnalysisMode] = field(default_factory=lambda: {
        AnalysisMode.clean_no_sae, AnalysisMode.clean_w_sae, AnalysisMode.attr_patching, AnalysisMode.ablation})
    limit_test_batches: int = 3
    target_layers: List[int] = field(default_factory=lambda: [9, 10])
    latent_effects_graphs: bool = True
    latent_effects_graphs_per_batch: bool = False  # can be overwhelming with many batches
    latents_table_per_sae: bool = True
    top_k_latents_table: int = 2
    top_k_latent_dashboards: int = 1  # (don't set too high, num dashboards = top_k_latent_dashboards * num_hooks * 2)
    top_k_clean_logit_diffs: int = 10
    sae_release: str = "gpt2-small-hook-z-kk"
    sae_hook_point: str = "hook_sae_acts_post"

    def __post_init__(self):
        if self.latent_effects_graphs_per_batch and not self.latent_effects_graphs:
            print("Note: Setting latent_effects_graphs to True since latent_effects_graphs_per_batch is True")
            self.latent_effects_graphs = True

# Change config here if you want to run a different set of analysis modes
# e.g. TutorialConfig(analysis_mode_demos={AnalysisMode.attr_patching})
#tutorial_config = TutorialConfig()
# various test configs
# tutorial_config = TutorialConfig(analysis_mode_demos={AnalysisMode.ablation})
# tutorial_config = TutorialConfig(analysis_mode_demos={AnalysisMode.attr_patching}, limit_test_batches=-1)
# tutorial_config = TutorialConfig(analysis_mode_demos={AnalysisMode.clean_no_sae, AnalysisMode.clean_w_sae},
#                                  limit_test_batches=-1)
# tutorial_config = TutorialConfig(analysis_mode_demos={AnalysisMode.clean_no_sae, AnalysisMode.clean_w_sae, 
#                                                       AnalysisMode.ablation, AnalysisMode.attr_patching}, 
#                                                       limit_test_batches=-1)
tutorial_config = TutorialConfig(analysis_mode_demos={AnalysisMode.clean_no_sae, AnalysisMode.clean_w_sae,
                                                      AnalysisMode.ablation, AnalysisMode.attr_patching},
                                                      limit_test_batches=3)



### Define Our IT Data Module


In [3]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

log = logging.getLogger(__name__)

TASK_TEXT_FIELD_MAP = {"rte": ("premise", "hypothesis"), "boolq": ("passage", "question")}
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
DEFAULT_TASK = "rte"
INVALID_TASK_MSG = f" is an invalid task_name. Proceeding with the default task: {DEFAULT_TASK!r}"

class RTEBoolqDataModule(ITDataModule):
    def __init__(self, itdm_cfg: ITDataModuleConfig) -> None:
        if itdm_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(itdm_cfg.task_name + INVALID_TASK_MSG)
            itdm_cfg.task_name = DEFAULT_TASK
        itdm_cfg.text_fields = TASK_TEXT_FIELD_MAP[itdm_cfg.task_name]
        super().__init__(itdm_cfg=itdm_cfg)

    def prepare_data(self, target_model: Optional[torch.nn.Module] = None) -> None:
        """Load the SuperGLUE dataset."""
        # N.B. prepare_data is called in a single process (rank 0, either per node or globally) so do not use it to
        # assign state (e.g. self.x=y)
        # note for raw pytorch we require a target_model
        # NOTE [HF Datasets Transformation Caching]:
        # HF Datasets' transformation cache fingerprinting algo necessitates construction of these partials as the hash
        # is generated using function args, dataset file, mapping args: https://bit.ly/HF_Datasets_fingerprint_algo)
        tokenization_func = partial(
            self.encode_for_rteboolq,
            tokenizer=self.tokenizer,
            text_fields=self.itdm_cfg.text_fields,
            prompt_cfg=self.itdm_cfg.prompt_cfg,
            template_fn=self.itdm_cfg.prompt_cfg.model_chat_template_fn,
            tokenization_pattern=self.itdm_cfg.cust_tokenization_pattern,
        )
        dataset = datasets.load_dataset("super_glue", self.itdm_cfg.task_name, trust_remote_code=True)
        for split in dataset.keys():
            dataset[split] = dataset[split].map(tokenization_func, **self.itdm_cfg.prepare_data_map_cfg)
            dataset[split] = self._remove_unused_columns(dataset[split], target_model)
        dataset.save_to_disk(self.itdm_cfg.dataset_path)

    def dataloader_factory(self, split: str, use_train_batch_size: bool = False) -> DataLoader:
        dataloader_kwargs = {"dataset": self.dataset[split], "collate_fn":self.data_collator,
                             **self.itdm_cfg.dataloader_kwargs}
        dataloader_kwargs['batch_size'] = self.itdm_cfg.train_batch_size if use_train_batch_size else \
            self.itdm_cfg.eval_batch_size
        return DataLoader(**dataloader_kwargs)

    # TODO: change to partialmethod's?
    def train_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='train', use_train_batch_size=True)

    def val_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def test_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def predict_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    #TODO: relax PreTrainedTokenizerBase to the protocol that is actually required
    @staticmethod
    def encode_for_rteboolq(example_batch: LazyDict, tokenizer: PreTrainedTokenizerBase, text_fields: List[str],
                            prompt_cfg: PromptConfig, template_fn: Callable,
                            tokenization_pattern: Optional[str] = None) -> BatchEncoding:
        example_batch['sequences'] = []
        # TODO: use promptsource instead of this manual approach after tinkering
        for field1, field2 in zip(example_batch[text_fields[0]],
                                  example_batch[text_fields[1]]):
            if prompt_cfg.cust_task_prompt:
                task_prompt = (prompt_cfg.cust_task_prompt['context'] + "\n" +
                               field1 + "\n" +
                               prompt_cfg.cust_task_prompt['question'] + "\n" +
                               field2)
            else:
                field2 = field2.rstrip('.')
                task_prompt = (field1 + prompt_cfg.ctx_question_join + field2 \
                               + prompt_cfg.question_suffix)
            sequence = template_fn(task_prompt=task_prompt, tokenization_pattern=tokenization_pattern)
            example_batch['sequences'].append(sequence)
        features = tokenizer.batch_encode_plus(example_batch["sequences"], padding="longest",
                                               padding_side=tokenizer.padding_side)
        features["labels"] = example_batch["label"]  # Rename label to labels, easier to pass to model forward
        features = _sanitize_input_name(tokenizer.model_input_names, features)
        return features

### Define Our IT Module

In [4]:
@dataclass(kw_only=True)
class RTEBoolqEntailmentMapping:
    entailment_mapping: Tuple = ("Yes", "No")  # RTE style, invert mapping for BoolQ
    entailment_mapping_indices: Optional[torch.Tensor] = None


@dataclass(kw_only=True)
class RTEBoolqPromptConfig(PromptConfig):
    ctx_question_join: str = " Does the previous passage imply that "
    question_suffix: str = "? Answer with only one word, either Yes or No."
    cust_task_prompt: Optional[Dict[str, Any]] = None


class RTEBoolqModule(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # when using TransformerLens, we need to manually calculate our loss from logit output
        self.loss_fn = CrossEntropyLoss()

    def setup(self, *args, **kwargs) -> None:
        super().setup(*args, **kwargs)
        self._init_entailment_mapping()

    def _before_it_cfg_init(self, it_cfg: ITConfig) -> ITConfig:
        if it_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(it_cfg.task_name + INVALID_TASK_MSG)
            it_cfg.task_name = DEFAULT_TASK
        it_cfg.num_labels = 0 if it_cfg.generative_step_cfg.enabled else TASK_NUM_LABELS[it_cfg.task_name]
        return it_cfg

    def load_metric(self) -> None:
        self.metric = evaluate.load(
            "super_glue", self.it_cfg.task_name, experiment_id=self._it_state._init_hparams["experiment_id"]
        )

    def _init_entailment_mapping(self) -> None:
        ent_cfg, tokenizer = self.it_cfg, self.datamodule.tokenizer
        token_ids = tokenizer.convert_tokens_to_ids(ent_cfg.entailment_mapping)
        device = self.device if isinstance(self.device, torch.device) else self.output_device
        ent_cfg.entailment_mapping_indices = torch.tensor(token_ids).to(device)

    def labels_to_ids(self, labels: List[str]) -> List[int]:
        return torch.take(self.it_cfg.entailment_mapping_indices, labels), labels

    @ProfilerHooksMixin.memprofilable
    def training_step(self, batch: BatchEncoding, batch_idx: int) -> STEP_OUTPUT:
        # TODO: need to be explicit about the compatibility constraints/contract
        # TODO: note that this example uses generative_step_cfg and lm_head except for the test_step where we demo how
        # to use the GenerativeStepMixin to run inference with or without a generative_step_cfg enabled as well as with
        # different heads (e.g., seqclassification or LM head in this case)
        answer_logits, labels, *_ = self.logits_and_labels(batch, batch_idx)
        loss = self.loss_fn(answer_logits, labels)
        self.log("train_loss", loss, sync_dist=True)
        return loss

    @ProfilerHooksMixin.memprofilable
    def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        answer_logits, labels, orig_labels, cache = self.logits_and_labels(batch, batch_idx)
        val_loss = self.loss_fn(answer_logits, labels)
        self.log("val_loss", val_loss, prog_bar=True, sync_dist=True)
        self.collect_answers(answer_logits, orig_labels)

    @ProfilerHooksMixin.memprofilable
    def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        if self.it_cfg.generative_step_cfg.enabled:
            self.generative_classification_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)
        else:
            self.default_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)

    def generative_classification_test_step(
        self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0
    ) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self.it_generate(batch, **self.it_cfg.generative_step_cfg.lm_generation_cfg.generate_kwargs)
        self.collect_answers(outputs.logits, labels)

    def default_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        self.collect_answers(outputs.logits, labels)

    def predict_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        return self.collect_answers(outputs, labels, mode="return")

    def collect_answers(self, logits: torch.Tensor | tuple, labels: torch.Tensor, mode: str = "log") -> Optional[Dict]:
        logits = self.standardize_logits(logits)
        per_example_answers, _ = torch.max(logits, dim=-2)
        preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
        metric_dict = self.metric.compute(predictions=preds, references=labels)
        # TODO: check if this type casting is still required for lightning torchmetrics, bug should be fixed now...
        metric_dict = dict(
            map(lambda x: (x[0], torch.tensor(x[1], device=self.device).to(torch.float32)), metric_dict.items())
        )
        if mode == "log":
            self.log_dict(metric_dict, prog_bar=True, sync_dist=True)
        else:
            return metric_dict

    def standardize_logits(self, logits: torch.Tensor) -> torch.Tensor:
        # to support generative classification/non-generative classification configs and LM/SeqClassification heads we
        # adhere to the following logits logical shape invariant:
        # [batch size, positions to consider, answers to consider]
        if isinstance(logits, tuple):
            logits = torch.stack([out for out in logits], dim=1)
        logits = logits.to(device=self.device)
        if logits.ndim == 2:  # if answer logits have already been squeezed
            logits = logits.unsqueeze(1)
        if logits.shape[-1] != self.it_cfg.num_labels:
            logits = torch.index_select(logits, -1, self.it_cfg.entailment_mapping_indices)
            if not self.it_cfg.generative_step_cfg.enabled:
                logits = logits[:, -1:, :]
        return logits

    def logits_and_labels(
        self, batch: BatchEncoding, batch_idx: int, run_ctx: str, hook_names: Optional[str] = None
    ) -> torch.Tensor:
        label_ids, labels = self.labels_to_ids(batch.pop("labels"))
        cache = None
        logits, cache = self.run_with_ctx(run_ctx, batch, batch_idx, hook_names=hook_names)
        # TODO: add another layer of abstraction here to handle different model output types? Tradeoffs to consider...
        if not isinstance(logits, torch.Tensor):
            logits = logits.logits
            assert isinstance(logits, torch.Tensor), f"Expected logits to be a torch.Tensor but got {type(logits)}"
        return torch.squeeze(logits[:, -1, :], dim=1), label_ids, labels, cache

    def analysis_test_step(
        self, batch: BatchEncoding, batch_idx: int, analysis_cfg: AnalysisCfg, dataloader_idx: int = 0
    ) -> Optional[STEP_OUTPUT]:
        label_ids, orig_labels = self.labels_to_ids(batch.pop("labels"))
        analysis_batch = AnalysisBatch(labels=label_ids, orig_labels=orig_labels)
        run_with_ctx(self, analysis_batch, analysis_cfg, batch, batch_idx)
        loss_and_logit_diffs(self, analysis_batch, analysis_cfg, batch, batch_idx)
        analysis_cfg.analysis_cache.save(analysis_batch, batch, tokenizer=self.datamodule.tokenizer)


### Configure our IT Session


In [None]:
from interpretune.adapters.transformer_lens import TLensGenerationConfig
from interpretune.base.config.mixins import HFFromPretrainedConfig
from interpretune.adapters.transformer_lens import ITLensFromPretrainedNoProcessingConfig
from interpretune.base.config.shared import ITSharedConfig, AutoCompConfig

shared_cfg = ITSharedConfig(model_name_or_path='gpt2', task_name='rte', tokenizer_id_overrides={'pad_token_id': 50256},
                  tokenizer_kwargs={'model_input_names': ['input'], 'padding_side': 'left', 'add_bos_token': True})
datamodule_cfg = ITDataModuleConfig(prompt_cfg=RTEBoolqPromptConfig(), train_batch_size=2, eval_batch_size=2,
                                    signature_columns=['input', 'labels'], prepare_data_map_cfg={"batched": True})
genclassif_cfg = GenerativeClassificationConfig(enabled=True, lm_generation_cfg=TLensGenerationConfig(max_new_tokens=1))
hf_cfg = HFFromPretrainedConfig(pretrained_kwargs={'torch_dtype': 'float32'}, model_head='transformers.GPT2LMHeadModel')
tl_cfg = ITLensFromPretrainedNoProcessingConfig(model_name="gpt2-small", default_padding_side='left')
sae_cfgs = [SAELensFromPretrainedConfig(release="gpt2-small-hook-z-kk", sae_id=f"blocks.{layer}.hook_z")
            for layer in tutorial_config.target_layers]

auto_comp_cfg = AutoCompConfig(module_cfg_name='RTEBoolqConfig', module_cfg_mixin=RTEBoolqEntailmentMapping)
module_cfg = ITConfig(auto_comp_cfg=auto_comp_cfg, generative_step_cfg=genclassif_cfg, hf_from_pretrained_cfg=hf_cfg,
                      tl_cfg=tl_cfg, sae_cfgs=sae_cfgs)

session_cfg = ITSessionConfig(adapter_ctx=(Adapter.core, Adapter.sae_lens),
                              datamodule_cls=RTEBoolqDataModule, module_cls=RTEBoolqModule,
                              shared_cfg=shared_cfg, datamodule_cfg=datamodule_cfg, module_cfg=module_cfg)
it_session = ITSession(session_cfg)
# TODO: maybe open a PR for the below
# https://github.com/jbloomAus/SAELens/blob/aa8f42bf06d9c68bb890f4881af0aac916ecd17c/sae_lens/sae.py#L144-L151 warning
# that inspects whether the loaded model has a default config override specified in ``pretrained_saes.yaml`` (e.g.
# 'gpt2-small-res-jb', config_overrides: model_from_pretrained_kwargs: center_writing_weights: true) and if so, avoids
# giving an arguably spurious warning to the user

### Run Demo Analyses

In [None]:
# TODO: consider instantiating a trainer here once initial exploration is done
# manually init IT components for now
it_init(**it_session)
sl_test_module = it_session.module
assert sl_test_module.it_cfg.entailment_mapping_indices is not None

# TODO: move to module
def run_test_split_fn(
    module: ITModule,
    datamodule: ITDataModule,
    limit_test_batches: int,
    analysis_cfg: AnalysisCfg = field(default_factory=AnalysisCfg),
    step_fn: str = "analysis_test_step",
    *args,
    **kwargs,
):
    dataloader = datamodule.test_dataloader()
    module._it_state._current_epoch = 0  # TODO: test removing this, prob not needed in this context
    step_func = getattr(module, step_fn)
    for batch_idx, batch in tqdm(enumerate(dataloader)):
        if batch_idx >= limit_test_batches >= 0:
            break
        batch = module.batch_to_device(batch)
        step_func(batch, batch_idx, analysis_cfg=analysis_cfg, *args, **kwargs)
        module.global_step += 1
    return analysis_cfg.analysis_cache  # return cache handle for further analysis

logit_diff_summs = {}
test_cache_base_args = dict(limit_test_batches=tutorial_config.limit_test_batches, **it_session)
torch.set_grad_enabled(False)
names_filter = construct_names_filter(sl_test_module, target_hooks=(tutorial_config.target_layers,
                                                                    [tutorial_config.sae_hook_point]))

def run_clean_w_sae():
    logit_diff_summs[AnalysisMode.clean_w_sae] = AnalysisCache(save_cfg=SaveCfg(prompts=True, tokens=True))
    analysis_cfg = AnalysisCfg(mode=AnalysisMode.clean_w_sae, analysis_cache=logit_diff_summs[AnalysisMode.clean_w_sae],
                                names_filter=names_filter)
    run_test_split_fn(analysis_cfg=analysis_cfg, **test_cache_base_args)

if AnalysisMode.clean_w_sae in tutorial_config.analysis_mode_demos:
    run_clean_w_sae()

if AnalysisMode.clean_no_sae in tutorial_config.analysis_mode_demos:
    logit_diff_summs[AnalysisMode.clean_no_sae] = (
        AnalysisCache() if AnalysisMode.clean_w_sae in tutorial_config.analysis_mode_demos
        else AnalysisCache(save_cfg=SaveCfg(prompts=True, tokens=True))
    )
    analysis_cfg = AnalysisCfg(analysis_cache=logit_diff_summs[AnalysisMode.clean_no_sae])
    run_test_split_fn(analysis_cfg=analysis_cfg, **test_cache_base_args)

if AnalysisMode.ablation in tutorial_config.analysis_mode_demos:
    # Run clean_w_sae mode as well since ablation mode requires the base logit diffs
    if AnalysisMode.clean_w_sae not in logit_diff_summs:
        run_clean_w_sae()
        print("Note, running clean SAE cache first to get the base logit diffs and other cache values")
    logit_diff_summs[AnalysisMode.ablation] = AnalysisCache()
    analysis_cfg = AnalysisCfg(mode=AnalysisMode.ablation, analysis_cache=logit_diff_summs[AnalysisMode.ablation],
                               names_filter=names_filter,
                               base_logit_diffs=logit_diff_summs[AnalysisMode.clean_w_sae].logit_diffs,
                               answer_indices=logit_diff_summs[AnalysisMode.clean_w_sae].answer_indices,
                               alive_latents=logit_diff_summs[AnalysisMode.clean_w_sae].alive_latents)
    run_test_split_fn(analysis_cfg=analysis_cfg, **test_cache_base_args)

if AnalysisMode.attr_patching in tutorial_config.analysis_mode_demos:
    torch.set_grad_enabled(True)
    prompts_tokens_flags = {"prompts": True, "tokens": True} if not run_clean_w_sae else {}
    logit_diff_summs[AnalysisMode.attr_patching] = AnalysisCache(save_cfg=SaveCfg(**prompts_tokens_flags))
    analysis_cfg = AnalysisCfg(
        mode=AnalysisMode.attr_patching,
        analysis_cache=logit_diff_summs[AnalysisMode.attr_patching],
        names_filter=names_filter
    )
    run_test_split_fn(analysis_cfg=analysis_cfg, **test_cache_base_args)
    torch.set_grad_enabled(False)


### Review Demo Results

#### Clean vs SAE Sample-wise Logit Diffs

In [None]:
if {AnalysisMode.clean_no_sae, AnalysisMode.clean_w_sae}.issubset(tutorial_config.analysis_mode_demos):
    display_ref_vs_sae_logit_diffs(sae=logit_diff_summs[AnalysisMode.clean_w_sae],
                                   no_sae_ref=logit_diff_summs[AnalysisMode.clean_no_sae],
                                   top_k=tutorial_config.top_k_clean_logit_diffs,
                                   tokenizer=sl_test_module.datamodule.tokenizer)


#### Proportion Correct Answers on Dataset By Analysis Mode

In [None]:
pred_summaries = {mode: compute_correct(summ, mode) for mode, summ in logit_diff_summs.items()}

table_rows = []
for mode, (total_correct, percentage_correct, _) in pred_summaries.items():
    table_rows.append([mode, total_correct, f"{percentage_correct:.2f}%"])

print(tabulate(table_rows, headers=["Mode", "Total Correct", "Percentage Correct"], tablefmt="grid"))

#### Per Batch Ablation Effect Graphs [Optional]

In [None]:
if tutorial_config.latent_effects_graphs and AnalysisMode.ablation in tutorial_config.analysis_mode_demos:
    # TODO: add note that only latent effects associated with correct answers currently displayed
    # TODO: allow toggling correct filtering during runs
    plot_latent_effects(logit_diff_summs[AnalysisMode.ablation],
                        per_batch=tutorial_config.latent_effects_graphs_per_batch)

#### Per-SAE Ablation Effects

In [None]:

if AnalysisMode.ablation in tutorial_config.analysis_mode_demos:
    ablation_batch_preds = pred_summaries[AnalysisMode.ablation].batch_predictions
    activation_summary = calc_activation_summary(logit_diff_summs[AnalysisMode.clean_w_sae])
    ablation_metrics = calculate_latent_metrics(
        analysis_cache=logit_diff_summs[AnalysisMode.ablation],
        pred_summ=pred_summaries[AnalysisMode.ablation],
        activation_summary=activation_summary,
        # filter_by_correct=True,
        run_name="ablation"
    )

    tables = create_attribution_tables(metrics=ablation_metrics,
                                       top_k=tutorial_config.top_k_latents_table, filter_type='both',
                                       per_sae=tutorial_config.latents_table_per_sae)

    for title, table in tables.items():
        print(f"\n{title}\n{table}\n")

    display_latent_dashboards(ablation_metrics, title="Ablation-Mediated Latent Analysis",
                              sae_release=tutorial_config.sae_release, top_k=tutorial_config.top_k_latent_dashboards)



#### Per-SAE Attribution Patching Effects

In [None]:
if AnalysisMode.attr_patching in tutorial_config.analysis_mode_demos:
    # per-SAE activation summaries are calculated using our AnalysisCache since the relevant keys are present,
    # no need to provide a separate activation summary from another comparison cache in this case as with ablation
    activation_summary = calc_activation_summary(logit_diff_summs[AnalysisMode.attr_patching])
    attribution_patching_metrics = calculate_latent_metrics(
        analysis_cache=logit_diff_summs[AnalysisMode.attr_patching],
        pred_summ=pred_summaries[AnalysisMode.attr_patching],
        run_name="attr_patching"
    )

    tables = create_attribution_tables(metrics=attribution_patching_metrics,
                                       top_k=tutorial_config.top_k_latents_table, filter_type='both',
                                       per_sae=tutorial_config.latents_table_per_sae)

    for title, table in tables.items():
        print(f"\n{title}\n{table}\n")

    display_latent_dashboards(attribution_patching_metrics, title="Attribution Patching-Mediated Latent Analysis",
                            sae_release=tutorial_config.sae_release, top_k=tutorial_config.top_k_latent_dashboards)


#### Per-SAE Ablation vs Attribution-Patching Effect Parity

In [None]:
if {AnalysisMode.attr_patching, AnalysisMode.ablation}.issubset(tutorial_config.analysis_mode_demos):
    # Visualize results for each hook
    # Call the function with our metrics
    latent_metrics_scatter(ablation_metrics, attribution_patching_metrics, label1="Ablation", label2="Attribution Patching")




In [13]:
# TODO: continue implementation here
#   - finish moving/testing LatentMetrics calculation for attr_patching cell to analysis.py module
#   - move analysis_test_step components into LatentAttributionMixin in saelens adapter
#   - re-test padding right, verify/adjust if necessary for multiple answer positions
#   - may need to refactor per-hook data structures to be more efficient for summarization
#    (once desired summary stats stabilize)
#   - apply created pipeline to gemma2 and other models!
#   - demo framework flexibility by running Lightning-supported steps with lightning and non-lightning supported custom
#     steps (analysis step and sae training steps) with IT basictrainer and/or raw pytorch training loop
#   - explore using a threshold positive logit diff for tuning
#   - though relative logits remain comparable, investigate impact of logit magnitude being substantially smaller with
#     padding_side=left (should be able to re-run with padding_side=right)
