# Interpretune Circuit Tracer Tutorial

![Fine-Tuning Scheduler logo](logo_fts.png){height="55px" width="401px"}

### Intro

[Interpretune](https://github.com/speediedan/interpretune) is a flexible framework for exploring, analyzing and tuning 
llm world models. In this tutorial, we'll walk through a simple example of using Interpretune to pursue interpretability 
research with Circuit Tracer. As we'll see, Interpretune handles the required execution context composition, allowing us to 
use the same code in a variety of contexts, depending upon the level of abstraction required.

As a long-time PyTorch and PyTorch Lightning contributor, I've found the PyTorch Lightning framework is the right level 
of abstraction for a large variety of ML research contexts, but some contexts benefit from using core PyTorch directly. 
Additionally, some users may prefer to use the core PyTorch framework directly for a wide variety of reasons including 
maximizing portability. As will be demonstrated here, Interpretune maximizes flexibility and portability by adhering to 
a well-defined protocol that allows auto-composition of our research module with the adapters required for execution in 
a wide variety of contexts. In this example, we'll be executing the same module with core PyTorch and PyTorch Lightning, 
demonstrating the use of `Circuit Tracer` w/ Interpretune for circuit discovery and interpretability research.

> Note - **this is a WIP**, but this is the core idea. If you have any feedback, please let me know!

## A note on memory usage

In these exercises, we'll be loading language models into memory for circuit analysis. It's useful to have functions which can help profile memory usage for you, so that if you encounter OOM errors you can try and clear out unnecessary models. For example, we've found that with the right memory handling (i.e. deleting models and objects when you're not using them any more) it should be possible to run all the exercises in this material on a Colab Pro notebook.

<details>
<summary>See this dropdown for some functions which you might find helpful, and how to use them.</summary>

First, we can run some code to inspect our current memory usage. Here's an example of running this code during circuit analysis exercises.

```python
# Profile memory usage
import torch
import gc

if torch.cuda.is_available():
    print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"GPU Memory Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"GPU Memory Free: {(torch.cuda.memory_reserved() - torch.cuda.memory_allocated()) / 1024**3:.2f} GB")
```

If you need to free up memory, you can delete large objects and run garbage collection:

```python
# Delete large objects if needed
# del model
# del circuit_tracer_session

# Move objects to CPU if needed
THRESHOLD = 0.1  # GB
for obj in gc.get_objects():
    try:
        if isinstance(obj, torch.nn.Module):
            # Calculate approximate size
            total_params = sum(p.numel() for p in obj.parameters())
            if total_params * 4 / 1024**3 > THRESHOLD:  # Assuming float32
                if hasattr(obj, "cpu"):
                    obj.cpu()
    except:
        pass

# Force garbage collection
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()
```

This approach helps manage memory when working with large language models during circuit analysis.

</details>

#### Imports

In [None]:
# Import circuit tracer and required modules
from transformer_lens import ActivationCache  # noqa: F401
from pprint import pformat
from datetime import datetime

import interpretune as it  # registered analysis ops will be available as it.<op> when analysis is imported
from it_examples import _ACTIVE_PATCHES  # noqa: F401  # TODO: add note about this unless patched in SL before release
from it_examples.example_module_registry import MODULE_EXAMPLE_REGISTRY  # TODO: move to hub once implemented
from interpretune import ITSessionConfig, ITSession
from interpretune.base.call import it_init

### Configure our IT Session


Here we define or customize our session configuration, which includes:
1. Experiment/task module and datamodule (in this case, 'rte' for the RTE task) 
    * We can customize any module, datamodule, or adapter-specific configuration options we want to use. In this case, we set target `circuit_tracer_cfg` that we want to use for our analysis. We also could customize generation parameters, tokenization, the pretrained/config-based model we want to use (in this case, GPT2) etc.
2. The adapter context we want to use. In this case, `core` PyTorch (vs e.g. Lightning) and `circuit_tracer` (vs e.g. `transformer_lens` or `sae_lens`). 

When an `ITSession` is created, the selected adapter context will trigger composition of the relevant adapters with our experiment/task module and datamodule. The intention of this abstraction is to enable the same experiment/task logic to be used unchanged across a broad variety of PyTorch framework and analytical package contexts.


In [None]:
# Load our demo config (this will be done from the hub once that is available)
base_itdm_cfg, base_it_cfg, dm_cls, m_cls = MODULE_EXAMPLE_REGISTRY.get("gemma2.rte_demo.circuit_tracer")

print(pformat(base_it_cfg.circuit_tracer_cfg))

# configure our session with our desired adapter composition, core and circuit_tracer in this case
session_cfg = ITSessionConfig(
    adapter_ctx=(it.Adapter.core, it.Adapter.circuit_tracer),
    datamodule_cfg=base_itdm_cfg,
    module_cfg=base_it_cfg,
    datamodule_cls=dm_cls,
    module_cls=m_cls,
)

# start our session
it_session = ITSession(session_cfg)
print("\nIT Session created successfully!")

In [None]:
# manual init for now
it_init(**it_session)
print("\nIT Session initialized successfully!")

### Basic Attribution Graph
 

In [None]:
from tqdm.auto import tqdm

limit_analysis_batches = 1
test_token_limit = -1
force_manual_debug_prompts = True  # Set to True to use manual debug prompts instead of random samples
# specific tokens to analyze, will use tokens associated with top `max_n_logits` if `None`
# analysis_target_tokens: Optional[torch.Tensor] = None

example_prompts = []
ct_module = it_session.module
if not force_manual_debug_prompts:
    dataloader = it_session.datamodule.test_dataloader()
    for epoch_idx in range(1):  # Run for a single epoch for simplicity
        ct_module.current_epoch = epoch_idx
        for batch_idx, batch in tqdm(enumerate(dataloader)):
            if batch_idx >= limit_analysis_batches >= 0:
                break
            # fetch the first test_token_limit from the first example in the batch
            first_ex_in_batch = batch[:1]
            first_ex_in_batch = first_ex_in_batch["input"]
            first_ex_in_batch.squeeze_()
            if test_token_limit > 0:
                first_ex_in_batch = first_ex_in_batch[-test_token_limit:]
            first_ex_in_batch = first_ex_in_batch[first_ex_in_batch != 0]
            example_prompts.append(first_ex_in_batch)
else:
    # Generate attribution graphs for a few example prompts
    example_prompts = [
        # "The capital of France is",
        "The capital of the state containing Dallas is",
        # "When I look at the sky, I see",
    ]

In [None]:
print("Generating attribution graphs for example prompts...")
slug_base = "it_circuit_tracer_compute_specific_logits_demo"
results = []

for i, prompt in enumerate(example_prompts):
    print(f"\nProcessing prompt {i + 1}: '{prompt}'")
    slug = f"{slug_base}_{i + 1}_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
    # Process the batch using the session, the adapter will handle tokenization and graph generation
    try:
        graph, local_graph_path, _ = ct_module.generate_graph(prompt=prompt, slug=slug)
        results.append(local_graph_path)
    except Exception as e:
        print(f"  - Error processing prompt: {e}")

print(f"\nProcessed {len(results)} prompts successfully")

### Saving and Visualizing Attribution Graphs

In this section, we'll demonstrate how to save the generated attribution graphs and prepare them for visualization. The CircuitTracerAdapter integrates with Interpretune's AnalysisStore to persistently store graph data.

In [None]:
from circuit_tracer.frontend.local_server import serve


enable_iframe = False  # whether to enable the IFrame display or not

port = 8046
server = serve(data_dir=ct_module.circuit_tracer_cfg.graph_output_dir, port=port)
# TODO: make this configurable at the top or request the user setup port forwarding and use localhost
port_forwarding = False  # whether to use port forwarding or not
# hostname = 'localhost'  # the hostname of the server where the graph files are hosted
hostname = "speediedl"

if port_forwarding:
    hostname = "localhost"  # use localhost for port forwarding
    print(
        f"Using port forwarding (ensure it is configured) and localhost. Open your graph here at http://{hostname}:{port}/index.html"
    )
else:
    print(
        f"Not using port forwarding. Use the IFrame below, or open your graph here directly at http://{hostname}:{port}/index.html"
    )

if enable_iframe:
    from IPython.display import IFrame

    # Display the IFrame with the graph visualization
    print(f"Displaying graph visualization in IFrame at http://{hostname}:{port}/index.html")
    display(IFrame(src=f"http://{hostname}:{port}/index.html", width="100%", height="800px"))

In [None]:
server.stop()

### Next Steps and Future Extensions

This notebook demonstrates the basic scaffolding for the CircuitTracerAdapter. The current implementation provides:

1. **Basic Integration**: CircuitTracerAdapter integrates with Interpretune's session management
2. **Configuration**: CircuitTracerConfig allows customization of attribution parameters
3. **Protocol Support**: CircuitAnalysisBatchProtocol defines the interface for batch processing
4. **Adapter Composition**: Seamless integration with other Interpretune adapters

#### Future Extensions:

1. **Full Implementation**: Complete the adapter methods to actually generate attribution graphs
2. **Batch Processing**: Support for efficient batch attribution analysis
3. **Advanced Analysis**: Integration with AnalysisOp for complex circuit analysis workflows
4. **Visualization**: Built-in support for graph visualization and exploration
5. **Caching**: Intelligent caching of attribution results for faster iteration
6. **Model Support**: Extended support for different model architectures beyond GPT-2

#### Resources:

- [Circuit Tracer Documentation](https://github.com/jacobdunefsky/circuit-tracer)
- [Interpretune Documentation](https://github.com/speediedan/interpretune)
- [Attribution Methods Paper](https://arxiv.org/abs/2310.10348)

This scaffold provides a solid foundation for building sophisticated circuit analysis workflows with Interpretune and Circuit Tracer.