# Interpretune SAELens Tutorial

![Fine-Tuning Scheduler logo](logo_fts.png){height="55px" width="401px"}

### Intro

[Interpretune](https://github.com/speediedan/interpretune) is a flexible framework for exploring, analyzing and tuning llm world models. In this tutorial, we'll walk through a simple example of using Interpretune to pursue interpretability research with SAELens. As we'll see, Interpretune handles the required execution context composition, allowing us to use the same code in a variety of contexts, depending upon the level of abstraction required.

As a long-time PyTorch and PyTorch Lightning contributor, I've found the PyTorch Lightning framework is the right level of abstraction for a large variety of ML research contexts, but some contexts benefit from using core PyTorch directly. Additionally, some users may prefer to use the core PyTorch framework directly for a wide variety of reasons including maximizing portability. As will be demonstrated here, Interpretune maximizes flexibility and portability by adhering to a well-defined protocol that allows auto-composition of our research module with the adapters required for execution in a wide variety of contexts. In this example, we'll be executing the same module with core PyTorch and PyTorch Lightning, demonstrating the use of `SAELens` w/ Interpretune for interpretability research.

> Note - **this is a WIP**, but this is the core idea. If you have any feedback, please let me know!

## A note on memory usage

In these exercises, we'll be loading some pretty large opls into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

<details>
<summary>See this dropdown for some functions which you might find helpful, and how to use them.</summary>

First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.

```python
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
part32_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘</pre>

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

```python
del gemma_2_2b
del gemma_2_2b_sae

THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except:
        pass

# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}

part32_utils.print_memory_status()
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66</pre>

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out `obj.shape` when `obj` is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted `gemma_2_2b`.

</details>

#### Imports

In [None]:
import logging
import os
from typing import Any, Dict, Optional, Tuple, List, Callable, Generator
from dataclasses import dataclass
from functools import partial

import evaluate
import datasets
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedTokenizerBase
from transformers.tokenization_utils_base import BatchEncoding
from datasets.arrow_dataset import LazyDict
from tabulate import tabulate
from transformer_lens import ActivationCache  # noqa: F401

from interpretune.base.config.datamodule import PromptConfig, ITDataModuleConfig
from interpretune.base.config.module import ITConfig
from interpretune.base.config.mixins import GenerativeClassificationConfig
from interpretune.base.components.mixins import ProfilerHooksMixin
from interpretune.base.datamodules import ITDataModule
from interpretune.utils.logging import rank_zero_warn
from interpretune.utils.session_runner import AnalysisRunnerCfg, AnalysisRunner, AnalysisSetCfg
from interpretune.utils.types import STEP_OUTPUT
from interpretune.utils.tokenization import _sanitize_input_name
from interpretune.adapters.sae_lens import SAELensFromPretrainedConfig
from interpretune.base.config.shared import Adapter
from interpretune.base.contract.session import ITSessionConfig, ITSession
from it_examples import _ACTIVE_PATCHES # noqa: F401  # TODO: add note about this unless patched in SL before release
from interpretune.base.analysis import (AnalysisBatch, base_vs_sae_logit_diffs, latent_metrics_scatter, compute_correct,
                                        SAEAnalysisTargets)
from interpretune.base.ops import ANALYSIS_OPS


### Tutorial Configuration

In [None]:
# Define our Tutorial Configuration
# By default, we will run all analysis ops available with SAEAnalysisMixin:
# Change config here if you want to run a different set of analysis ops, use different sae targets, etc.

sae_targets = SAEAnalysisTargets(sae_release="gpt2-small-hook-z-kk", target_layers=[9, 10])
# TODO: move non-sae analysis target cfg to a separate ArtifactCfg class
# tutorial_config = AnalysisSetCfg(sae_analysis_targets=sae_targets)  # default 1 full epoch and all analysis ops

########################################################################################################################
# Various example configs
#-----------------------------------------------------------------------------------------------------------------------
# tutorial_config = AnalysisSetCfg(analysis_ops=(ANALYSIS_OPS['logit_diffs.base'], ANALYSIS_OPS['logit_diffs.sae']))
w_batch_limit_kwargs = dict(limit_analysis_batches=3, sae_analysis_targets=sae_targets)
# tutorial_config = AnalysisSetCfg(analysis_ops=(ANALYSIS_OPS['logit_diffs.attribution.grad_based'],), **w_batch_limit_kwargs) # patching-only
# tutorial_config = AnalysisSetCfg(analysis_ops=(ANALYSIS_OPS['logit_diffs.attribution.ablation'],), **w_batch_limit_kwargs)  # ablation-only
# tutorial_config = AnalysisSetCfg(analysis_ops=(ANALYSIS_OPS['logit_diffs.base'],), **w_batch_limit_kwargs)  # no-sae only
# tutorial_config = AnalysisSetCfg(analysis_ops=(ANALYSIS_OPS['logit_diffs.sae'],), **w_batch_limit_kwargs)  # w-sae only
# tutorial_config = AnalysisSetCfg(analysis_ops=(ANALYSIS_OPS['logit_diffs.base'], ANALYSIS_OPS['logit_diffs.sae']), # clean-only
#                                  **w_batch_limit_kwargs)
tutorial_config = AnalysisSetCfg(**w_batch_limit_kwargs)  # all analysis ops, but limit to 3 batches
########################################################################################################################

### Define Our IT Data Module


In [None]:
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

log = logging.getLogger(__name__)

TASK_TEXT_FIELD_MAP = {"rte": ("premise", "hypothesis"), "boolq": ("passage", "question")}
TASK_NUM_LABELS = {"boolq": 2, "rte": 2}
DEFAULT_TASK = "rte"
INVALID_TASK_MSG = f" is an invalid task_name. Proceeding with the default task: {DEFAULT_TASK!r}"

class RTEBoolqDataModule(ITDataModule):
    def __init__(self, itdm_cfg: ITDataModuleConfig) -> None:
        if itdm_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(itdm_cfg.task_name + INVALID_TASK_MSG)
            itdm_cfg.task_name = DEFAULT_TASK
        itdm_cfg.text_fields = TASK_TEXT_FIELD_MAP[itdm_cfg.task_name]
        super().__init__(itdm_cfg=itdm_cfg)

    def prepare_data(self, target_model: Optional[torch.nn.Module] = None) -> None:
        """Load the SuperGLUE dataset."""
        # N.B. prepare_data is called in a single process (rank 0, either per node or globally) so do not use it to
        # assign state (e.g. self.x=y)
        # note for raw pytorch we require a target_model
        # NOTE [HF Datasets Transformation Caching]:
        # HF Datasets' transformation cache fingerprinting algo necessitates construction of these partials as the hash
        # is generated using function args, dataset file, mapping args: https://bit.ly/HF_Datasets_fingerprint_algo)
        tokenization_func = partial(
            self.encode_for_rteboolq,
            tokenizer=self.tokenizer,
            text_fields=self.itdm_cfg.text_fields,
            prompt_cfg=self.itdm_cfg.prompt_cfg,
            template_fn=self.itdm_cfg.prompt_cfg.model_chat_template_fn,
            tokenization_pattern=self.itdm_cfg.cust_tokenization_pattern,
        )
        dataset = datasets.load_dataset("super_glue", self.itdm_cfg.task_name, trust_remote_code=True)
        for split in dataset.keys():
            dataset[split] = dataset[split].map(tokenization_func, **self.itdm_cfg.prepare_data_map_cfg)
            dataset[split] = self._remove_unused_columns(dataset[split], target_model)
        dataset.save_to_disk(self.itdm_cfg.dataset_path)

    def dataloader_factory(self, split: str, use_train_batch_size: bool = False) -> DataLoader:
        dataloader_kwargs = {"dataset": self.dataset[split], "collate_fn":self.data_collator,
                             **self.itdm_cfg.dataloader_kwargs}
        dataloader_kwargs['batch_size'] = self.itdm_cfg.train_batch_size if use_train_batch_size else \
            self.itdm_cfg.eval_batch_size
        return DataLoader(**dataloader_kwargs)

    # TODO: change to partialmethod's?
    def train_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='train', use_train_batch_size=True)

    def val_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def test_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    def predict_dataloader(self) -> DataLoader:
        return self.dataloader_factory(split='validation')

    #TODO: relax PreTrainedTokenizerBase to the protocol that is actually required
    @staticmethod
    def encode_for_rteboolq(example_batch: LazyDict, tokenizer: PreTrainedTokenizerBase, text_fields: List[str],
                            prompt_cfg: PromptConfig, template_fn: Callable,
                            tokenization_pattern: Optional[str] = None) -> BatchEncoding:
        example_batch['sequences'] = []
        # TODO: use promptsource instead of this manual approach after tinkering
        for field1, field2 in zip(example_batch[text_fields[0]],
                                  example_batch[text_fields[1]]):
            if prompt_cfg.cust_task_prompt:
                task_prompt = (prompt_cfg.cust_task_prompt['context'] + "\n" +
                               field1 + "\n" +
                               prompt_cfg.cust_task_prompt['question'] + "\n" +
                               field2)
            else:
                field2 = field2.rstrip('.')
                task_prompt = (field1 + prompt_cfg.ctx_question_join + field2 \
                               + prompt_cfg.question_suffix)
            sequence = template_fn(task_prompt=task_prompt, tokenization_pattern=tokenization_pattern)
            example_batch['sequences'].append(sequence)
        features = tokenizer.batch_encode_plus(example_batch["sequences"], padding="longest",
                                               padding_side=tokenizer.padding_side)
        features["labels"] = example_batch["label"]  # Rename label to labels, easier to pass to model forward
        features = _sanitize_input_name(tokenizer.model_input_names, features)
        return features

### Define Our IT Module

In [None]:
@dataclass(kw_only=True)
class RTEBoolqEntailmentMapping:
    entailment_mapping: Tuple = ("Yes", "No")  # RTE style, invert mapping for BoolQ
    entailment_mapping_indices: Optional[torch.Tensor] = None


@dataclass(kw_only=True)
class RTEBoolqPromptConfig(PromptConfig):
    ctx_question_join: str = " Does the previous passage imply that "
    question_suffix: str = "? Answer with only one word, either Yes or No."
    cust_task_prompt: Optional[Dict[str, Any]] = None


class RTEBoolqModule(torch.nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        # when using TransformerLens, we need to manually calculate our loss from logit output
        self.loss_fn = CrossEntropyLoss()

    def setup(self, *args, **kwargs) -> None:
        super().setup(*args, **kwargs)
        self._init_entailment_mapping()

    def _before_it_cfg_init(self, it_cfg: ITConfig) -> ITConfig:
        if it_cfg.task_name not in TASK_NUM_LABELS.keys():
            rank_zero_warn(it_cfg.task_name + INVALID_TASK_MSG)
            it_cfg.task_name = DEFAULT_TASK
        it_cfg.num_labels = 0 if it_cfg.generative_step_cfg.enabled else TASK_NUM_LABELS[it_cfg.task_name]
        return it_cfg

    def load_metric(self) -> None:
        self.metric = evaluate.load(
            "super_glue", self.it_cfg.task_name, experiment_id=self._it_state._init_hparams["experiment_id"]
        )

    def _init_entailment_mapping(self) -> None:
        ent_cfg, tokenizer = self.it_cfg, self.datamodule.tokenizer
        token_ids = tokenizer.convert_tokens_to_ids(ent_cfg.entailment_mapping)
        device = self.device if isinstance(self.device, torch.device) else self.output_device
        ent_cfg.entailment_mapping_indices = torch.tensor(token_ids).to(device)

    def labels_to_ids(self, labels: List[str]) -> List[int]:
        return torch.take(self.it_cfg.entailment_mapping_indices, labels), labels

    @ProfilerHooksMixin.memprofilable
    def training_step(self, batch: BatchEncoding, batch_idx: int) -> STEP_OUTPUT:
        # TODO: need to be explicit about the compatibility constraints/contract
        # TODO: note that this example uses generative_step_cfg and lm_head except for the test_step where we demo how
        # to use the GenerativeStepMixin to run inference with or without a generative_step_cfg enabled as well as with
        # different heads (e.g., seqclassification or LM head in this case)
        answer_logits, labels, *_ = self.logits_and_labels(batch, batch_idx)
        loss = self.loss_fn(answer_logits, labels)
        self.log("train_loss", loss, sync_dist=True)
        return loss

    @ProfilerHooksMixin.memprofilable
    def validation_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        answer_logits, labels, orig_labels, *_ = self.logits_and_labels(batch, batch_idx)
        val_loss = self.loss_fn(answer_logits, labels)
        self.log("val_loss", val_loss, prog_bar=True, sync_dist=True)
        self.collect_answers(answer_logits, orig_labels)

    @ProfilerHooksMixin.memprofilable
    def test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        if self.it_cfg.generative_step_cfg.enabled:
            self.generative_classification_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)
        else:
            self.default_test_step(batch, batch_idx, dataloader_idx=dataloader_idx)

    def generative_classification_test_step(
        self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0
    ) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self.it_generate(batch, **self.it_cfg.generative_step_cfg.lm_generation_cfg.generate_kwargs)
        self.collect_answers(outputs.logits, labels)

    def default_test_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        self.collect_answers(outputs.logits, labels)

    def predict_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
        labels = batch.pop("labels")
        outputs = self(**batch)
        return self.collect_answers(outputs, labels, mode="return")

    # def analysis_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Optional[STEP_OUTPUT]:
    #     label_ids, orig_labels = self.labels_to_ids(batch.pop("labels"))
    #     analysis_batch = AnalysisBatch(labels=label_ids, orig_labels=orig_labels)
    #     self.run_with_ctx(analysis_batch, batch, batch_idx)
    #     self.loss_and_logit_diffs(analysis_batch, batch, batch_idx)
    #     self.analysis_cfg.analysis_store.save(analysis_batch, batch, tokenizer=self.datamodule.tokenizer)

    def analysis_step(self, batch: BatchEncoding, batch_idx: int, dataloader_idx: int = 0) -> Generator[STEP_OUTPUT,
                                                                                                        None, None]:
        label_ids, orig_labels = self.labels_to_ids(batch.pop("labels"))
        analysis_batch = AnalysisBatch(labels=label_ids, orig_labels=orig_labels)
        self.run_with_ctx(analysis_batch, batch, batch_idx)
        self.loss_and_logit_diffs(analysis_batch, batch, batch_idx)
        yield from self.analysis_cfg.save_batch(analysis_batch, batch, tokenizer=self.datamodule.tokenizer)

    def collect_answers(self, logits: torch.Tensor | tuple, labels: torch.Tensor, mode: str = "log") -> Optional[Dict]:
        logits = self.standardize_logits(logits)
        per_example_answers, _ = torch.max(logits, dim=-2)
        preds = torch.argmax(per_example_answers, axis=-1)  # type: ignore[call-arg]
        metric_dict = self.metric.compute(predictions=preds, references=labels)
        # TODO: check if this type casting is still required for lightning torchmetrics, bug should be fixed now...
        metric_dict = dict(
            map(lambda x: (x[0], torch.tensor(x[1], device=self.device).to(torch.float32)), metric_dict.items())
        )
        if mode == "log":
            self.log_dict(metric_dict, prog_bar=True, sync_dist=True)
        else:
            return metric_dict

    def standardize_logits(self, logits: torch.Tensor) -> torch.Tensor:
        # to support generative classification/non-generative classification configs and LM/SeqClassification heads we
        # adhere to the following logits logical shape invariant:
        # [batch size, positions to consider, answers to consider]
        if isinstance(logits, tuple):
            logits = torch.stack([out for out in logits], dim=1)
        logits = logits.to(device=self.device)
        if logits.ndim == 2:  # if answer logits have already been squeezed
            logits = logits.unsqueeze(1)
        if logits.shape[-1] != self.it_cfg.num_labels:
            logits = torch.index_select(logits, -1, self.it_cfg.entailment_mapping_indices)
            if not self.it_cfg.generative_step_cfg.enabled:
                logits = logits[:, -1:, :]
        return logits

    def logits_and_labels(
        self, batch: BatchEncoding, batch_idx: int, run_ctx: str, hook_names: Optional[str] = None
    ) -> torch.Tensor:
        label_ids, labels = self.labels_to_ids(batch.pop("labels"))
        cache = None
        logits, cache = self.run_with_ctx(run_ctx, batch, batch_idx, hook_names=hook_names)
        # TODO: add another layer of abstraction here to handle different model output types? Tradeoffs to consider...
        if not isinstance(logits, torch.Tensor):
            logits = logits.logits
            assert isinstance(logits, torch.Tensor), f"Expected logits to be a torch.Tensor but got {type(logits)}"
        return torch.squeeze(logits[:, -1, :], dim=1), label_ids, labels, cache


### Configure our IT Session


In [None]:
from interpretune.adapters.transformer_lens import TLensGenerationConfig
from interpretune.base.config.mixins import HFFromPretrainedConfig
from interpretune.adapters.transformer_lens import ITLensFromPretrainedNoProcessingConfig
from interpretune.base.config.shared import ITSharedConfig, AutoCompConfig

shared_cfg = ITSharedConfig(model_name_or_path='gpt2', task_name='rte', tokenizer_id_overrides={'pad_token_id': 50256},
                  tokenizer_kwargs={'model_input_names': ['input'], 'padding_side': 'left', 'add_bos_token': True})
datamodule_cfg = ITDataModuleConfig(prompt_cfg=RTEBoolqPromptConfig(), train_batch_size=2, eval_batch_size=2,
                                    signature_columns=['input', 'labels'], prepare_data_map_cfg={"batched": True})
genclassif_cfg = GenerativeClassificationConfig(enabled=True, lm_generation_cfg=TLensGenerationConfig(max_new_tokens=1))
hf_cfg = HFFromPretrainedConfig(pretrained_kwargs={'torch_dtype': 'float32'}, model_head='transformers.GPT2LMHeadModel')
tl_cfg = ITLensFromPretrainedNoProcessingConfig(model_name="gpt2-small", default_padding_side='left')
sae_cfgs = [SAELensFromPretrainedConfig(release=sae_fqn.release, sae_id=sae_fqn.sae_id) for sae_fqn
            in tutorial_config.sae_analysis_targets.sae_fqns]
auto_comp_cfg = AutoCompConfig(module_cfg_name='RTEBoolqConfig', module_cfg_mixin=RTEBoolqEntailmentMapping)
module_cfg = ITConfig(auto_comp_cfg=auto_comp_cfg, generative_step_cfg=genclassif_cfg, hf_from_pretrained_cfg=hf_cfg,
                      tl_cfg=tl_cfg, sae_cfgs=sae_cfgs)
session_cfg = ITSessionConfig(adapter_ctx=(Adapter.core, Adapter.sae_lens),
                              datamodule_cls=RTEBoolqDataModule, module_cls=RTEBoolqModule,
                              shared_cfg=shared_cfg, datamodule_cfg=datamodule_cfg, module_cfg=module_cfg)
it_session = ITSession(session_cfg)
# TODO: maybe open a PR for the below
# https://github.com/jbloomAus/SAELens/blob/aa8f42bf06d9c68bb890f4881af0aac916ecd17c/sae_lens/sae.py#L144-L151 warning
# that inspects whether the loaded model has a default config override specified in ``pretrained_saes.yaml`` (e.g.
# 'gpt2-small-res-jb', config_overrides: model_from_pretrained_kwargs: center_writing_weights: true) and if so, avoids
# giving an arguably spurious warning to the user

### Run Demo Analyses

In [None]:
# We of course could manually init IT components here and run the exploratory analysis ops one by one 
# (or our own custom analysis)
# it_init(**it_session)
# TODO: add more demo manual commands and likely a separate demo tutorial for manual analysis (AI model generated)
#       after abstractions stabilize and are better refined/integrated


run_config = AnalysisRunnerCfg(it_session=it_session, max_epochs=1, analysis_set_cfg=tutorial_config)
runner = AnalysisRunner(run_cfg=run_config)
sl_test_module = run_config.module
analysis_results = runner.run_analysis_set()



### Review Demo Results

#### Clean vs SAE Sample-wise Logit Diffs

In [None]:
if {ANALYSIS_OPS['logit_diffs.base'], ANALYSIS_OPS['logit_diffs.sae']}.issubset(set(tutorial_config.analysis_ops)):
    base_vs_sae_logit_diffs(sae=analysis_results[ANALYSIS_OPS['logit_diffs.sae']],
                            base_ref=analysis_results[ANALYSIS_OPS['logit_diffs.base']],
                            top_k=tutorial_config.top_k_clean_logit_diffs,
                            tokenizer=sl_test_module.datamodule.tokenizer)


#### Proportion Correct Answers on Dataset By Analysis Op

In [None]:
# TODO: NEXT: debug reconstruction here.
pred_summaries = {op: compute_correct(summ, op) for op, summ in analysis_results.items()}

table_rows = []
for op, (total_correct, percentage_correct, _) in pred_summaries.items():
    table_rows.append([op, total_correct, f"{percentage_correct:.2f}%"])

print(tabulate(table_rows, headers=["Op", "Total Correct", "Percentage Correct"], tablefmt="grid"))

#### Per Batch Ablation Effect Graphs [Optional]

In [None]:
if tutorial_config.latent_effects_graphs and ANALYSIS_OPS['logit_diffs.attribution.ablation'] in tutorial_config.analysis_ops:
    # TODO: add note that only latent effects associated with correct answers currently displayed
    # TODO: allow toggling correct filtering during runs
    analysis_results[ANALYSIS_OPS['logit_diffs.attribution.ablation']].plot_latent_effects(per_batch=tutorial_config.latent_effects_graphs_per_batch)

#### Per-SAE Ablation Effects

In [None]:

if ANALYSIS_OPS['logit_diffs.attribution.ablation'] in tutorial_config.analysis_ops:
    ablation_batch_preds = pred_summaries[ANALYSIS_OPS['logit_diffs.attribution.ablation']].batch_predictions
    activation_summary = analysis_results[ANALYSIS_OPS['logit_diffs.sae']].calc_activation_summary()
    ablation_metrics = analysis_results[ANALYSIS_OPS['logit_diffs.attribution.ablation']].calculate_latent_metrics(
        pred_summ=pred_summaries[ANALYSIS_OPS['logit_diffs.attribution.ablation']],
        activation_summary=activation_summary,
        # filter_by_correct=True,
        run_name="logit_diffs.attribution.ablation"
    )

    tables = ablation_metrics.create_attribution_tables(top_k=tutorial_config.top_k_latents_table, filter_type='both',
                                                        per_sae=tutorial_config.latents_table_per_sae)

    for title, table in tables.items():
        print(f"\n{title}\n{table}\n")

    sl_test_module.display_latent_dashboards(ablation_metrics, title="Ablation-Mediated Latent Analysis",
                              sae_release=tutorial_config.sae_analysis_targets.sae_release, 
                              top_k=tutorial_config.top_k_latent_dashboards)



#### Per-SAE Attribution Patching Effects

In [None]:
if ANALYSIS_OPS['logit_diffs.attribution.grad_based'] in tutorial_config.analysis_ops:
    # per-SAE activation summaries are calculated using our AnalysisStore since the relevant keys are present,
    # no need to provide a separate activation summary from another comparison cache in this case as with ablation
    activation_summary = analysis_results[ANALYSIS_OPS['logit_diffs.attribution.grad_based']].calc_activation_summary()
    attribution_patching_metrics = analysis_results[ANALYSIS_OPS['logit_diffs.attribution.grad_based']].calculate_latent_metrics(
        pred_summ=pred_summaries[ANALYSIS_OPS['logit_diffs.attribution.grad_based']],
        run_name="logit_diffs.attribution.grad_based"
    )

    tables = attribution_patching_metrics.create_attribution_tables(top_k=tutorial_config.top_k_latents_table,
                                                                    filter_type='both',
                                                                    per_sae=tutorial_config.latents_table_per_sae)

    for title, table in tables.items():
        print(f"\n{title}\n{table}\n")

    sl_test_module.display_latent_dashboards(attribution_patching_metrics, 
                                             title="Attribution Patching-Mediated Latent Analysis",
                                             sae_release=tutorial_config.sae_analysis_targets.sae_release, 
                                             top_k=tutorial_config.top_k_latent_dashboards)


#### Per-SAE Ablation vs Attribution-Patching Effect Parity

In [None]:
if {ANALYSIS_OPS['logit_diffs.attribution.grad_based'], ANALYSIS_OPS['logit_diffs.attribution.ablation']}.issubset(tutorial_config.analysis_ops):
    # Visualize results for each hook
    # Call the function with our metrics
    latent_metrics_scatter(ablation_metrics, attribution_patching_metrics, label1="Ablation", label2="Attribution Patching")


In [None]:
# TODO: continue implementation here
# NEXT: implement DFA using activationstore on rte val dataset and plot for top latents 
