# Introduction to Causal Abstraction Analysis with Distributed Alignment Search 

In [1]:
__author__ = "Atticus Geiger"

## Contents

1. [The hierarchical equality task](#The-hierarchical-equality-task)
    1. [An Algorithm that Solves the Equality Task](#An-Algorithm-that-Solves-the-Equality-Task)
        1. [The algorithm with no intervention](#The-algorithm-with-no-intervention)
        1. [The algorithm with an intervention](#The-algorithm-with-an-intervention)
        1. [The algorithm with an interchange intervention](#The-algorithm-with-an-interchange-intervention)
    1. [Hand Crafting an MLP to Solve Hierarchical Equality](#Hand-Crafting-an-MLP-to-Solve-Hierarchical-Equality)        
    1. [Training an MLP to Solve Hierarchical Equality](#Training-an-MLP-to-Solve-Hierarchical-Equality)
1. [Causal abstraction Analysis](#Causal-abstraction)
    1. [Basic intervention: zeroing out part of a hidden layer](#Basic-intervention:-zeroing-out-part-of-a-hidden-layer)
    1. [An interchange intervention](#An-interchange-intervention)
    1. [Alignment](#Alignment)
    1. [Evaluating an Alignment](#Evaluation)
1. [Interchange Intervention Training (IIT)](#Interchange-Intervention-Training-(IIT))
1. [Distributed Alignment Search (DAS)](#Distributed-Alignment-Search-(DAS))

## Set-up

This notebook is a hands-on introduction to __causal abstraction analysis__ using __distributed alignment search__ with neural networks.

In causal abstraction analysis, we assess whether trained models conform to high-level causal models that we specify, not just in terms of their input–output behavior, but also in terms of their internal dynamics. The core technique is the __interchange intervention__, in which a causal model is provided an input and then intermediate variables are fixed to take on the values they would have for a second input.

To motivate and illustrate these concepts, we're going to focus on a hierarchical equality task, building on work by [Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968).

In [2]:
import sys, os
sys.path.append(os.path.join('..', '..'))

In [3]:
import torch
from torch.utils.data import DataLoader
import random
import copy
import itertools
import numpy as np
from tqdm import tqdm, trange
from highlevel_models.causal_model import CausalModel
from sklearn.metrics import classification_report
from transformers import get_linear_schedule_with_warmup
from models.mlp.modelings_mlp import MLPConfig
from models.mlp.modelings_alignable_mlp import create_mlp_classifier
from models.configuration_alignable_model import AlignableRepresentationConfig, AlignableConfig
from models.interventions import VanillaIntervention, LowRankRotatedSpaceIntervention, RotatedIntervention
from models.alignable_base import AlignableModel

ModuleNotFoundError: No module named 'highlevel_models.causal_model'

In [None]:
seed = 42
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
embedding_dim = 2
number_of_entities = 10

## The hierarchical equality task

This section builds on results presented in [Geiger, Carstensen, Frank, and Potts (2020)](https://arxiv.org/abs/2006.07968). We will use a hierarchical equality task ([Premack 1983](https://www.cambridge.org/core/services/aop-cambridge-core/content/view/7DF6F2D22838F7546AF7279679F3571D/S0140525X00015077a.pdf/div-class-title-the-codes-of-man-and-beasts-div.pdf)) to present interchange intervention training (IIT). 

We define the hierarchical equality task as follows: The input is two pairs of objects and the output is **True** if both pairs contain the same object or if both pairs contain different objects and **False** otherwise.  For example, `AABB` and `ABCD` are both labeled **True**, while `ABCC` and `BBCD` are both labeled **False**. 

## An Algorithm that Solves the Equality Task

Let $\mathcal{A}$ be the simple tree-structured algorithm that solves this task by applying a simple equality relation three times: Compute whether the first two inputs are equal, compute whether the second two inputs are equal, then compute whether the truth-valued outputs of these first two computations are equal. 

And here's a Python implementation of $\mathcal{A}$ that supports the interventions we'll want to do:

In [None]:
def randvec(n=50, lower=-1, upper=1):
    return np.array([random.uniform(lower, upper) for i in range(n)])

variables =  ["W", "X", "Y", "Z", "WX", "YZ", "O"]

reps = [randvec(embedding_dim, lower=-1, upper=1) for _ in range(number_of_entities)]
values = {variable:reps for variable in ["W","X", "Y", "Z"]}
values["WX"] = [True, False]
values["YZ"] = [True, False]
values["O"] = [True, False]

parents = {"W":[],"X":[], "Y":[], "Z":[], 
           "WX":["W", "X"], "YZ":["Y", "Z"], 
           "O":["WX", "YZ"]}

def FILLER():
    return reps[0]

functions = {"W":FILLER,"X":FILLER, "Y":FILLER, "Z":FILLER, 
             "WX": lambda x,y: np.array_equal(x,y), 
             "YZ":lambda x,y: np.array_equal(x,y), 
             "O": lambda x,y: x==y}

pos = {"W":(0.2,0),"X":(1,0.1), "Y":(2,0.2), "Z":(2.8,0), 
           "WX":(1,2), "YZ":(2,2), 
           "O":(1.5,3)}

equiv_classes = {}

equality_model = CausalModel(variables, values, parents, functions, pos = pos)

Here's a visual depiction of the algorithm:

In [None]:
equality_model.print_structure()
print("Timesteps:", equality_model.timesteps)

### The algorithm with no intervention

Let's first observe the behavior of the algorithm when we provide the input `BBCD` with no interventions. Here is a visual depiction:

In [None]:
setting = equality_model.run_forward({"W":reps[0], "X":reps[0], "Y":reps[1], "Z":reps[3]})
print("No intervention:\n", setting, "\n")
equality_model.print_setting(setting)

### The algorithm with an intervention

Let's now see the behavior of the algorithm when we provide the input `BBCD` with an intervention setting **WX** to **False**. First, a visual depiction:

<img src="fig/IIT/PremackIntervention.png" width="500"/>

And then the same computation with `compute_A`:

In [None]:
print("Intervention setting WX to TRUE:\n", )
equality_model.print_setting(equality_model.run_forward({"W":reps[0], "X":reps[0], "Y":reps[1], "Z":reps[3], "WX":False}))

Notice that, in this example, even though the left two inputs are not the same, the intervention has changed the intermediate prediction for those two inputs from **False** to **True**, and thus the algorithm outputs **True**, since its output is determined by **WX** and **YZ**.

### The algorithm with an interchange intervention

Finally, let's observe the behavior of the algorithm when we provide the base input `BBCD` with an intervention setting **WX** to be the value it would be for the source input `ABCC`.

In [None]:
base = {"W":reps[0], "X":reps[0], "Y":reps[1], "Z":reps[3]}
source = {"W":reps[0], "X":reps[1], "Y":reps[2], "Z":reps[2]}
setting = equality_model.run_interchange(base, {"WX":source})
equality_model.print_setting(setting)

# Hand Crafting an MLP to Solve Hierarchical Equality

In [None]:
n_examples = 2048*16

X, y = equality_model.generate_factual_dataset(n_examples,equality_model.sample_input_tree_balanced)

X = X.unsqueeze(1)

In [None]:
config = MLPConfig(h_dim=embedding_dim*4,
          activation_function = "relu",
          n_layer = 2,
          n_labels = 2,
          pdrop = 0.0
         )

In [None]:
config, tokenizer, handcrafted = create_mlp_classifier(config)

The first layer of our handcrafted model computes:

$ReLU(W_1[\mathbf{a}, \mathbf{b}, \mathbf{c}, \mathbf{d}]) = [max(\mathbf{a}-\mathbf{b}, 0), max(\mathbf{b}-\mathbf{a}, 0), max(\mathbf{c}-\mathbf{d}, 0), max(\mathbf{d}-\mathbf{c}, 0)]$


In [None]:
W1 = [[ 1,  0, -1,  0,  0,  0,  0,  0],
      [ 0,  1,  0, -1,  0,  0,  0,  0],
      [-1,  0,  1,  0,  0,  0,  0,  0],
      [ 0, -1,  0,  1,  0,  0,  0,  0],
      [ 0,  0,  0,  0,  1,  0, -1,  0],
      [ 0,  0,  0,  0,  0,  1,  0, -1],
      [ 0,  0,  0,  0, -1,  0,  1,  0],
      [ 0,  0,  0,  0,  0, -1,  0,  1]]
handcrafted.mlp.h[0].ff1.weight = torch.nn.Parameter(torch.FloatTensor(W1))
handcrafted.mlp.h[0].ff1.bias = torch.nn.Parameter(torch.FloatTensor([0,0,0,0,0,0,0,0]))
handcrafted.mlp.h[0].ff2.weight = torch.nn.Parameter(torch.eye(embedding_dim*4))
handcrafted.mlp.h[0].ff2.bias = torch.nn.Parameter(torch.FloatTensor([0,0,0,0,0,0,0,0]))

The second layer of our handcrafted model computes:

$ReLU(W_2ReLU(W_1[\mathbf{a}, \mathbf{b}, \mathbf{c}, \mathbf{d}])) = [|\mathbf{a}-\mathbf{b}| - |\mathbf{c}-\mathbf{d}|, |\mathbf{c}-\mathbf{d}|-|\mathbf{a}-\mathbf{b}|, |\mathbf{a}-\mathbf{b}|, |\mathbf{c}-\mathbf{d}|,0,0,0,0]$


In [None]:
W2 = [[ 1,-1, 0, 1, 0, 0, 0, 0],
      [ 1,-1, 0, 1, 0, 0, 0, 0],
      [ 1,-1, 0, 1, 0, 0, 0, 0],
      [ 1,-1, 0, 1, 0, 0, 0, 0],
      [-1, 1, 1, 0, 0, 0, 0, 0],
      [-1, 1, 1, 0, 0, 0, 0, 0],
      [-1, 1, 1, 0, 0, 0, 0, 0],
      [-1, 1, 1, 0, 0, 0, 0, 0]]
handcrafted.mlp.h[1].ff1.weight = torch.nn.Parameter(torch.FloatTensor(W2).transpose(0,1))
handcrafted.mlp.h[1].ff1.bias = torch.nn.Parameter(torch.FloatTensor([0,0,0,0,0,0,0,0]))
handcrafted.mlp.h[1].ff2.weight = torch.nn.Parameter(torch.eye(embedding_dim*4))
handcrafted.mlp.h[1].ff2.bias = torch.nn.Parameter(torch.FloatTensor([0,0,0,0,0,0,0,0]))

The third layer of our handcrafted model computes the logits:

$W_3 ReLU(W_2ReLU(W_1[\mathbf{a}, \mathbf{b}, \mathbf{c}, \mathbf{d}])) = [||\mathbf{a}-\mathbf{b}| - |\mathbf{c}-\mathbf{d}|| -0.999999|\mathbf{a}-\mathbf{b}|-0.999999|\mathbf{c}-\mathbf{d}|, 0]$

In [None]:
W3 = [[        1, 0],
      [        1, 0],
      [-0.999999, 0],
      [-0.999999, 0],
      [        0, 0],
      [        0, 0],
      [        0, 0],
      [        0, 0]]
handcrafted.score.weight = torch.nn.Parameter(torch.FloatTensor(W3).transpose(0,1))
handcrafted.score.bias = torch.nn.Parameter(torch.FloatTensor([0,0.00000000000001]))

In [None]:
preds = handcrafted.forward(inputs_embeds=X)

print("Train Results")
print(classification_report(y, preds[0].squeeze(1).argmax(1)))

# Causal abstraction

The formal theory of **causal abstraction** describes the conditions that must hold for the high-level tree structured algorithm to be a **simplified and faithful description** of the neural network. 

In essence: an high-level model is a causal abstraction of a neural network if and only if for all base and source inputs, the algorithm and network provides the same output, for some alignment between these two models.

Below, we define an alignment between the neural network and the algorithm and a function to compute the **interchange intervention accuracy** (II accuracy) for a high-level variable: the percentage of aligned interchange interventions that the network and algorithm produce the same output on. When the II accuracy is 100%, the causal abstraction relation holds between the network and a simplified version of the algorithm where only one high-level variable exists.

In [None]:
alignable_config = AlignableConfig(
        alignable_model_type=type(handcrafted),
        alignable_representations=[
            AlignableRepresentationConfig(
                0,             # layer
                "block_output", # intervention type
                "pos",             # intervention unit
                1,                  # max number of unit
                subspace_partition = [[0,4],[4,8]],
             intervention_link_key=0 # create sym link across interventions
            ),
            AlignableRepresentationConfig(
                0,             # layer
                "block_output", # intervention type
                "pos",             # intervention unit
                1,                  # max number of unit
                subspace_partition = [[0,4],[4,8]],
             intervention_link_key=0 # create sym link across interventions
            ),
        ],
        alignable_interventions_type=RotatedIntervention,
    )
alignable_handcrafted = AlignableModel(alignable_config, handcrafted)

Next we create a counterfactual equality dataset that includes interchange intervention examples:

In [None]:
def intervention_id(intervention):
    if "WX" in intervention and "YZ" in intervention:
        return WXYZ
    if "WX" in intervention:
        return WX
    if "YZ" in intervention:
        return YZ

In [None]:
data_size = 2048*16
batch_size = 16
WX = 0
YZ = 1
WXYZ = 2
dataset = equality_model.generate_counterfactual_dataset(data_size,
                                                        intervention_id,
                                                        16,
                                                        device = "cuda:0",
                                                        sampler=equality_model.sample_input_tree_balanced)

In [None]:
print(dataset[0]["input_ids"])
print(dataset[0]["source_input_ids"])
print(dataset[0]["labels"])
print(dataset[0]["base_labels"])
print(dataset[0]["intervention_id"])

This dataset has the following components:

* `X_base_train`: a regular set of train examples
* `y_base_train`: a regular set of train labels
* `X_sources_train`: a list additional train sets (here, a singleton list of them) for counterfactuals
* `y_IIT_train`: a list of labels for the examples in `X_sources_train`.
* `interventions`: a list of intervention sites (here, all `0` corresponding to our key for "V1")

In [None]:
handcrafted.to("cuda:0")
for parameter in alignable_handcrafted.get_trainable_parameters():
    parameter.to("cuda:0")
preds = []
for batch in DataLoader(dataset, batch_size):
    batch["input_ids"] = batch["input_ids"].unsqueeze(1)    
    batch["source_input_ids"] = batch["source_input_ids"].unsqueeze(2)    
    if batch["intervention_id"][0] == 2:
        _, counterfactual_outputs = alignable_handcrafted(
                {"inputs_embeds":batch["input_ids"]},
                [{"inputs_embeds":batch["source_input_ids"][:, 0]}, 
                 {"inputs_embeds":batch["source_input_ids"][:,1]}],
                {"sources->base": ([[[0]]*batch_size, [[0]]*batch_size], [[[0]]*batch_size, [[0]]*batch_size])},
            subspaces=[[[0]]*batch_size, [[1]]*batch_size]
            )
    elif batch["intervention_id"][0] == 1:
        _, counterfactual_outputs = alignable_handcrafted(
                {"inputs_embeds":batch["input_ids"]},
                [{"inputs_embeds":batch["source_input_ids"][:,0]}, None],
                {"sources->base": ([[[0]]*batch_size, None], [[[0]]*batch_size, None])},
                subspaces=[[[0]]*batch_size, None]
            )
    elif batch["intervention_id"][0] == 0:
        _, counterfactual_outputs = alignable_handcrafted(
                {"inputs_embeds":batch["input_ids"]},
                [None, {"inputs_embeds":batch["source_input_ids"][:,0]}],
                {"sources->base": ([None, [[0]]*batch_size], [None, [[0]]*batch_size])},
                subspaces=[None, [[1]]*batch_size]
            )
    preds.append(counterfactual_outputs)
preds = torch.cat(preds)

In [None]:
print(classification_report(torch.tensor([x["labels"] for x in dataset]).cpu(), preds.argmax(1).cpu()))

# Training an MLP to Solve Hierarchical Equality

We've now seen how interventions work in our high-level causal model. We turn now to doing parallel work in our neural network, which will be a fully-connected feed-forward neural network with three hidden layers. The following code simply extends `TorchDeepNeuralClassifier` with a method `retrieve_activations` that supports interventions on PyTorch computation graphs:

The module `iit` provides some dataset functions for equality learning. Here we define a simple an equality dataset:

The examples in this dataset are 8-dimensional vectors: the concatenation of 4 2-dimensional vectors. Here's the first example with its label:

In [None]:
X[0], y[0]

The label for this example is determined by whether the equality value for the first two inputs matches the equality value for the second two inputs:

In [None]:
left = torch.equal(
    X[0][: embedding_dim],
    X[0][embedding_dim: embedding_dim*2])

left

In [None]:
right = torch.equal(
    X[0][embedding_dim*2: embedding_dim*3],
    X[0][embedding_dim*3: ])

right

In [None]:
int(left == right)

Let's see how our model does out-of-the-box on this task:

In [None]:
trained = MLPClassifier(
    hidden_dim=embedding_dim*4, 
    hidden_activation=torch.nn.ReLU(), 
    num_layers=2,
    input_dim=embedding_dim*4,
    input_len=4,
    n_classes=2,
    warm_start=True,
    max_iter=100,
    batch_size=64,
    n_iter_no_change=10000,
    shuffle_train=False,
    eta=0.001)

In [None]:
_ = trained.fit(X, y)

This neural network achieves near perfect performance on its train set:

In [None]:
preds = trained.predict(X, device="cpu")


print("Train Results")
print(classification_report(y, preds.cpu()))

And it generalizes perfectly to a test set consisting of distinct vectors:

In [None]:
variables =  ["W", "X", "Y", "Z", "WX", "YZ", "O"]

number_of__test_entities = 100

reps = [tutorial_utils.randvec(embedding_dim).round(2)  for _ in range(number_of__test_entities)]
values = {variable:reps for variable in ["W","X", "Y", "Z"]}
values["WX"] = [True, False]
values["YZ"] = [True, False]
values["O"] = [True, False]

parents = {"W":[],"X":[], "Y":[], "Z":[], 
           "WX":["W", "X"], "YZ":["Y", "Z"], 
           "O":["WX", "YZ"]}

def FILLER():
    return reps[0]

functions = {"W":FILLER,"X":FILLER, "Y":FILLER, "Z":FILLER, 
             "WX": lambda x,y: np.array_equal(x,y), "YZ":lambda x,y: np.array_equal(x,y), 
             "O": lambda x,y: x==y}

pos = {"W":(0,0),"X":(1,0.1), "Y":(2,0.2), "Z":(3,0), 
           "WX":(1,2), "YZ":(2,2), 
           "O":(1.5,3)}

test_equality_model = CausalModel(variables, values, parents, functions, pos = pos)

In [None]:
X_test, y_test = test_equality_model.generate_factual_dataset(10000,equality_model.sample_input_tree_balanced)
print("Test Results")

test_preds = trained.predict(X_test, device="cpu")

print(classification_report(y_test, test_preds.cpu()))

In [None]:
test_dataset = equality_model.generate_counterfactual_dataset(data_size,
                                                        intervention_id,
                                                        16,
                                                        device = "cuda:0",
                                                        sampler=equality_model.sample_input_tree_balanced)

Does it implement our high-level model of the problem, though?

# Distributed Alignment Search

Interchange Intervention Training (IIT) is a method for training a neural network to conform to the causal structure of a high-level algorithm. Conceptually, it is a direct extension of the causal abstraction analysis we just performed, except instead of **evaluating** whether the neural network and algorithm produce the same outputs under aligned interchange interventions, we are now **training** the neural network to produce the output of the algorithm under aligned interchange interventions.

IIT was developed by [Geiger\*, Wu\*, Lu\*, Rozner, Kreiss, Icard, Goodman, and Potts (2021)](https://arxiv.org/abs/2112.00826), and it is used for model distillation [ Wu\*, Geiger\*, Rozner, Kreiss, Lu, Icard, Goodman, and Potts (2022)](https://arxiv.org/abs/2112.02505).

In [None]:
alignable_config = AlignableConfig(
    alignable_model_type=type(trained),
    alignable_representations=[
        AlignableRepresentationConfig(
            0,             # layer
            "block_output", # intervention type
            "pos",    # intervention unit is now aligne with tokens
            4,                 # max number of unit
            alignable_low_rank_dimension = 4 * embedding_dim, # the full space
            subspace_partition=[[0,2 * embedding_dim],[2 * embedding_dim,4 * embedding_dim]]      # binary partition with equal sizes
        ),
    ],
    alignable_interventions_type=LowRankRotatedSpaceIntervention,
)

In [None]:
alignable = AlignableModel(alignable_config, trained)
alignable.set_device("cuda")
alignable.disable_model_gradients()

In [None]:
t_total = int(len(dataset) * 3)
warm_up_steps = 0.1 * t_total
optimizer_params = []
for k, v in alignable.interventions.items():
    optimizer_params += [{'params': v[0].rotate_layer.parameters()}]
optimizer = torch.optim.Adam(
    optimizer_params,
    lr=1e-4
)

def compute_metrics(eval_preds, eval_labels):
    total_count = 0
    correct_count = 0
    for eval_pred, eval_label in zip(eval_preds, eval_labels):
        actual_test_labels = eval_label[:, -1]
        pred_test_labels = torch.argmax(eval_pred[:, -1], dim=-1)
        correct_labels = (actual_test_labels==pred_test_labels)
        total_count += len(correct_labels)
        correct_count += correct_labels.sum().tolist()
    accuracy = round(correct_count/total_count, 2)
    return {"accuracy" : accuracy}

epochs = 2
gradient_accumulation_steps = 4
total_step = 0
target_total_step = len(dataset) * epochs

In [None]:
alignable.model.train() # train enables drop-off but no grads
print("MLP trainable parameters: ", count_parameters(alignable.model))
print("intervention trainable parameters: ", alignable.count_parameters())
train_iterator = trange(
    0, int(epochs), desc="Epoch"
)
for epoch in train_iterator:
    epoch_iterator = tqdm(
        DataLoader(dataset, batch_size), desc=f"Epoch: {epoch}", position=0, leave=True
    )
    for step, batch in enumerate(epoch_iterator):
        for k, v in batch.items():
            if v is not None and isinstance(v, torch.Tensor):
                batch[k] = v.to("cuda")
        batch_size = batch["input_ids"].shape[0]
        if batch["intervention_id"][0] == 2: 
            continue
#             _, counterfactual_outputs = alignable(
#                 {"input_ids":batch["input_ids"]},
#                 [{"input_ids":batch["source_input_ids"][:, 0]}, 
#                  {"input_ids":batch["source_input_ids"][:, 1]}],
#                 {"sources->base": ([[[0,1,2,3]]*batch_size, [[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size, [[0,1,2,3]]*batch_size])},
#                 None,
#                 [[[0,1]]*batch_size]
#             )
        elif batch["intervention_id"][0] ==1:
            _, counterfactual_outputs = alignable(
                {"input_ids":batch["input_ids"]},
                [{"input_ids":batch["source_input_ids"][:,0]}],
                {"sources->base": ([[[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size])},
                None,
                [[[1]]*batch_size]
            )
        elif batch["intervention_id"][0] ==0:
            _, counterfactual_outputs = alignable(
                {"input_ids":batch["input_ids"]},
                [{"input_ids":batch["source_input_ids"][:,0]}],
                {"sources->base": ([[[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size])},
                None,
                [[[0]]*batch_size]
            )
        eval_metrics = compute_metrics(
            [counterfactual_outputs], [batch['labels']]
        )
        
        # loss and backprop
        loss = alignable.model.loss(
            counterfactual_outputs, batch["labels"].squeeze().type(torch.LongTensor).to("cuda")
        )
        loss_str = round(loss.item(), 2)
        epoch_iterator.set_postfix({'loss': loss_str, 'acc': eval_metrics["accuracy"]})
        
        if gradient_accumulation_steps > 1:
            loss = loss / gradient_accumulation_steps
        if total_step % gradient_accumulation_steps == 0:
            if not (gradient_accumulation_steps > 1 and total_step == 0):
                loss.backward()
                optimizer.step()
                alignable.set_zero_grad()
    total_step += 1

In [None]:
eval_labels = []
eval_preds = []
with torch.no_grad():
    epoch_iterator = tqdm(DataLoader(dataset, batch_size), desc=f"Test")
    for step, batch in enumerate(epoch_iterator):
        for k, v in batch.items():
            if v is not None and isinstance(v, torch.Tensor):
                batch[k] = v.to("cuda")
        batch_size = batch["input_ids"].shape[0]
        if batch["intervention_id"][0] == 2: 
            continue
#             _, counterfactual_outputs = alignable(
#                 {"input_ids":batch["input_ids"]},
#                 [{"input_ids":batch["source_input_ids"][:, 0]}, 
#                  {"input_ids":batch["source_input_ids"][:,1]}],
#                 {"sources->base": ([[[0,1]]*batch_size, [[2,3]]*batch_size], [[[0,1]]*batch_size, [[2,3]]*batch_size])} 
#             )
        elif batch["intervention_id"][0] == 1:
            _, counterfactual_outputs = alignable(
                {"input_ids":batch["input_ids"]},
                [{"input_ids":batch["source_input_ids"][:,0]}],
                {"sources->base": ([[[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size])},
                None,
                [[[1]]*batch_size]
            )
        elif batch["intervention_id"][0] == 0:
            _, counterfactual_outputs = alignable(
                {"input_ids":batch["input_ids"]},
                [{"input_ids":batch["source_input_ids"][:,0]}],
                {"sources->base": ([[[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size])},
                None,
                [[[0]]*batch_size]
            )
        eval_labels += [batch['labels']]
        eval_preds += [torch.argmax(counterfactual_outputs,dim=1)]


In [None]:
print(classification_report(torch.cat(eval_labels).cpu(), torch.cat(eval_preds).cpu()))

In [None]:
eval_labels = []
eval_preds = []
with torch.no_grad():
    epoch_iterator = tqdm(DataLoader(test_dataset, batch_size), desc=f"Test")
    for step, batch in enumerate(epoch_iterator):
        for k, v in batch.items():
            if v is not None and isinstance(v, torch.Tensor):
                batch[k] = v.to("cuda")
        batch_size = batch["input_ids"].shape[0]
        if batch["intervention_id"][0] == 2: 
            continue
#             _, counterfactual_outputs = alignable(
#                 {"input_ids":batch["input_ids"]},
#                 [{"input_ids":batch["source_input_ids"][:, 0]}, 
#                  {"input_ids":batch["source_input_ids"][:,1]}],
#                 {"sources->base": ([[[0,1]]*batch_size, [[2,3]]*batch_size], [[[0,1]]*batch_size, [[2,3]]*batch_size])} 
#             )
        elif batch["intervention_id"][0] ==1:
            _, counterfactual_outputs = alignable(
                {"input_ids":batch["input_ids"]},
                [{"input_ids":batch["source_input_ids"][:,0]}],
                {"sources->base": ([[[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size])},
            )
        elif batch["intervention_id"][0] ==0:
            _, counterfactual_outputs = alignable(
                {"input_ids":batch["input_ids"]},
                [{"input_ids":batch["source_input_ids"][:,0]}],
                {"sources->base": ([[[0,1,2,3]]*batch_size], [[[0,1,2,3]]*batch_size])},
            )
        eval_labels += [batch['labels']]
        eval_preds += [torch.argmax(counterfactual_outputs,dim=1)]
print(classification_report(torch.cat(eval_labels).cpu(), torch.cat(eval_preds).cpu()))

# Trying to use find_alingment

In [None]:
def add_locations(example):
    example["source_0->base.0.pos"] = [0,1,2,3] # target the whole layer
    example["source_0->base.1.pos"] = [0,1,2,3] 
    example["source_1->base.0.pos"] = [0,1,2,3] # target the whole layer
    example["source_1->base.1.pos"] = [0,1,2,3] 
    if example["intervention_id"] == 0:
        example["source_0->subspaces"] = [[0]]
    if example["intervention_id"] == 1:
        example["source_0->subspaces"] = [[1]]   
    if example["intervention_id"] == 2:
        example["source_0->subspaces"] = [[0]]
        example["source_1->subspaces"] = [[1]] 
    return example

for example in dataset:
    add_locations(example)



In [None]:
def inputs_collator(inputs):
    for k, v in inputs.items():
        if "subspace" in k:
            inputs[k] = [v]
        elif v is not None and isinstance(v, torch.Tensor):
            inputs[k] = v.to("cuda")
    return inputs

In [None]:
alignable.find_alignment(
    train_dataloader=DataLoader(dataset, batch_size),
    compute_loss=alignable.model.loss,
    compute_metrics=classification_report,
    inputs_collator=None
)

To evaluate this model, we create a fresh IIT equality dataset consisting of 100 examples:

In [None]:
data_size = 10000

data = test_equality_model.generate_counterfactual_dataset(10000,
                                                           intervention_id,
                                                           64,
                                                           equality_model.sample_input_tree_balanced)
X_base_test, y_base_test, X_sources_test, y_II_test, interventions_test = data

print(X_base_test.shape)
print(y_base_test.shape)
print(X_sources_test.shape)
print(y_II_test.shape)
print(interventions_test.shape)

In [None]:
base_preds_test = LIM_trainer.predict(X_base_test.cpu(),device="cpu")

II_preds_test = LIM_trainer.iit_predict(X_base_test.cpu(),
                                    X_sources_test.cpu(),
                                    interventions_test.cpu(),
                                    id_to_coords,
                                    device="cpu")

This IIT-trained model does well in terms of a standard behavioral tests:

In [None]:
print(classification_report(y_base_test, base_preds_test.cpu()))

Importantly, it _also_ performs perfectly on counterfactual examples – certainly a marked improvement over the model we studied above that did no IIT:

In [None]:
print(classification_report(y_II_test, II_preds_test.cpu()))