## Tutorial of Interventions on Non-transformer Model: MLPs

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "12/20/2023"

### Overview

This tutorials show how to use this library on non-transformer models, such as MLPs. The set-ups are pretty much the same as standard transformer-based models.

### Set-up

In [2]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import pyvene

except ModuleNotFoundError:
    !pip install git+https://github.com/frankaging/pyvene.git

In [3]:
import torch
import pandas as pd
from pyvene import embed_to_distrib, top_vals, format_token
from pyvene import (
    IntervenableModel,
    VanillaIntervention,
    RotatedSpaceIntervention,
    LowRankRotatedSpaceIntervention,
    IntervenableRepresentationConfig,
    IntervenableConfig,
)
from pyvene.models.mlp.modelings_mlp import MLPConfig
from pyvene import create_mlp_classifier

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
)

config, tokenizer, mlp = create_mlp_classifier(MLPConfig(h_dim=32, n_layer=1, num_classes=5))

loaded model


### Intervene in middle layer by partitioning representations into subspaces

MLP layer may contain only a single "token" representation each layer. As a result, we often want to intervene on a subspace of this "token" representation to localize a concept.

In [5]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(mlp),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",  # mlp layer creates a single token reprs
            1,
            subspace_partition=[
                [0, 16],
                [16, 32],
            ],  # partition into two sets of subspaces
        ),
    ],
    intervenable_interventions_type=RotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, mlp)

base = {"inputs_embeds": torch.rand(1, 1, 32)}
source = {"inputs_embeds": torch.rand(1, 1, 32)}
print("base", intervenable(base))
print("source", intervenable(source))

base ((tensor([[ 0.1161,  0.3125, -0.0301, -0.0866, -0.2650]]),), None)
source ((tensor([[-0.0336,  0.2797, -0.1006, -0.1071, -0.1748]]),), None)


In [6]:
_, counterfactual_outputs = intervenable(
    base, [source], {"sources->base": ([[[0]]], [[[0]]])}, subspaces=[[[1, 0]]]
)

In [7]:
counterfactual_outputs  # this should be the same as source.

(tensor([[-0.0336,  0.2797, -0.1006, -0.1071, -0.1748]],
        grad_fn=<SqueezeBackward1>),)

### Intervene the subspace with multiple sources

In [8]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(mlp),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",  # mlp layer creates a single token reprs
            1,
            intervenable_low_rank_dimension=32,
            subspace_partition=[
                [0, 16],
                [16, 32],
            ],  # partition into two sets of subspaces
            intervention_link_key=0,  # linked ones target the same subspace
        ),
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",  # mlp layer creates a single token reprs
            1,
            intervenable_low_rank_dimension=32,
            subspace_partition=[
                [0, 16],
                [16, 32],
            ],  # partition into two sets of subspaces
            intervention_link_key=0,  # linked ones target the same subspace
        ),
    ],
    intervenable_interventions_type=LowRankRotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, mlp)

base = {"inputs_embeds": torch.rand(10, 1, 32)}
source = {"inputs_embeds": torch.rand(10, 1, 32)}
print("base", intervenable(base))
print("source", intervenable(source))

base ((tensor([[ 0.0451,  0.2397, -0.0225, -0.1322, -0.1702],
        [-0.1273,  0.1582, -0.1209,  0.0013, -0.0770],
        [ 0.0576,  0.1957, -0.0486, -0.1787, -0.1354],
        [-0.0354,  0.1983, -0.0879, -0.0428, -0.1376],
        [ 0.0157,  0.2301,  0.1082, -0.0964, -0.1618],
        [ 0.0272,  0.1491, -0.0361, -0.0419, -0.1055],
        [ 0.0647,  0.1679, -0.0025, -0.1478, -0.1277],
        [ 0.0646,  0.1757,  0.0718, -0.0831, -0.2018],
        [-0.0333,  0.1445, -0.0088, -0.0406, -0.1593],
        [-0.0250,  0.1420, -0.0297, -0.0605, -0.0992]]),), None)
source ((tensor([[ 0.0345,  0.2049,  0.0838, -0.0803, -0.1454],
        [-0.0685,  0.0413,  0.0412, -0.0442, -0.1103],
        [ 0.0321,  0.1516,  0.0290, -0.1087, -0.1988],
        [ 0.0186,  0.2466, -0.0577, -0.0954, -0.1735],
        [-0.0058,  0.2023, -0.0193,  0.0034, -0.1716],
        [-0.0757,  0.1766,  0.0303, -0.1014, -0.2228],
        [ 0.0605,  0.2042, -0.0676, -0.1082, -0.2676],
        [ 0.0101,  0.2911, -0.0020,  0.

In [9]:
_, counterfactual_outputs = intervenable(
    base,
    [source, source],
    {"sources->base": ([[[0]] * 10, [[0]] * 10], [[[0]] * 10, [[0]] * 10])},
    subspaces=[[[1]] * 10, [[0]] * 10],
)
print(counterfactual_outputs)  # this should be the same as the source output
counterfactual_outputs[
    0
].sum().backward()  # fake call to make sure gradient can be populated

(tensor([[ 0.0345,  0.2049,  0.0838, -0.0803, -0.1454],
        [-0.0685,  0.0413,  0.0412, -0.0442, -0.1103],
        [ 0.0321,  0.1516,  0.0290, -0.1087, -0.1988],
        [ 0.0186,  0.2466, -0.0577, -0.0954, -0.1735],
        [-0.0058,  0.2023, -0.0193,  0.0034, -0.1716],
        [-0.0757,  0.1766,  0.0303, -0.1014, -0.2228],
        [ 0.0605,  0.2042, -0.0676, -0.1082, -0.2676],
        [ 0.0101,  0.2911, -0.0020,  0.0106, -0.2927],
        [ 0.0720,  0.2538, -0.0988, -0.0858, -0.1759],
        [ 0.0441,  0.2026,  0.0353,  0.0047, -0.1161]],
       grad_fn=<SqueezeBackward1>),)
