## Tutorial of Group-based Interventions

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "12/14/2023"

### Overview

Sometimes, we need to have a group of interventions using the same source examples to intervene. For instance, a set of subcomponents act together as a causal variable, then we need to intervene on them all together and possibly with the same source example.

### Set-up

In [2]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys
    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys
    sys.path.append("align-transformers/")

In [3]:
import sys
sys.path.append("..")

import torch
import pandas as pd
from models.utils import embed_to_distrib, top_vals, format_token
from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.alignable_base import AlignableModel
from models.interventions import VanillaIntervention
from models.gpt2.modelings_alignable_gpt2 import create_gpt2

%config InlineBackend.figure_formats = ['svg']
from plotnine import ggplot, geom_tile, aes, facet_wrap, theme, element_text, \
                     geom_bar, geom_hline, scale_y_log10

[2023-12-15 17:20:22,519] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)


### Non-group-based Interventions v.s. Group-based Interventions

In [4]:
config, tokenizer, gpt = create_gpt2()

loaded model


Two same sources are used to intervene at two locations.

In [5]:
alignable_config = AlignableConfig(
    alignable_model_type=type(gpt),
    alignable_representations=[
        AlignableRepresentationConfig(
            0,             
            "block_output", 
            "pos",        
            1,
        ),
        AlignableRepresentationConfig(
            2,             
            "block_output", 
            "pos",        
            1,
        ),
    ],
    alignable_interventions_type=VanillaIntervention,
)
alignable = AlignableModel(alignable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [
    tokenizer("The capital of Italy is", return_tensors="pt"),
    tokenizer("The capital of Italy is", return_tensors="pt")
]

In [6]:
_, counterfactual_outputs_no_group = alignable(
    base,
    sources,
    {"sources->base": ([[[3]], [[4]]], [[[3]], [[4]]])}
)

One single source is used for all interventions in the group

In [7]:
alignable_config = AlignableConfig(
    alignable_model_type=type(gpt),
    alignable_representations=[
        AlignableRepresentationConfig(
            0,             
            "block_output", 
            "pos",        
            1,
            group_key=0
        ),
        AlignableRepresentationConfig(
            2,             
            "block_output", 
            "pos",        
            1,
            group_key=0
        ),
    ],
    alignable_interventions_type=VanillaIntervention,
)
alignable = AlignableModel(alignable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [
    tokenizer("The capital of Italy is", return_tensors="pt")
]

In [8]:
_, counterfactual_outputs_group = alignable(
    base,
    sources,
    {"sources->base": ([[[3]], [[4]]], [[[3]], [[4]]])}
)

In [9]:
torch.equal(
    counterfactual_outputs_no_group.last_hidden_state, 
    counterfactual_outputs_group.last_hidden_state
)

True