# Mechanistic Interpretability

## 1. Find induction heads where we know they exist. 

Take a GPT-2 or LLama model, visualize the attention patterns, and find the heads that seem to be doing induction. Additionally, see if you can automatically rank the heads by how much they attend to previous tokens like induction heads on synthetic samples `A B X_1 ... X_N A`. 

## 2. Apply the same two techniques to the ICL pretrained transformer

In [1]:
import os
from dotenv import load_dotenv

load_dotenv();

In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt

from icl.analysis.utils import get_unique_run
from devinterp.mechinterp.activations import ActivationProbe




In [3]:
run = get_unique_run(
    "../sweeps/small-sweep.yaml", 
    task_config={"num_tasks": 1, "num_layers": 2},
    optimizer_config={"lr": 0.001}
)

  ws_hat = torch.linalg.solve(LHS, RHS)   # BKDD^-1 @ BKD1 -> B K D 1


In [None]:
# Confirm that we've loaded in the most recent model
run.evaluator(run.model)

In [5]:
list(run.model.state_dict().keys())

['token_sequence_transformer.token_embedding.weight',
 'token_sequence_transformer.postn_embedding.weight',
 'token_sequence_transformer.blocks.0.attention.causal_mask',
 'token_sequence_transformer.blocks.0.attention.attention.weight',
 'token_sequence_transformer.blocks.0.compute.0.weight',
 'token_sequence_transformer.blocks.0.compute.0.bias',
 'token_sequence_transformer.blocks.0.compute.2.weight',
 'token_sequence_transformer.blocks.0.compute.2.bias',
 'token_sequence_transformer.blocks.0.layer_norms.0.weight',
 'token_sequence_transformer.blocks.0.layer_norms.0.bias',
 'token_sequence_transformer.blocks.0.layer_norms.1.weight',
 'token_sequence_transformer.blocks.0.layer_norms.1.bias',
 'token_sequence_transformer.blocks.1.attention.causal_mask',
 'token_sequence_transformer.blocks.1.attention.attention.weight',
 'token_sequence_transformer.blocks.1.compute.0.weight',
 'token_sequence_transformer.blocks.1.compute.0.bias',
 'token_sequence_transformer.blocks.1.compute.2.weight',
 

In [None]:
from pathlib import Path
figures = Path("../figures/M=1/across-x-basis-w=0")

output, activations_ = hooked_model.run_with_cache(x_trick, y_trick)

def separate_attention(qkv, num_heads: int, batch_size: int, head_size: int, num_tokens: int):
    return (qkv    # B T C @ C 3C  -> B T 3C
        .view(batch_size, num_tokens, num_heads, 3*head_size)     #               -> B T H 3c
        .transpose(-2, -3)      #               -> B H T 3c
        .split(head_size, dim=-1)       #               -> (B H T c) * 3
    )

E = 64
T = 16
H = 4
B = 4

# separate_attention(activations_, num_heads=4, batch_size=4, head_size=64//4, num_tokens=16)

def optionally_rotate(x, name,):
    if len(x.shape) != 2:
        raise ValueError("Tensor should have two dimensions.")

    if x.shape[0] > x.shape[1]:
        return x.T, f"{name}.T"
    
    return x, name

activations = {}
activations["x"] = x_trick
activations["y"] = y_trick
activations.update(activations_)

for i in range(4):
    if not os.path.exists(figures / f"{i}"):
        os.makedirs(figures / f"{i}")

    for location, v in activations.items():
        activation_slice = v[i] # Batch idx

        if location.endswith("attention.attention"):
            q, k, v = separate_attention(v, num_heads=H, batch_size=B, head_size=E//H, num_tokens=T)
            qk = q @ k.transpose(-2, -1)
            q, k, qk, v = q[i], k[i], v[i], qk[i]
            
            fig, axs = plt.subplots(H, 4, figsize=(15, 15))

            for j, (name, x) in enumerate(zip(["Q", "K", "QK", "V"], [q, k, qk, v])):
                for h in range(H):
                    ax = axs[h, j]
                    ax.matshow(x[h].detach().to("cpu").numpy())
                    ax.set_title(f"{h}.{name}")

            plt.suptitle(f"{location}")
            plt.savefig(figures / f"{i}/{location}.png")
            plt.show()
        elif len(activation_slice.shape) == 2:
            x, name = optionally_rotate(activation_slice, location)
            plt.matshow(x.detach().to("cpu").numpy())
            plt.title(f"{i} {name}")
            plt.savefig(figures / f"{i}/{name}.png")
            plt.show()
        elif len(activation_slice.shape) == 3:  # [heads, xs, ys]
            heads, xs, ys = activation_slice.shape
            fig, axs = plt.subplots(1, heads, figsize=(15, 15))
            for j in range(heads):
                ax = axs[j]
                x, name = optionally_rotate(activation_slice[j], str(j))
                ax.matshow(x.detach().to("cpu").numpy())
                ax.set_title(name)
            plt.suptitle(f"{location}.#")
            plt.savefig(figures / f"{i}/{location}.png")
            plt.show()
        else:
            raise ValueError("Unsupported number of dimensions.")


In [None]:
import os
from PIL import Image

# List of folder paths containing images
folder_paths = [figures / f"{i}" for i in range(4)]  # Add more folders as needed

if not os.path.exists(figures / "overview"):
    os.makedirs(figures/"overview")

# Create a dictionary to store images by filename
images_by_filename = {}

# Load images from each folder and organize them by filename
for folder_path in folder_paths:
    filenames = [f for f in os.listdir(folder_path) if f.endswith('.png')]  # Change the extension as needed
    for filename in filenames:
        img = Image.open(os.path.join(folder_path, filename))
        if filename in images_by_filename:
            images_by_filename[filename].append(img)
        else:
            images_by_filename[filename] = [img]

print(images_by_filename)

# Create comparison images for each unique filename
for filename, image_list in images_by_filename.items():
    # Calculate the width and height of the result image
    width = sum(img.width for img in image_list)
    height = max(img.height for img in image_list)

    # Create a new image for the comparison
    result_image = Image.new('RGB', (width, height))

    # Paste images side by side
    x_offset = 0
    for img in image_list:
        result_image.paste(img, (x_offset, 0))
        x_offset += img.width

    # Display or save the result image
    result_image.save(figures / f"overview/{filename}")  # You can replace this with result_image.save() to save the comparison images

In [None]:
from copy import deepcopy
import operator
from typing import Callable, List

def patch(module: nn.Module):
    if isinstance(module, nn.ModuleList):
        return PatchedList(module)
    elif isinstance(module, nn.Sequential):
        return PatchedSequential(module)
    else:
        return Patched(module)


class Patched(nn.Module):
    def __init__(self, module: nn.Module):
        super().__init__()
        self.__dict__["_current"] = module
        self.__dict__["_original"] = module

        for n, c in self._current.named_children():
            print(n, c)
            self.add_module(n, patch(c))

    def forward(self, *args, **kwargs):
        return self._current(*args, **kwargs)
    
    def __getattr__(self, name):
        try:
            return super().__getattr__(name)
        except AttributeError:
            return self._current.__getattr__(name)
        
    def __setattr__(self, name, value):
        current_value = getattr(self, name)

        if isinstance(current_value, (Patched, PatchedList, PatchedSequential)):
            current_value._current = value
        else:
            super().__setattr__(name, value)

    def set(self, new_value):
        self._current = new_value
    
    def reset(self):
        self._current = self._original

        for c in self._current.children():
            if isinstance(c, (Patched, PatchedList, PatchedSequential)):
                c.reset()

    @property
    def _current(self):
        return self.__dict__["_current"]
    
    @property
    def _original(self):
        return self.__dict__["original"]


class PatchedList(nn.Module):
    def __init__(self, module_list: nn.ModuleList):
        super().__init__()
        module_list = [patch(c) for c in module_list]
        self.__dict__["_current"] = module_list
        self.__dict__["_original"] = module_list

        for i, module in module_list:
            self.add_module(str(i), module)

    def __setattr__(self, name, value):
        current_value = getattr(self, name)

        if isinstance(current_value, (Patched, PatchedList, PatchedSequential)):
            current_value._current = value
        else:
            super().__setattr__(name, value)

    def __setitem__(self, index, value):
        self[index].set(value)

    def __getitem__(self, index):
        return self._current[index]
    
    def __iter__(self):
        return iter(self._current)
    
    def __len__(self):
        return len(self._current)

    def set(self, new_value):
        self._current = new_value
    
    def reset(self):
        self._current = self._original

        for c in self:
            if isinstance(c, (Patched, PatchedList, PatchedSequential)):
                c.reset()

    @property
    def _current(self):
        return self.__dict__["_current"]
    
    @property
    def _original(self):
        return self.__dict__["original"]
    
class PatchedSequential(PatchedList):
    def forward(self, x):
        for layer in self:
            x = layer(x)

        return x


def set_head_to(i, output):
    def new_softmax(self, x):
        y = nn.functional.softmax(x, dim=-1)
        y[:, i, :, :] = output
        return y 
    
    return new_softmax

def set_heads_to(mappings):
    def new_softmax(self, x):
        y = nn.functional.softmax(x, dim=-1)

        for k, v in mappings.items():
            y[:, k, :, :] = v
        return y 
    
    return new_softmax

class Patch(nn.Module):
    def __init__(self, callable: Callable):
        self.callable = callable

    def forward(self, *args, **kwargs):
        return self.callable(*args, **kwargs)


def patch_(module: nn.Module, callable: Callable):
    if isinstance(module, Hooked):
        module._original = module._forward
        module._forward = callable

        def reset():
            module._forward = module._original
            del module._forward
            del module.reset

        module.reset = reset
        return module

    else:
        module._original = module.forward
        module.forward = callable

        def reset():
            module.forward = module._original
            del module._forward
            del module.reset

        module.reset = reset
        return module

model = deepcopy(run.model)
patched_model = model #  patch(model)
hooked_patched_model = hook(patched_model) # hook(patch(run.model))

def apply_binop_dicts(d1, d2, op):
    return {
        k: op(d1[k], d2[k]) for k in d1.keys()
    }

# probe = ActivationProbe(masked_model, "token_sequence_transformer.blocks.0.attention.attention_softmax")
# probe.register_hook()
_, activations = hooked_patched_model.run_with_cache(x_trick[:2], y_trick[:2])
plot_activations([activations["token_sequence_transformer.blocks.0.attention.attention_softmax"]])

evals_1 = run.evaluator(hooked_patched_model)

head_2 = torch.zeros((16, 16), device="mps")
for i in range(0,16, 2):
    head_2[i, i] = 1

for i in range(0, 16, 2):
    head_2[i+1, i] = head_2[i+1, i+1] = 0.5


head_3 = torch.zeros((16, 16), device="mps")
for i in range(15):
    head_3[i+1, i] = 1
head_3[0, 0] = 1

layer = patch_(hooked_patched_model.token_sequence_transformer.blocks[0].attention.attention_softmax, set_heads_to({1: head_2, 2: head_3}))
_, activations = hooked_patched_model.run_with_cache(x_trick[:2], y_trick[:2])
plot_activations([activations["token_sequence_transformer.blocks.0.attention.attention_softmax"]])

evals_2 = run.evaluator(hooked_patched_model)

apply_binop_dicts(evals_1, evals_2, lambda x, y: (y-x)/x)


In [None]:
# Ready to investigate
run.model

In [None]:
from typing import List
from torchtyping import TensorType


def get_attention(model, xs, ys):
    num_layers = len(model.token_sequence_transformer.blocks)
    probes = []

    for b in range(num_layers):
        probe = ActivationProbe(model, f"token_sequence_transformer.blocks.{b}.attention.attention_softmax")
        probe.register_hook()
        probes.append(probe)

    # Run the model
    model(xs, ys)

    for probe in probes:
        probe.unregister_hook()

    # Get the activations
    return [probe.activation for probe in probes]

def plot_activations(activations: List[TensorType["batch", "heads", "tokens", "tokens"]]):       
    num_layers = len(activations)
    num_samples, num_heads, num_tokens, _ = activations[0].shape

    for sample_idx in range(num_samples):
        # Create a new figure
        plt.figure(figsize=(15, 4 * num_layers))

        # Loop through each head
        for layer_idx, activation in enumerate(activations):
            for head_idx in range(num_heads):
                head_activation = activation[sample_idx, head_idx].detach().cpu().numpy()

                # Create a subplot for each head
                ax = plt.subplot(num_layers, num_heads, layer_idx * num_heads + head_idx + 1)

                # Plot the activation
                ax.imshow(head_activation, cmap='viridis', aspect='auto')

                # Add title and labels
                ax.set_title(f'Layer {layer_idx + 1}, Head {head_idx + 1}')
                ax.set_xlabel('Keys')
                ax.set_ylabel('Queries')

        plt.tight_layout()
        plt.show()

def compose2(f, g):
    return lambda *a, **kw: f(g(*a, **kw))

def compose(*fs):
    from functools import reduce
    return reduce(compose2, fs)

get_and_plot_activations = compose(plot_activations, get_attention)

In [None]:
ws = run.pretrain_dist.task_distribution.tasks
print(ws)

In [None]:
from icl.tasks import apply_transformations

# x_trick = torch.zeros((4, 8, 4))
# x_trick[:, :, 0] = torch.arange(0, 8)
# # y_trick = torch.zeros((1, 8, 1))
# x_trick = x_trick.to("mps")
# y_trick = apply_transformations(ws, x_trick, run.pretrain_dist.std, device="mps")

x_trick = torch.zeros((4, 8, 4))
for i in range(4):
    x_trick[i, :, i] = torch.arange(0, 8)

# y_trick = torch.zeros((1, 8, 1))
x_trick = x_trick.to("mps")
y_trick = apply_transformations(ws[0].repeat(4), x_trick, run.pretrain_dist.std, device="mps")

for i in range(4):
    plt.matshow(x_trick[i].T.detach().cpu().numpy())

plt.matshow(y_trick.detach().cpu().numpy())

In [None]:
# xs, ys = run.pretrain_dist.get_batch(8, 1)
# get_and_plot_activations(run.model, xs=xs, ys=ys)
get_and_plot_activations(run.model, xs=x_trick, ys=y_trick)

In [None]:
run_2 = get_unique_run(
    "../sweeps/small-sweep.yaml", 
    task_config={"num_tasks": 65536, "num_layers": 2},
    optimizer_config={"lr": 0.001}
)

xs_2, ys_2 = run_2.pretrain_dist.get_batch(8, 1)
get_and_plot_activations(run_2.model, xs=xs_2, ys=ys_2)
get_and_plot_activations(run_2.model, xs=x_trick, ys=y_trick)

In [None]:
run_3 = get_unique_run(
    "../sweeps/small-sweep.yaml", 
    task_config={"num_tasks": 64, "num_layers": 2},
    optimizer_config={"lr": 0.001}
)

xs_3, ys_3 = run_3.pretrain_dist.get_batch(8, 1)
get_and_plot_activations(run_3.model, xs=xs_3, ys=ys_3)
get_and_plot_activations(run_3.model, xs=x_trick, ys=y_trick)