## Tutorial of Interventions on Non-transformer Model: MLPs

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "12/20/2023"

### Overview

This tutorials show how to use this library on non-transformer models, such as MLPs. The set-ups are pretty much the same as standard transformer-based models.

### Set-up

In [2]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys
    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys
    sys.path.append("align-transformers/")

In [3]:
import sys
sys.path.append("../..")

import torch
import pandas as pd
from models.basic_utils import embed_to_distrib, top_vals, format_token
from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.interventions import VanillaIntervention, RotatedSpaceIntervention
from models.mlp.modelings_mlp import MLPConfig
from models.mlp.modelings_alignable_mlp import create_mlp_classifier

%config InlineBackend.figure_formats = ['svg']
from plotnine import ggplot, geom_tile, aes, facet_wrap, theme, element_text, \
                     geom_bar, geom_hline, scale_y_log10

config, tokenizer, mlp =create_mlp_classifier(MLPConfig(h_dim=32))

loaded model


### Intervene in middle layer by partitioning representations into subspaces

MLP layer may contain only a single "token" representation each layer. As a result, we often want to intervene on a subspace of this "token" representation to localize a concept.

In [4]:
alignable_config = AlignableConfig(
    alignable_model_type=type(mlp),
    alignable_representations=[
        AlignableRepresentationConfig(
            0,
            "block_output",
            "pos",                              # mlp layer creates a single token reprs
            1,
            subspace_partition=[[0,16],[16,32]] # partition into two sets of subspaces
        ),
    ],
    alignable_interventions_type=RotatedSpaceIntervention,
)
alignable = AlignableModel(alignable_config, mlp)

base = {"inputs_embeds": torch.rand(1, 1, 32)}
source = {"inputs_embeds": torch.rand(1, 1, 32)}
print("base", alignable(base))
print("source", alignable(source))

base ((tensor([[[-0.0552, -0.1310]]]),), None)
source ((tensor([[[-0.0507, -0.1348]]]),), None)


In [5]:
_, counterfactual_outputs = alignable(
    base,
    [source],
    {"sources->base": ([[[0]]], [[[0]]])},
    subspaces = [[[0, 1]]]
)

In [6]:
counterfactual_outputs # this should be the same as source.

(tensor([[[-0.0507, -0.1348]]], grad_fn=<UnsafeViewBackward0>),)

### Intervene the subspace with multiple sources

In [4]:
alignable_config = AlignableConfig(
    alignable_model_type=type(mlp),
    alignable_representations=[
        AlignableRepresentationConfig(
            0,
            "block_output",
            "pos",                               # mlp layer creates a single token reprs
            1,
            subspace_partition=[[0,16],[16,32]], # partition into two sets of subspaces
            intervention_link_key=0
        ),
        AlignableRepresentationConfig(
            0,
            "block_output",
            "pos",                               # mlp layer creates a single token reprs
            1,
            subspace_partition=[[0,16],[16,32]], # partition into two sets of subspaces
            intervention_link_key=0
        ),
    ],
    alignable_interventions_type=VanillaIntervention,
)
alignable = AlignableModel(alignable_config, mlp)

base = {"inputs_embeds": torch.rand(10, 1, 32)}
source = {"inputs_embeds": torch.rand(10, 1, 32)}
print("base", alignable(base))
print("source", alignable(source))

base ((tensor([[[-0.0609, -0.0470]],

        [[-0.0662, -0.0392]],

        [[-0.0639, -0.0455]],

        [[-0.0566, -0.0371]],

        [[-0.0627, -0.0445]],

        [[-0.0520, -0.0422]],

        [[-0.0615, -0.0410]],

        [[-0.0547, -0.0478]],

        [[-0.0602, -0.0447]],

        [[-0.0608, -0.0444]]]),), None)
source ((tensor([[[-0.0647, -0.0400]],

        [[-0.0558, -0.0537]],

        [[-0.0605, -0.0355]],

        [[-0.0554, -0.0457]],

        [[-0.0621, -0.0395]],

        [[-0.0621, -0.0408]],

        [[-0.0619, -0.0448]],

        [[-0.0599, -0.0390]],

        [[-0.0525, -0.0408]],

        [[-0.0595, -0.0523]]]),), None)


In [5]:
_, counterfactual_outputs = alignable(
    base,
    [source, None],
    {"sources->base": ([[[0]]*10, None], [[[0]]*10, None])},
    subspaces = [[[0]]*10, None]
)

In [6]:
counterfactual_outputs

(tensor([[[-0.0657, -0.0415]],
 
         [[-0.0656, -0.0449]],
 
         [[-0.0637, -0.0350]],
 
         [[-0.0602, -0.0413]],
 
         [[-0.0630, -0.0443]],
 
         [[-0.0593, -0.0396]],
 
         [[-0.0680, -0.0461]],
 
         [[-0.0586, -0.0447]],
 
         [[-0.0550, -0.0440]],
 
         [[-0.0588, -0.0496]]]),)

In [7]:
_, counterfactual_outputs = alignable(
    base,
    [None, source],
    {"sources->base": ([None, [[0]]*10], [None, [[0]]*10])},
    subspaces = [None, [[0]]*10]
)

In [8]:
counterfactual_outputs

(tensor([[[-0.0657, -0.0415]],
 
         [[-0.0656, -0.0449]],
 
         [[-0.0637, -0.0350]],
 
         [[-0.0602, -0.0413]],
 
         [[-0.0630, -0.0443]],
 
         [[-0.0593, -0.0396]],
 
         [[-0.0680, -0.0461]],
 
         [[-0.0586, -0.0447]],
 
         [[-0.0550, -0.0440]],
 
         [[-0.0588, -0.0496]]]),)

In [9]:
_, counterfactual_outputs = alignable(
    base,
    [source, source],
    {"sources->base": ([[[0]]*10, [[0]]*10], [[[0]]*10, [[0]]*10])},
    subspaces = [[[0]]*10, [[1]]*10]
)

In [10]:
counterfactual_outputs

(tensor([[[-0.0647, -0.0400]],
 
         [[-0.0558, -0.0537]],
 
         [[-0.0605, -0.0355]],
 
         [[-0.0554, -0.0457]],
 
         [[-0.0621, -0.0395]],
 
         [[-0.0621, -0.0408]],
 
         [[-0.0619, -0.0448]],
 
         [[-0.0599, -0.0390]],
 
         [[-0.0525, -0.0408]],
 
         [[-0.0595, -0.0523]]]),)