# Interpretune SAELens Tutorial

![Fine-Tuning Scheduler logo](logo_fts.png){height="55px" width="401px"}

### Intro

[Interpretune](https://github.com/speediedan/interpretune) is a flexible framework for exploring, analyzing and tuning 
llm world models. In this tutorial, we'll walk through a simple example of using Interpretune to pursue interpretability 
research with SAELens. As we'll see, Interpretune handles the required execution context composition, allowing us to 
use the same code in a variety of contexts, depending upon the level of abstraction required.

As a long-time PyTorch and PyTorch Lightning contributor, I've found the PyTorch Lightning framework is the right level 
of abstraction for a large variety of ML research contexts, but some contexts benefit from using core PyTorch directly. 
Additionally, some users may prefer to use the core PyTorch framework directly for a wide variety of reasons including 
maximizing portability. As will be demonstrated here, Interpretune maximizes flexibility and portability by adhering to 
a well-defined protocol that allows auto-composition of our research module with the adapters required for execution in 
a wide variety of contexts. In this example, we'll be executing the same module with core PyTorch and PyTorch Lightning, 
demonstrating the use of `SAELens` w/ Interpretune for interpretability research.

> Note - **this is a WIP**, but this is the core idea. If you have any feedback, please let me know!

## A note on memory usage

In these exercises, we'll be loading some pretty large opls into memory (e.g. Gemma 2-2B and its SAEs, as well as a host of other models in later sections of the material). It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook, and all the exercises minus the handful involving Gemma on a free Colab notebook.

<details>
<summary>See this dropdown for some functions which you might find helpful, and how to use them.</summary>

First, we can run some code to inspect our current memory usage. Here's me running this code during the exercise set on SAE circuits, after having already loaded in the Gemma models from the previous section. This was on a Colab Pro notebook.

```python
# Profile memory usage, and delete gemma models if we've loaded them in
namespace = globals().copy() | locals()
part32_utils.profile_pytorch_memory(namespace=namespace, filter_device="cuda:0")
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 35.88 GB
Total = 39.56 GB
Free = 3.68 GB
┌──────────────────────┬────────────────────────┬──────────┬─────────────┐
│ Name                 │ Object                 │ Device   │   Size (GB) │
├──────────────────────┼────────────────────────┼──────────┼─────────────┤
│ gemma_2_2b           │ HookedSAETransformer   │ cuda:0   │       11.94 │
│ gpt2                 │ HookedSAETransformer   │ cuda:0   │        0.61 │
│ gemma_2_2b_sae       │ SAE                    │ cuda:0   │        0.28 │
│ sae_resid_dirs       │ Tensor (4, 24576, 768) │ cuda:0   │        0.28 │
│ gpt2_sae             │ SAE                    │ cuda:0   │        0.14 │
│ logits               │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ logits_with_ablation │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ clean_logits         │ Tensor (4, 15, 50257)  │ cuda:0   │        0.01 │
│ _                    │ Tensor (16, 128, 768)  │ cuda:0   │        0.01 │
│ clean_sae_acts_post  │ Tensor (4, 15, 24576)  │ cuda:0   │        0.01 │
└──────────────────────┴────────────────────────┴──────────┴─────────────┘</pre>

From this, we see that we've allocated a lot of memory for the the Gemma model, so let's delete it. We'll also run some code to move any remaining objects on the GPU which are larger than 100MB to the CPU, and print the memory status again.

```python
del gemma_2_2b
del gemma_2_2b_sae

THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.nn.Module) and part32_utils.get_tensors_size(obj) / 1024**3 > THRESHOLD:
            if hasattr(obj, "cuda"):
                obj.cpu()
            if hasattr(obj, "reset"):
                obj.reset()
    except:
        pass

# Move our gpt2 model & SAEs back to GPU (we'll need them for the exercises we're about to do)
gpt2.to(device)
gpt2_saes = {layer: sae.to(device) for layer, sae in gpt2_saes.items()}

part32_utils.print_memory_status()
```

<pre style="font-family: Consolas; font-size: 14px">Allocated = 14.90 GB
Reserved = 39.56 GB
Free = 24.66</pre>

Mission success! We've managed to free up a lot of memory. Note that the code which moves all objects collected by the garbage collector to the CPU is often necessary to free up the memory. We can't just delete the objects directly because PyTorch can still sometimes keep references to them (i.e. their tensors) in memory. In fact, if you add code to the for loop above to print out `obj.shape` when `obj` is a tensor, you'll see that a lot of those tensors are actually Gemma model weights, even once you've deleted `gemma_2_2b`.

</details>

#### Imports

In [None]:
from transformer_lens import ActivationCache  # noqa: F401
from tabulate import tabulate

import interpretune as it  # registered analysis ops will be available as it.<op> when analysis is imported
from it_examples import _ACTIVE_PATCHES # noqa: F401  # TODO: add note about this unless patched in SL before release
from it_examples.example_module_registry import MODULE_EXAMPLE_REGISTRY  # TODO: move to hub once implemented
from interpretune import ITSessionConfig, ITSession, SAELensFromPretrainedConfig, SAEAnalysisTargets

### Configure our IT Session


Here we define or customize our session configuration, which includes:
1. Experiment/task module and datamodule (in this case, 'rte' for the RTE task) 
    * We can customize any module, datamodule, or adapter-specific configuration options we want to use. In this case, we set target `sae_cfgs` that we want to use for our analysis. We also could customize generation parameters, tokenization, the pretrained/config-based model we want to use (in this case, GPT2) etc.
2. The adapter context we want to use. In this case, `core` PyTorch (vs e.g. Lightning) and `sae_lens` (vs e.g. `transformer_lens`). 

When an `ITSession` is created, the selected adapter context will trigger composition of the relevant adapters with our experiment/task module and datamodule. The intention of this abstraction is to enable the same experiment/task logic to be used unchanged across a broad variety of PyTorch framework and analytical package contexts.


In [None]:
# Load our demo config (this will be done from the hub once that is available)
base_itdm_cfg, base_it_cfg, dm_cls, m_cls = MODULE_EXAMPLE_REGISTRY.get('gpt2.rte_demo.sae_lens')

# update our config with our desired SAE analysis targets
sae_targets = SAEAnalysisTargets(sae_release="gpt2-small-hook-z-kk", target_layers=[9, 10])
sae_cfgs = [SAELensFromPretrainedConfig(release=sae_fqn.release, sae_id=sae_fqn.sae_id) for sae_fqn
            in sae_targets.sae_fqns]
base_it_cfg.sae_cfgs = sae_cfgs

# configure our session with our desired adapter composition, core and sae_lens in this case
session_cfg = ITSessionConfig(adapter_ctx=(it.Adapter.core, it.Adapter.sae_lens), datamodule_cfg=base_itdm_cfg,
                              module_cfg=base_it_cfg, datamodule_cls=dm_cls, module_cls=m_cls)

# start our session
it_session = ITSession(session_cfg)

### Run Demo Analysis 

#### Define Our Analysis Run

We define what analysis we want to run. This includes defining:

1. our latent space targets (sae_analysis_targets in this case)
2. one or more analysis configurations (which can use manual or generated analysis steps)
3. the analysis runner

The `AnalysisRunner` is the core component of Interpretune that handles the execution of our analysis. It takes care of running the analysis operations defined in our analysis set, managing the execution context, and storing the results.

In [None]:
from interpretune import AnalysisRunner, AnalysisCfg, AnalysisStore
# Define our `AnalysisRunner`. We set:
# 1. our analysis targets across all analysis configurations we want to run in the next analysis run
# 2. batch and epoch limits
# 3. ignore any manual `analysis_step` in our provided module because we want to generate analysis steps based on
#    provided operations
run_kwargs = dict(sae_analysis_targets=sae_targets, it_session=it_session, max_epochs=1)
run_config = dict(limit_analysis_batches=3, ignore_manual=True, **run_kwargs)
runner = AnalysisRunner(run_cfg=run_config)

# Define our Analysis Configurations
# here we demo a few different op compositions involving logit differences
auto_logit_diffs_base_cfg = AnalysisCfg(target_op=it.logit_diffs_base, save_prompts=False, save_tokens=False)
auto_logit_diffs_sae_cfg = AnalysisCfg(target_op=it.logit_diffs_sae, save_prompts=True, save_tokens=True)
auto_logit_diffs_attr_grad_cfg = AnalysisCfg(target_op=it.logit_diffs_attr_grad, save_prompts=True, save_tokens=True)
auto_logit_diffs_attr_ablation_cfg = AnalysisCfg(target_op=it.logit_diffs_attr_ablation, save_prompts=False,
                                                 save_tokens=False)

#### Run the Analysis


In [None]:
analysis_results = runner.run_analysis(analysis_cfgs=(auto_logit_diffs_base_cfg, auto_logit_diffs_sae_cfg,
                                                      auto_logit_diffs_attr_grad_cfg,
                                                      auto_logit_diffs_attr_ablation_cfg))

#### Set convenience variables for exploratory analysis

In [None]:
run_cfg = runner.run_cfg
sl_test_module = run_cfg.module  # convenience handle to the module used in the analysis
artifact_cfg = run_cfg.artifact_cfg
# Set tutorial_active_ops based on the keys in analysis_results
if isinstance(analysis_results, dict):
    tutorial_active_ops = set(analysis_results.keys())
elif hasattr(run_cfg, "analysis_cfg") and run_cfg.analysis_cfg:
    # Single analysis configuration
    tutorial_active_ops = {run_cfg.analysis_cfg.name}
else:
    tutorial_active_ops = set()

# If analysis_results is an AnalysisStore, convert to a dict with a single entry
if isinstance(analysis_results, AnalysisStore):
    analysis_results = {run_cfg.analysis_cfg.name: analysis_results}

print(f"Analysis completed for {len(analysis_results) if isinstance(analysis_results, dict) else 1} operations:")
for cfg_name in (analysis_results.keys() if isinstance(analysis_results, dict) else [run_cfg.analysis_cfg.name]):
    print(f"- {cfg_name}")



### Review Demo Results

#### Clean vs SAE Sample-wise Logit Diffs

In [None]:
if {it.logit_diffs_base.name, it.logit_diffs_sae.name}.issubset(tutorial_active_ops):
    from interpretune.analysis import base_vs_sae_logit_diffs
    base_vs_sae_logit_diffs(sae=analysis_results[it.logit_diffs_sae.name],
                            base_ref=analysis_results[it.logit_diffs_base.name],
                            top_k=artifact_cfg.top_k_clean_logit_diffs,
                            tokenizer=sl_test_module.datamodule.tokenizer)


#### Proportion Correct Answers on Dataset By Analysis Op

In [None]:
from interpretune.analysis import compute_correct

pred_summaries = {op: compute_correct(summ, op) for op, summ in analysis_results.items()}
table_rows = []
for op, (total_correct, percentage_correct, _) in pred_summaries.items():
    table_rows.append([op, total_correct, f"{percentage_correct:.2f}%"])

print(tabulate(table_rows, headers=["Op", "Total Correct", "Percentage Correct"], tablefmt="grid"))

#### Per Batch Ablation Effect Graphs [Optional]

In [None]:
if artifact_cfg.latent_effects_graphs and it.logit_diffs_attr_ablation.name in tutorial_active_ops:
    # TODO: add note that only latent effects associated with correct answers currently displayed
    # TODO: allow toggling correct filtering during runs
    analysis_results[it.logit_diffs_attr_ablation.name].plot_latent_effects(per_batch=artifact_cfg.latent_effects_graphs_per_batch)

#### Per-SAE Ablation Effects

In [None]:

if it.logit_diffs_attr_ablation.name in tutorial_active_ops:
    ablation_batch_preds = pred_summaries[it.logit_diffs_attr_ablation.name].batch_predictions
    activation_summary = analysis_results[it.logit_diffs_sae.name].calc_activation_summary()

    ablation_metrics = analysis_results[it.logit_diffs_attr_ablation.name].calculate_latent_metrics(
        pred_summ=pred_summaries[it.logit_diffs_attr_ablation.name],
        activation_summary=activation_summary,
        # filter_by_correct=True,
        run_name="logit_diffs.attribution.ablation"
    )

    tables = ablation_metrics.create_attribution_tables(top_k=artifact_cfg.top_k_latents_table, filter_type='both',
                                                        per_sae=artifact_cfg.latents_table_per_sae)

    for title, table in tables.items():
        print(f"\n{title}\n{table}\n")

    sl_test_module.display_latent_dashboards(ablation_metrics, title="Ablation-Mediated Latent Analysis",
                              sae_release=runner.run_cfg.sae_analysis_targets.sae_release,
                              top_k=artifact_cfg.top_k_latent_dashboards)


#### Per-SAE Attribution Patching Effects

In [None]:
if it.logit_diffs_attr_grad.name in tutorial_active_ops:
    # per-SAE activation summaries are calculated using our AnalysisStore since the relevant keys are present,
    # no need to provide a separate activation summary from another comparison cache in this case as with ablation
    activation_summary = analysis_results[it.logit_diffs_attr_grad.name].calc_activation_summary()
    attribution_patching_metrics = analysis_results[it.logit_diffs_attr_grad.name].calculate_latent_metrics(
        pred_summ=pred_summaries[it.logit_diffs_attr_grad.name],
        run_name="logit_diffs.attribution.grad_based"
    )

    tables = attribution_patching_metrics.create_attribution_tables(top_k=artifact_cfg.top_k_latents_table,
                                                                    filter_type='both',
                                                                    per_sae=artifact_cfg.latents_table_per_sae)

    for title, table in tables.items():
        print(f"\n{title}\n{table}\n")

    sl_test_module.display_latent_dashboards(attribution_patching_metrics,
                                             title="Attribution Patching-Mediated Latent Analysis",
                                             sae_release=runner.run_cfg.sae_analysis_targets.sae_release,
                                             top_k=artifact_cfg.top_k_latent_dashboards)


#### Per-SAE Ablation vs Attribution-Patching Effect Parity

In [None]:
if {it.logit_diffs_attr_grad.name, it.logit_diffs_attr_ablation.name}.issubset(tutorial_active_ops):
    from interpretune.analysis import latent_metrics_scatter
    # Visualize results for each hook
    # Call the function with our metrics
    latent_metrics_scatter(
        ablation_metrics,
        attribution_patching_metrics,
        label1="Ablation",
        label2="Attribution Patching"
    )
