## Tutorial of Interchange Intervention Training

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "01/11/2024"

### Overview

[Interchange Intervention Training](https://arxiv.org/abs/2112.00826) (IIT) is a technique to train neural networks to be interpretable in a data-driven fashion. As it says in its name, it leverages intervention signals to train a neural network. As a result, the network's activations are highly interpretable in a sense that we can intervene them at inference time to get interpretable counterfactual behaviors.

This library supports IIT as it is essentially a vanilla intervention plus enabling gradients for all the model parameters.

### Set-up

In [None]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene

except ModuleNotFoundError:
    !pip install git+https://github.com/frankaging/pyvene.git

In [3]:
from pyvene.models.basic_utils import (
    embed_to_distrib,
    top_vals,
    format_token,
    count_parameters
)

from pyvene import create_gpt2
from pyvene import (
    IntervenableModel, RotatedSpaceIntervention, 
    IntervenableConfig, RepresentationConfig, VanillaIntervention
)

config, tokenizer, gpt = create_gpt2()

loaded model


In [4]:
config = IntervenableConfig(
    model_type=type(gpt),
    representations=[
        RepresentationConfig(
            2,
            "mlp_activation",
            "pos",
            1,
        ),
    ],
    intervention_types=VanillaIntervention,
)
intervenable = IntervenableModel(config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [
    tokenizer("The capital of Italy is", return_tensors="pt"),
]

In [5]:
intervenable.count_parameters()

0

We just need to turn on gradients on all the model parameters

In [6]:
intervenable.enable_model_gradients()
intervenable.count_parameters()

124439808

In [10]:
base_outputs, counterfactual_outputs = intervenable(
    base, sources, {"sources->base": ([[[3]]], [[[3]]])}
)

In [11]:
counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state

tensor([[[ 0.0438,  0.1204,  0.3694,  ..., -0.2660,  0.0809,  0.0310],
         [-0.0778, -0.0170, -0.2844,  ...,  0.0151,  0.0190,  0.1998],
         [ 0.1443, -0.5990,  0.2823,  ..., -0.1331, -0.1422,  0.1267],
         [-0.2162, -0.2819,  0.1670,  ..., -0.1039, -0.1112, -0.0366],
         [ 0.6421, -0.1228, -0.2224,  ..., -0.0918, -0.0167, -0.0540]]],
       grad_fn=<SubBackward0>)

In [12]:
counterfactual_outputs.last_hidden_state.sum().backward()

check any model grad

In [13]:
gpt.h[0].mlp.c_fc.weight.grad

tensor([[ 0.5090, -0.0050, -0.0039,  ..., -0.0109, -0.0106, -0.1192],
        [-0.2290,  0.0316,  0.0467,  ..., -0.0318, -0.0221,  0.0374],
        [-0.1379, -0.0110, -0.0248,  ...,  0.0145, -0.0232, -0.1983],
        ...,
        [-0.3359,  0.0410, -0.0045,  ...,  0.0035, -0.0556, -0.0470],
        [-0.1536,  0.0064, -0.0127,  ...,  0.0150,  0.0037,  0.1006],
        [-0.5015,  0.0190, -0.0021,  ...,  0.0194,  0.0125,  0.0355]])