In [1]:
import torch
import pyvene as pv

#### Set Activations to Zeros with Subspaces

In [2]:
# built-in helper to get a HuggingFace model
_, tokenizer, gpt2 = pv.create_gpt2()

# create with dict-based config
pv_config = pv.IntervenableConfig({
  "layer": 0,
  "component": "mlp_output"},
  intervention_types=pv.VanillaIntervention
)

#initialize model
# mode = parallel
pv_gpt2 = pv.IntervenableModel(pv_config, model=gpt2)

# run an intervened forward pass
intervened_outputs = pv_gpt2(
  # the intervening base input
  base=tokenizer("The capital of Spain is", return_tensors="pt"), 
  # the location to intervene at (3rd token)
  # used as key to get desired activation at given layer
  unit_locations={"base": 3}, # -> {'sources->base': (None, [[[3]]])}, 1st val sources, 2nd val base
  # the individual dimensions targetted
  subspaces=[10,11,12], # -> [[[10, 11, 12]]], replace at only these dims
  # source_reps = {intervention_name -> source_rep}
  source_representations=torch.zeros(gpt2.config.n_embd) 
)



loaded model


#### Interchange Interventions

##### Intro

In [6]:
# built-in helper to get a HuggingFace model
_, tokenizer, gpt2 = pv.create_gpt2()
# create with dict-based config
pv_config = pv.IntervenableConfig({
  "layer": 0,
  "component": "mlp_output"},
  intervention_types=pv.VanillaIntervention
)
#initialize model
pv_gpt2 = pv.IntervenableModel(
  pv_config, model=gpt2)
# run an interchange intervention 
intervened_outputs = pv_gpt2(
  # the base input
  base=tokenizer(
    "The capital of Spain is", 
    return_tensors = "pt"), 
  # the source input
  sources=tokenizer(
    "The capital of Italy is", 
    return_tensors = "pt"), 
  # the location to intervene at (3rd token)
  unit_locations={"sources->base": 3},
  # the individual dimensions targeted
  subspaces=[10,11,12]
)



loaded model


##### Factual recall

In [9]:
import pandas as pd
import pyvene
from pyvene import embed_to_distrib, top_vals, format_token
from pyvene import RepresentationConfig, IntervenableConfig, IntervenableModel
from pyvene import VanillaIntervention

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
)

In [10]:
config, tokenizer, gpt = pyvene.create_gpt2()

base = "The capital of Spain is"
source = "The capital of Italy is"
inputs = [tokenizer(base, return_tensors="pt"), tokenizer(source, return_tensors="pt")]
print(base)
res = gpt(**inputs[0])
distrib = embed_to_distrib(gpt, res.last_hidden_state, logits=False)
top_vals(tokenizer, distrib[0][-1], n=10)
print()
print(source)
res = gpt(**inputs[1])
distrib = embed_to_distrib(gpt, res.last_hidden_state, logits=False)
top_vals(tokenizer, distrib[0][-1], n=10)



loaded model
The capital of Spain is
_Madrid              0.10501297563314438
_the                 0.09497053176164627
_Barcelona           0.07027736306190491
_a                   0.04010061174631119
_now                 0.028243165463209152
_in                  0.02760007046163082
_Spain               0.022992383688688278
_Catalonia           0.01882333680987358
_also                0.018688397482037544
_not                 0.017356621101498604

The capital of Italy is
_Rome                0.1573489010334015
_the                 0.07316398620605469
_Milan               0.04687740281224251
_a                   0.03449936583638191
_now                 0.032003238797187805
_in                  0.023065846413373947
_also                0.022748125717043877
_home                0.019202813506126404
_not                 0.016405250877141953
_Italy               0.01577123813331127


We path patch on two modules on each layer:
- MLP output (the MLP output will be from another example)

- MHA input (the self-attention module input will be from another module)


In [11]:
def simple_position_config(model_type, component, layer):
    config = IntervenableConfig(
        model_type=model_type,
        representations=[
            RepresentationConfig(
                layer,              # layer
                component,          # component
                "pos",              # intervention unit
                1,                  # max number of unit
            ),
        ],
        intervention_types=VanillaIntervention,
    )
    return config


base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [tokenizer("The capital of Italy is", return_tensors="pt")]

In [None]:
tokens = tokenizer.encode(" Madrid Rome")

data = []
for layer_i in range(gpt.config.n_layer):
    config = simple_position_config(type(gpt), "mlp_output", layer_i)
    intervenable = IntervenableModel(config, gpt)
    for pos_i in range(len(base.input_ids[0])):
        # subspaces = None -> ?
        _, counterfactual_outputs = intervenable(
            base, sources, {"sources->base": pos_i}
        )
        distrib = embed_to_distrib(
            gpt, counterfactual_outputs.last_hidden_state, logits=False
        )
        for token in tokens:
            data.append(
                {
                    "token": format_token(tokenizer, token),
                    "prob": float(distrib[0][-1][token]),
                    "layer": f"f{layer_i}",
                    "pos": pos_i,
                    "type": "mlp_output",
                }
            )

    config = simple_position_config(type(gpt), "attention_input", layer_i)
    intervenable = IntervenableModel(config, gpt)
    for pos_i in range(len(base.input_ids[0])):
        _, counterfactual_outputs = intervenable(
            base, sources, {"sources->base": pos_i}
        )
        distrib = embed_to_distrib(
            gpt, counterfactual_outputs.last_hidden_state, logits=False
        )
        for token in tokens:
            data.append(
                {
                    "token": format_token(tokenizer, token),
                    "prob": float(distrib[0][-1][token]),
                    "layer": f"a{layer_i}",
                    "pos": pos_i,
                    "type": "attention_input",
                }
            )
df = pd.DataFrame(data)