In [None]:
from ..CausalAbstraction.tasks.IOI_task.ioi_task import get_task, get_token_positions
from ..CausalAbstraction.experiments.aggregate_experiments import ioi_baselines

In [None]:
task = get_task(hf=True, size=None)
print("Raw input:")
print(task.raw_all_data["input"][0])
task.display_counterfactual_data()

In [None]:
from ..CausalAbstraction.pipeline import LMPipeline
import torch
from transformers import GPT2Config

device = "cuda:1" if torch.cuda.is_available() else "cpu"

def checker (pred, expected):
    return expected in pred

# model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# model_name = "google/gemma-2-2b"
model_name = 'openai-community/gpt2'
if "gpt2" in model_name:
    config = GPT2Config.from_pretrained(model_name)
    config._attn_implementation = "eager"
pipeline = LMPipeline(model_name, max_new_tokens=1, device=device, dtype=torch.float32, max_length=32, logit_labels=True,position_ids=True, config=config)
pipeline.tokenizer.padding_side = "left"
batch_size = 1024 * 2
print("DEVICE:", pipeline.model.device)
print(task.raw_all_data)
print("INPUT:", task.raw_all_data["input"][0])
print("LABEL:", task.raw_all_data["label"][0])
print("PREDICTION:", pipeline.dump(pipeline.generate(task.raw_all_data["input"][0])))

task.filter(pipeline, checker, verbose=True, batch_size=batch_size)

In [None]:
token_positions = get_token_positions(pipeline, task)

In [None]:
task.output_dumper = lambda x: x
pipeline.return_scores = True

def log_diff(logits, params):
    """
    Compute the difference in logit scores between two tokens.
    
    Args:
        logits: Tensor containing logit scores for tokens
        params: Dictionary containing 'name_A', 'name_B', and 'output_token'
    """
    # Extract names from params
    name_A = params["name_A"]
    name_B = params["name_B"]
    name_C = params["name_C"]
    
    if not isinstance(name_A, list):
        name_A = [name_A]
    if not isinstance(name_B, list):
        name_B = [name_B]
    if not isinstance(name_C, list):
        name_C = [name_C]
    # print(name_A, name_B, name_C)

    token_id_A = [pipeline.tokenizer.encode(A, add_special_tokens=False)[0] for A in name_A]
    token_id_B = [pipeline.tokenizer.encode(B, add_special_tokens=False)[0] for B in name_B]
    token_id_C = [pipeline.tokenizer.encode(C, add_special_tokens=False)[0] for C in name_C]

    token_id_IO, token_id_S = [], []
    for i in range(len(token_id_A)):
        if token_id_A[i] == token_id_C[i]:
            token_id_S.append(token_id_A[i])
            token_id_IO.append(token_id_B[i])
        elif token_id_B[i] == token_id_C[i]:
            token_id_S.append(token_id_B[i])
            token_id_IO.append(token_id_A[i])
    # print(token_id_S)
    
    if isinstance(logits, tuple):
        logits = logits[0]
    # print(logits.shape)
    # print(len(token_id_S), len(token_id_IO))
    # Get the logit scores for both tokens
    if len(logits.shape) == 3:
        logits = logits.squeeze(1)
    if len(logits.shape) == 2:
        # Create batch indices
        batch_indices = torch.arange(logits.shape[0])
        
        # Extract specific logits using batch indices
        logit_S = logits[batch_indices, token_id_S]
        # print("2", logit_S)
        logit_IO = logits[batch_indices, token_id_IO]
    elif len(logits.shape) == 1:
        logit_S = logits[token_id_S[0]]
        logit_IO = logits[token_id_IO[0]]

    # print("S", logit_S)
    # print("IO", logit_IO)
    
    return logit_IO - logit_S

def checker(logits, params):
    """
    Compute the squared error between the actual logit difference and the target logit difference.
    
    Args:
        logits: Tensor containing logit scores for tokens
        params: Dictionary containing 'name_A', 'name_B', 'output_token', and 'logit_diff'
    
    Returns:
        Squared error between the computed logit difference and the target logit difference
    """
    # Extract names and target values from params
    if isinstance(logits, list):
        logits = logits[0]

    target_diff = params["logit_diff"]
    actual_diff = log_diff(logits, params)
    if isinstance(target_diff, torch.Tensor):
        target_diff = target_diff.to(actual_diff.device).to(actual_diff.dtype)
        # print(target_diff.shape)
        # print(actual_diff.shape)
        #make sure the target_diff requires gradient
    # Compute the squared error

    squared_error = (actual_diff - target_diff) ** 2
    
    return squared_error


In [None]:
import copy

new_examples, new_raw_examples = [], []
for raw_example, example in zip(task.raw_counterfactual_datasets["s1_io_flip_train"], task.counterfactual_datasets["s1_io_flip_train"]):
    new_example = example.copy()
    new_raw_example = raw_example.copy()

    new_example["counterfactual_inputs"] = [new_example["input"].copy()]
    new_raw_example["counterfactual_inputs"] = [new_raw_example["input"]]

    new_examples.append(new_example)
    new_raw_examples.append(new_raw_example)

task.counterfactual_datasets["same"] = new_examples
task.raw_counterfactual_datasets["same"] = new_raw_examples

In [None]:
# diffs = []
# losses = []
# for raw_example, example in zip(task.raw_counterfactual_datasets["same"], task.counterfactual_datasets["same"]):
#     logits = pipeline.generate(raw_example["input"])[0]
#     params = task.causal_model.run_forward(example["input"])
#     diff = log_diff(logits, params)
#     loss = checker(logits, params)
#     losses.append(loss)
#     diffs.append(diff)
# #take the average
# diff = sum(diffs) / len(diffs)
# loss = sum(losses) / len(losses)
# print("AVERAGE DIFF:", diff)
# print("AVERAGE LOSS:", loss)

In [None]:
from ..CausalAbstraction.experiments.LM_experiments import PatchIOIHeads



data_to_X = {"same":{"position":1, "token":1}, 
             "s1_io_flip_train":{"position": -1,"token":1},
             "s2_io_flip_train":{"position":-1, "token":-1},
             "s1_ioi_flip_s2_ioi_flip_train":{"position":1,
                                              "token":-1}}
X, y, = [], []
total_loss = 0

for counterfactual in data_to_X:
    experiment = PatchIOIHeads(pipeline, task, list(range(0, 1)), None, token_positions, checker, config={"evaluation_batch_size": batch_size, "output_scores":True})
    raw_results = experiment.perform_interventions([counterfactual], verbose=False)
    raw_outputs = None
    losses, labels, counterfactual_y = [],[],[]  # Collect y values for the current counterfactual
    for v in raw_results["dataset"][counterfactual].values():
        for v2 in v.values():
            raw_outputs = v2["raw_outputs"][0]
    for raw_logits, input in zip(raw_outputs, task.counterfactual_datasets[counterfactual]):
        actual_diff = log_diff(raw_logits, task.causal_model.run_forward(input["input"]))
        high_level_output = task.causal_model.run_interchange(input["input"], {"output_token":input["counterfactual_inputs"][0], "output_position":input["counterfactual_inputs"][0]})
        loss = checker(raw_logits, high_level_output)
        label = high_level_output["logit_diff"]
        # print(actual_diff)
        # print(label)
        # print(loss)

        y.append(actual_diff)
        counterfactual_y.append(actual_diff)  # Append to the counterfactual-specific list
        X.append((data_to_X[counterfactual]["position"], data_to_X[counterfactual]["token"]))
        losses.append(loss)
        labels.append(label)
    
    # Compute and print the average y for the current counterfactual
    avg_y = sum(counterfactual_y) / len(counterfactual_y) if counterfactual_y else 0
    print(f"Average y for counterfactual '{counterfactual}': {avg_y}")
    print(f"Average label for counterfactual '{counterfactual}': {sum(labels) / len(labels)}")    
    print(f"Average loss for counterfactual '{counterfactual}': {sum(losses) / len(losses)}")

In [None]:
#TESTING BATCHES

# from experiments.LM_experiments import PatchIOIHeads
# from torch.utils.data import DataLoader



# data_to_X = {"same":{"position":1, "token":1}, 
#              "s1_io_flip_train":{"position": -1,"token":1},
#              "s2_io_flip_train":{"position":-1, "token":-1},
#              "s1_ioi_flip_s2_ioi_flip_train":{"position":1,
#                                               "token":-1}}
# X, y, = [], []
# total_loss = 0

# for counterfactual in data_to_X:
#     experiment = PatchIOIHeads(pipeline, task, list(range(0, 1)), token_positions, checker, config={"evaluation_batch_size": batch_size, "output_scores":True})
#     raw_results = experiment.perform_interventions([counterfactual], verbose=False)
#     raw_outputs = None
#     losses, labels, counterfactual_y = [],[],[]  # Collect y values for the current counterfactual
#     for v in raw_results["dataset"][counterfactual].values():
#         for v2 in v.values():
#             raw_outputs = v2["raw_outputs"][0]
#     for raw_logits, input in zip(raw_outputs, task.counterfactual_datasets[counterfactual]):
#         print(input)
#         print(task.counterfactual_datasets[counterfactual])
#         actual_diff = log_diff(raw_logits, task.causal_model.run_forward(input["input"]))
#         high_level_output = task.causal_model.run_interchange(input["input"], {"output_token":input["counterfactual_inputs"][0], "output_position":input["counterfactual_inputs"][0]})
#         loss = checker(raw_logits, high_level_output)
#         label = high_level_output["logit_diff"]
#         break
#         # print(input)
#         # print(actual_diff)
#         # print(label)
#         # print(loss)
#         # print(raw_logits)
#         # print()
#     dataloader = DataLoader(
#         task.label_counterfactual_data(task.counterfactual_datasets[counterfactual], ["output_token", "output_position"]),
#         batch_size=8,
#     )
#     for i, batch in enumerate(dataloader):
#         raw_logits = raw_outputs[i*8:i*8+8]
#         # print()
#         # print()
#         # print(batch)
#         actual_diff = log_diff(raw_logits, batch["label"])
#         loss = checker(raw_logits, batch["label"])
#         label = batch["label"]["logit_diff"]
#         # print(actual_diff)
#         # print(label)
#         # print(loss)
#         # print(raw_logits)
#         # print(awefa)

#         y += actual_diff
#         counterfactual_y += actual_diff  # Append to the counterfactual-specific list
#         losses += loss
#         labels += label
    
#     # Compute and print the average y for the current counterfactual
#     avg_y = sum(counterfactual_y) / len(counterfactual_y) if counterfactual_y else 0
#     print(f"Average y for counterfactual '{counterfactual}': {avg_y}")
#     print(f"Average label for counterfactual '{counterfactual}': {sum(labels) / len(labels)}")    
#     print(f"Average loss for counterfactual '{counterfactual}': {sum(losses) / len(losses)}")
#     total_loss += sum(losses) / len(losses)
# print("TOTAL LOSS:", total_loss/4)




In [None]:
#fit a linear model to the data
from sklearn.linear_model import LinearRegression

model = LinearRegression()
X = torch.tensor(X)
y = torch.tensor(y)
model.fit(X, y) 
#the loss function is the mean squared error
loss = model.score(X, y)
# Print the coefficients
print("Coefficients:", model.coef_)
print("Intercept:", model.intercept_)
print("Loss:", loss)

In [None]:
def get_logit_diff(name_A, name_B, name_C, output_token, output_position):
    token_signal = None 
    if (name_C == name_A and output_token == name_B) or (name_C == name_B and output_token == name_A):
        token_signal = 1
    elif (name_C == name_A and output_token == name_A) or (name_C == name_B and output_token == name_B):
        token_signal = -1

    position_signal = None 
    if (name_C == name_A and output_position == 1) or (name_C == name_B and output_position == 0):
        position_signal = 1
    elif (name_C == name_A and output_position == 0) or (name_C == name_B and output_position == 1):
        position_signal = -1

    return model.intercept_ + model.coef_[1]* token_signal + model.coef_[0]* position_signal

task.causal_model.mechanisms["logit_diff"] = get_logit_diff
# custom_loss = lambda logits, params: checker(logits, params).sum()
def custom_loss(logits, params):
    #average loss
    return checker(logits, params).mean()

In [None]:
data_to_X = {"same":{"position":1, "token":1}, 
             "s1_io_flip_train":{"position": -1,"token":1},
             "s2_io_flip_train":{"position":-1, "token":-1},
             "s1_ioi_flip_s2_ioi_flip_train":{"position":1,
                                              "token":-1}}
X, y, = [], []
total_loss = 0

for counterfactual in data_to_X:
    experiment = PatchIOIHeads(pipeline, task, list(range(0, 1)), None, token_positions, checker, config={"evaluation_batch_size": batch_size, "output_scores":True})
    raw_results = experiment.perform_interventions([counterfactual], verbose=False)
    raw_outputs = None
    losses, labels, counterfactual_y = [],[],[]  # Collect y values for the current counterfactual
    for v in raw_results["dataset"][counterfactual].values():
        for v2 in v.values():
            raw_outputs = v2["raw_outputs"][0]
    for raw_logits, input in zip(raw_outputs, task.counterfactual_datasets[counterfactual]):
        actual_diff = log_diff(raw_logits, task.causal_model.run_forward(input["input"]))
        high_level_output = task.causal_model.run_interchange(input["input"], {"output_token":input["counterfactual_inputs"][0], "output_position":input["counterfactual_inputs"][0]})
        loss = checker(raw_logits, high_level_output)
        label = high_level_output["logit_diff"]
        # print(actual_diff)
        # print(high_level_output["logit_diff"])
        # print(loss)

        y.append(actual_diff)
        counterfactual_y.append(actual_diff)  # Append to the counterfactual-specific list
        X.append((data_to_X[counterfactual]["position"], data_to_X[counterfactual]["token"]))
        losses.append(loss)
        labels.append(label)
    
    # Compute and print the average y for the current counterfactual
    avg_y = sum(counterfactual_y) / len(counterfactual_y) if counterfactual_y else 0
    print(f"Average y for counterfactual '{counterfactual}': {avg_y}")
    print(f"Average label for counterfactual '{counterfactual}': {sum(labels) / len(labels)}")    
    print(f"Average loss for counterfactual '{counterfactual}': {sum(losses) / len(losses)}")
    total_loss += sum(losses) / len(losses)
print("TOTAL LOSS:", total_loss/4)


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

start = 0
end = pipeline.get_num_layers()
config={"evaluation_batch_size": batch_size,"batch_size":256, "training_epoch":2, "n_features":32, "regularization_coefficient":0.0, "output_scores":True, "shuffle":True, "temperature_schedule":(1.0, 0.01), "init_lr":1.0}
counterfactuals = ["s1_io_flip", "s2_io_flip", "s1_ioi_flip_s2_ioi_flip"]
train_data = [counterfactual + "_train" for counterfactual in counterfactuals]
test_data = [counterfactual + "_test" for counterfactual in counterfactuals]
test_data += [counterfactual + "_testprivate" for counterfactual in counterfactuals]
verbose = True
results_dir = "ioi_results"

In [None]:
import itertools
from tqdm import tqdm

for m in tqdm([1,2,3,4]):
    for layer_heads_list in tqdm(list(itertools.combinations([(7, 3), (7, 9), (8, 6), (8, 10)], m)), desc="Layer heads combinations"):
        print("Running IOI over Layer heads list:", layer_heads_list)
        results_dir = f"ioi_results_search_{str(layer_heads_list)}"
        ioi_baselines(pipeline=pipeline, task=task, token_positions=token_positions, train_data=train_data, test_data=test_data, config=config, target_variables=["output_token"], checker=checker, custom_loss=custom_loss, start=start, end=end, verbose=verbose, results_dir=results_dir, heads_list=layer_heads_list, skip=["DAS", "DBM+PCA", "DBM"])
        ioi_baselines(pipeline=pipeline, task=task, token_positions=token_positions, train_data=train_data, test_data=test_data, config=config, target_variables=["output_position"], checker=checker, custom_loss=custom_loss, start=start, end=end, verbose=verbose, results_dir=results_dir, heads_list=layer_heads_list, skip=["DAS", "DBM+PCA", "DBM"])

In [None]:
ioi_baselines(pipeline=pipeline, task=task, token_positions=token_positions, train_data=train_data, test_data=test_data, config=config, target_variables=["output_position"], checker=checker, custom_loss=custom_loss, start=start, end=end, verbose=verbose, results_dir=results_dir, skip=[])

In [None]:
ioi_baselines(pipeline=pipeline, task=task, token_positions=token_positions, train_data=train_data, test_data=test_data, config=config, target_variables=["output_token"], checker=checker, custom_loss=custom_loss, start=start, end=end, verbose=verbose, results_dir=results_dir, skip=[])