In [1]:
import json
import random
import os
import math
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any, List, Optional
import nnsight
from nnsight import CONFIG, LanguageModel
import numpy as np
from collections import defaultdict
from einops import einsum
import time
from einops import rearrange, reduce
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd

sys.path.append("../")
from src.dataset import SampleV3, DatasetV3, STORY_TEMPLATES
from src.utils import env_utils
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(10)

CONFIG.set_default_api_key("d9e00ab7d4f74643b3176de0913f24a7")
os.environ["HF_TOKEN"] = "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"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
CONFIG.APP.REMOTE_LOGGING = False

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
env.yml not found in /disk/u/nikhil/mind!
Setting MODEL_ROOT="". Models will now be downloaded to conda env cache, if not already there
Other defaults are set to:
    DATA_DIR = "data"
    RESULTS_DIR = "results"
    HPARAMS_DIR = "hparams"


# Loading Raw Data

In [2]:
all_states = {}
all_containers= {}
all_characters = json.load(open(os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities", "characters.json"), "r"))

for TYPE, DCT in {"states": all_states, "containers": all_containers}.items():
    ROOT = os.path.join(
        env_utils.DEFAULT_DATA_DIR, "synthetic_entities", TYPE
    )
    for file in os.listdir(ROOT):
        file_path = os.path.join(ROOT, file)
        with open(file_path, "r") as f:
            names = json.load(f)
        DCT[file.split(".")[0]] = names

# Loading model

In [3]:
# model = LanguageModel("meta-llama/Meta-Llama-3.1-405B")
model = LanguageModel("meta-llama/Meta-Llama-3-70B-Instruct", cache_dir="/disk/u/nikhil/.cache/huggingface/hub/", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16, dispatch=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 30/30 [01:25<00:00,  2.84s/it]


In [4]:
model.eval()
for param in model.parameters():
    param.requires_grad_(False)

# Loading Helper Functions

In [4]:
def get_ques_start_token_idx(tokenizer, prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    corrolary_token = tokenizer.encode(":", return_tensors="pt").squeeze()[-1].item()
    ques_start_idx = (input_tokens == corrolary_token).nonzero()[2].item()

    return ques_start_idx-1

In [5]:
def get_prompt_token_len(tokenizer, prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    return len(input_tokens)

In [6]:
def check_pred(pred, target, verbose=False):
    prompt = f"Instruction: Check if the following ground truth and prediction of the state of the object mean the same thing or different. If they mean the same, then predict 'Yes', else 'No' \n\nGround truth: {target}\nPrediction: {pred}\nAnswer:"
    
    if verbose:
        print(prompt)

    with torch.no_grad():
        with model.generate(prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id):
            out = model.generator.output.save()

    prompt_len = get_prompt_token_len(model.tokenizer, prompt)

    return model.tokenizer.decode(out[0][prompt_len:-1]).strip()

# Loading BigToM dataset

In [7]:
# Read a csv file
df_false = pd.read_csv("../data/bigtom/0_forward_belief_false_belief/stories.csv", delimiter=";")
df_true = pd.read_csv("../data/bigtom/0_forward_belief_true_belief/stories.csv", delimiter=";")

In [8]:
# For each row in the dataframe extract story, answer, and distractor
true_stories, false_stories = [], []
for i in range(len(df_true)):
    story = df_true.iloc[i]['story']
    question = df_true.iloc[i]['question']
    answer = df_true.iloc[i]['answer']
    distractor = df_true.iloc[i]['distractor']
    true_stories.append({"story": story, "question": question, "answer": answer, "distractor": distractor})

for i in range(len(df_false)):
    story = df_false.iloc[i]['story']
    question = df_true.iloc[i]['question']
    answer = df_false.iloc[i]['answer']
    distractor = df_false.iloc[i]['distractor']
    false_stories.append({"story": story, "question": question, "answer": answer, "distractor": distractor})

dataset = []
instruction = "1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only the final state of the queried object in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict container or character as the final output."

for i in range(min(len(true_stories), len(false_stories))):
    question = true_stories[i]['question']
    alt_prompt = f"Instructions: {instruction}\n\nStory: {true_stories[i]['story']}\nQuestion: {question}\nAnswer:"

    question = false_stories[i]['question']
    org_prompt = f"Instructions: {instruction}\n\nStory: {false_stories[i]['story']}\nQuestion: {question}\nAnswer:"

    dataset.append({
        "visible_prompt": alt_prompt,
        "visible_ans": true_stories[i]['answer'],
        "invisible_prompt": org_prompt,
        "invisible_ans": false_stories[i]['answer'],
    })

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [9]:
idx = 0
print(dataset[idx]['visible_prompt'], dataset[idx]['visible_ans'])
print(dataset[idx]['invisible_prompt'], dataset[idx]['invisible_ans'])

Instructions: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only the final state of the queried object in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict container or character as the final output.

Story: Noor is working as a barista at a busy coffee shop. Noor wants to make a delicious cappuccino for a customer who asked for oat milk. Noor grabs a milk pitcher and fills it with oat milk. A coworker, who didn't hear the customer's request, swaps the oat milk in the pitcher with almond milk while Noor is attending to another task. Noor sees her coworker swapping the milk.
Question: Does Noor believe the milk pitcher contains oat milk or almond mil

# Loading Custom Data

In [5]:
train_size = 40
valid_size = 20
batch_size = 4

train_dataset = get_visibility_align_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size,
                             question_type="belief_question",
                             diff_visibility=True)
valid_dataset = get_visibility_align_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             valid_size,
                             question_type="belief_question",
                             diff_visibility=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=valid_size, shuffle=False)

In [6]:
idx = 0
print(train_dataset[idx]['corrupt_prompt'], train_dataset[idx]['corrupt_ans'])
print(train_dataset[idx]['clean_prompt'], train_dataset[idx]['clean_ans'])
print(train_dataset[idx]['target'])

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Max and Karen are working in a busy restaurant. To complete an order, Max grabs an opaque tun and fills it with coffee. Then Karen grabs another opaque dispenser and fills it with cocoa. They are working side by side and can clearly observe each other's actions.
Question: What does Karen believe the tun contains?
Answer: coffee
Instruction: 1. Track the belief of each charact

# Loading Singular Vectors

In [7]:
sing_vecs = defaultdict(dict)
for l in range(41):
    # sing_vecs[l] = torch.load(f"/media/sda/bigtom/singular_vecs/{l}.pt").cpu()
    sing_vecs[l] = torch.load(f"/disk/u/nikhil/mind/selected_tokens_diff/singular_vecs/{l}.pt").cpu()

# DCM

In [14]:
patch_layers = [i for i in range(0, 40, 10)] + [i for i in range(22, 30, 2)]

In [18]:
# valid_accs, rank = {}, {}
for layer_idx in range(32, 34, 2):
    modules = [i for i in range(sing_vecs[0].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)
    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.1

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0
        
        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompt = batch["corrupt_prompt"]
            org_prompt = batch["clean_prompt"]
            target = batch["target"]
            target_token = model.tokenizer(target, return_tensors="pt").to("cuda").input_ids[:, -1]
            batch_size = target_token.size(0)
            
            optimizer.zero_grad()
            
            with model.trace() as tracer:
                # Cache alternative activations
                alt_acts = defaultdict(dict)
                with tracer.invoke(alt_prompt):
                    for l in range(layer_idx, layer_idx+1):
                        for t in range(-8, 0):
                            alt_acts[l][t] = model.model.layers[l].output[0][:, t].clone()
                
                # Process original prompt with modifications
                with tracer.invoke(org_prompt):
                    for l in range(layer_idx, layer_idx+1):
                        sing_vec = sing_vecs[l].cuda()
                        # Apply mask and ensure gradients flow
                        masked_vec = sing_vec * mask.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()
                        
                        for t in range(-8, 0):
                            curr_output = model.model.layers[l].output[0][:, t].clone()
                            
                            # Compute projections while maintaining gradients
                            alt_proj = torch.matmul(alt_acts[l][t], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)
                            
                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[l].output[0][:, t] = modified_out
                        
                        del sing_vec, proj_matrix, masked_vec
                        torch.cuda.empty_cache()
                    
                    logits = model.lm_head.output[:, -1].save()
            
            # Compute loss with L1 regularization
            target_logit = logits[torch.arange(batch_size).cuda(), target_token]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss
            
            epoch_loss += loss.item()
            
            if bi % 2 == 0:
                mean_loss = epoch_loss / (bi + 1)
                print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
                    f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
                with torch.no_grad():
                    mask.data.clamp_(0, 1)
                    rounded = torch.round(mask)
                    print(f"#Causal SVs: {(rounded == 1).sum().item()}")
            
            loss.backward()
            optimizer.step()
            
            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")


    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0
    with torch.inference_mode():
        mask_data = mask.data.clone()
        mask_data.clamp_(0, 1)
        rounded = torch.round(mask_data)
        print(f"#Causal SVs: {(rounded == 1).sum().item()}")
        rank[layer_idx] = (rounded == 1).sum().item()

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompt = batch["corrupt_prompt"]
            org_prompt = batch["clean_prompt"]
            target = batch["target"]
            target_token = model.tokenizer(target, return_tensors="pt").input_ids[:, -1]
            batch_size = target_token.size(0)

            with model.trace() as tracer:
                # Cache alternative activations
                alt_acts = defaultdict(dict)
                with tracer.invoke(alt_prompt):
                    for l in range(layer_idx, layer_idx+1):
                        for t in range(-8, 0):
                            alt_acts[l][t] = model.model.layers[l].output[0][:, t]

                # Process original prompt with modifications
                with tracer.invoke(org_prompt):
                    for l in range(layer_idx, layer_idx+1):
                        sing_vec = sing_vecs[l].cuda()
                        # Apply mask and ensure gradients flow
                        masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        for t in range(-8, 0):
                            curr_output = model.model.layers[l].output[0][:, t].clone()

                            # Compute projections while maintaining gradients
                            alt_proj = torch.matmul(alt_acts[l][t], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)

                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[l].output[0][:, t] = modified_out

                        del sing_vec, proj_matrix, masked_vec
                        torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_token.device).cpu()
            
            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {target[i].lower().strip()}")
                if pred_token.lower().strip() == target[i].lower().strip():
                    correct += 1
                total += 1
            
            del alt_acts, alt_prompt, org_prompt, target, target_token, logits, pred
            torch.cuda.empty_cache()

    print(f"Validation accuracy: {correct / total:.2f}")
    valid_accs[layer_idx] = round(correct / total, 2)

Training layer: 32


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 0, Batch: 0, Task Loss: -19.2500, L1 Loss: 250.0000, Total Loss: 230.7500
#Causal SVs: 2500


 20%|██        | 2/10 [00:41<02:47, 20.95s/it]

Epoch: 0, Batch: 2, Task Loss: -18.7812, L1 Loss: 199.0000, Total Loss: 205.4688
#Causal SVs: 2500


 40%|████      | 4/10 [01:25<02:08, 21.48s/it]

Epoch: 0, Batch: 4, Task Loss: -18.5469, L1 Loss: 150.0000, Total Loss: 180.5500
#Causal SVs: 2500


 60%|██████    | 6/10 [02:08<01:26, 21.63s/it]

Epoch: 0, Batch: 6, Task Loss: -19.0000, L1 Loss: 99.5000, Total Loss: 155.4509
#Causal SVs: 6


 80%|████████  | 8/10 [02:52<00:43, 21.69s/it]

Epoch: 0, Batch: 8, Task Loss: -19.1719, L1 Loss: 49.7500, Total Loss: 130.5451
#Causal SVs: 5


100%|██████████| 10/10 [03:35<00:00, 21.55s/it]


Training complete for 32!
Validation started for 32
#Causal SVs: 4


100%|██████████| 1/1 [01:02<00:00, 62.78s/it]

Validation accuracy: 1.00





In [19]:
# Sort valid_accs and rank by key
valid_accs = dict(sorted(valid_accs.items()))
rank = dict(sorted(rank.items()))
valid_accs, rank

({0: 0.0,
  10: 0.0,
  20: 0.05,
  22: 0.05,
  24: 0.4,
  26: 0.85,
  28: 1.0,
  30: 1.0,
  32: 1.0},
 {0: 0, 10: 0, 20: 2, 22: 3, 24: 3, 26: 5, 28: 4, 30: 3, 32: 4})

In [21]:
# Find the indices in rounded which are 1
idx = rounded.nonzero().squeeze()
idx

tensor([0, 1, 2, 5], device='cuda:7')