In [1]:
import torch
from torch import compile
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
import os
import time
import sys
import wandb
import random
import numpy as np
import json
from tqdm import tqdm
from datasets import Dataset, load_from_disk
from src.data_utils import get_ravel_collate_fn, generate_ravel_dataset_from_filtered
from src.utils import add_fwd_hooks
import argparse
from pyvene import IntervenableConfig, RepresentationConfig, LowRankRotatedSpaceIntervention, IntervenableModel, count_parameters

from torch import optim
from transformers import AutoTokenizer, LlamaForCausalLM
from transformers import get_scheduler


%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.


In [2]:
model_name_or_path = "./models/llama3-8b"
das_path = "./models/country_baseline/final_das_module.pt"

train_set_path = "./experiments/ravel/data/ravel_city_Country_train"
test_set_path = "./experiments/ravel/data/ravel_city_Country_test"

intervention_location = "last_entity_token"

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_set = load_from_disk(train_set_path)
test_set = load_from_disk(test_set_path)
            
collate_fn = get_ravel_collate_fn(
    tokenizer, 
    source_suffix_visibility=True, 
    base_suffix_visibility=True, 
    add_space_before_target=True,
    contain_entity_position=True,
)

data_loader = DataLoader(
    train_set, batch_size=16, collate_fn=collate_fn, shuffle=True
)

test_data_loader = DataLoader(
    test_set, batch_size=16, collate_fn=collate_fn, shuffle=True
)

model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
model = model.to("cuda")

intervention_config = IntervenableConfig(
    model_type=type(model),
    representations=[
    RepresentationConfig(
            12,  # layer
            'block_output',  # intervention repr
            "pos",  # intervention unit
            1,  # max number of unit
            128)
    ],
    intervention_types=LowRankRotatedSpaceIntervention,
)

intervenable = IntervenableModel(intervention_config, model)
intervenable.set_device(model.device)
intervenable.disable_model_gradients()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.07it/s]


In [5]:
def forward(
    base_input_ids: torch.Tensor = None,
    base_attention_mask: torch.Tensor = None,
    base_intervention_position: torch.Tensor = None,
    base_position_ids: torch.Tensor = None,
    source_input_ids: torch.Tensor = None,
    source_attention_mask: torch.Tensor = None,
    source_intervention_position: torch.Tensor = None,
    source_position_ids: torch.Tensor = None,
    intervention_layer: int = None,
):
    if intervention_layer is None:
        raise ValueError("intervention_layer must be specified")
    
    if base_position_ids is None:
        # 0 for all the padding tokens and start from 1 for the rest
        base_position_ids = torch.cumsum(base_attention_mask, dim=1) * base_attention_mask
    
    if source_position_ids is None:
        source_position_ids = torch.cumsum(source_attention_mask, dim=1) * source_attention_mask
    
    # print(source_intervention_position.unsqueeze(0).shape, base_intervention_position.unsqueeze(0).shape)
    b_s = base_input_ids.shape[0]
    intervention_locations = {
        "sources->base": (
            source_intervention_position.unsqueeze(0).unsqueeze(-1),
            base_intervention_position.unsqueeze(0).unsqueeze(-1)
        )
    }
    
    _, counterfactual_outputs = intervenable(
        {
            "input_ids": base_input_ids,
            'attention_mask': base_attention_mask,
            'position_ids': base_position_ids
        }, [
            {
                "input_ids": source_input_ids,
                'attention_mask': source_attention_mask,
                'position_ids': source_position_ids
            }
        ] , intervention_locations
    )
    
    return counterfactual_outputs
    
            
def eval_accuracy(test_loader, eval_n_label_tokens=3):
    
    intervenable.eval()
    correct_idxs = []
    is_causal = []
    
    with torch.no_grad():
        for batch_id, batch in enumerate(test_loader):
            
            if intervention_location == "last_entity_token":
                base_intervention_position = batch["base_entity_position_ids"].to("cuda") 
                source_intervention_position = batch["source_entity_position_ids"].to("cuda")
            else:
                base_intervention_position = batch["base_input_ids"].shape[1] - 1
                source_intervention_position = batch["source_input_ids"].shape[1] - 1
                
                base_intervention_position = torch.tensor([base_intervention_position] * batch["base_input_ids"].shape[0]).to("cuda")
                source_intervention_position = torch.tensor([source_intervention_position] * batch["source_input_ids"].shape[0]).to("cuda")
            
            output = forward(
                base_input_ids=batch["base_input_ids"].to("cuda"),
                base_attention_mask=batch["base_attention_mask"].to("cuda"),
                base_intervention_position=base_intervention_position,
                source_input_ids=batch["source_input_ids"].to("cuda"),
                source_attention_mask=batch["source_attention_mask"].to("cuda"),
                source_intervention_position=source_intervention_position,
                intervention_layer=12,
            )
            
            logits = output.logits
                            
            batch_pred_ids = torch.argmax(logits, dim=-1)
            is_causal.extend(batch["is_causal"].cpu().numpy().tolist())
            
            for i, (label, pred_ids) in enumerate(zip(batch["labels"].to("cuda"), batch_pred_ids)):
                label_idx = label != -100
                output_idx = torch.zeros_like(label_idx)
                output_idx[:-1] = label_idx[1:]
                
                label = label[label_idx]
                pred_ids = pred_ids[output_idx]
                
                if eval_n_label_tokens is not None and len(label) > eval_n_label_tokens:
                    label = label[:eval_n_label_tokens]
                    pred_ids = pred_ids[:eval_n_label_tokens]
                
                is_correct = (torch.sum (label == pred_ids) == torch.numel(label)).item()    
                if is_correct:
                    correct_idxs.append(batch_id * len(batch["labels"]) + i)
            
            
    total_causal = sum(is_causal)
    total_isolate = len(is_causal) - total_causal
    
    correct_causal = sum([is_causal[i] for i in correct_idxs])
    correct_isolate = len(correct_idxs) - correct_causal
    
    causal_acc = correct_causal / total_causal if total_causal > 0 else 0.0
    isolate_acc = correct_isolate / total_isolate if total_isolate > 0 else 0.0
    
    disentangle_acc = 0.5 * (causal_acc + isolate_acc) if total_isolate > 0 else causal_acc
    
    accuracies = {
        "causal": causal_acc,
        "isolate": isolate_acc,
        "disentangle": disentangle_acc
    }
                
    return accuracies

In [6]:
inv_key = list(intervenable.interventions.keys())[0]
intervenable.interventions[inv_key][0].load_state_dict(torch.load(das_path))

eval_accuracy(test_data_loader)

  intervenable.interventions[inv_key][0].load_state_dict(torch.load(das_path))
Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


  unit_locations = torch.tensor(
  unit_locations = torch.tensor(


{'causal': 0.8653846153846154,
 'isolate': 0.890625,
 'disentangle': 0.8780048076923077}