In [1]:
import json
import random
import os
import math
import sys
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any, List, Optional
import nnsight
from nnsight import CONFIG, LanguageModel
import numpy as np
from collections import defaultdict
from einops import einsum
import time
from einops import rearrange, reduce
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import pandas as pd

sys.path.append("../")
from src.dataset import SampleV3, DatasetV3, STORY_TEMPLATES
from src.utils import env_utils
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(10)

CONFIG.set_default_api_key("d9e00ab7d4f74643b3176de0913f24a7")
os.environ["HF_TOKEN"] = "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"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
CONFIG.APP.REMOTE_LOGGING = False

%load_ext autoreload
%autoreload 2

env.yml not found in /home/local_nikhil/Projects/mind!
Setting MODEL_ROOT="". Models will now be downloaded to conda env cache, if not already there
Other defaults are set to:
    DATA_DIR = "data"
    RESULTS_DIR = "results"
    HPARAMS_DIR = "hparams"


# Loading Raw Data

In [2]:
all_states = {}
all_containers= {}
all_characters = json.load(open(os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities", "characters.json"), "r"))

for TYPE, DCT in {"states": all_states, "containers": all_containers}.items():
    ROOT = os.path.join(
        env_utils.DEFAULT_DATA_DIR, "synthetic_entities", TYPE
    )
    for file in os.listdir(ROOT):
        file_path = os.path.join(ROOT, file)
        with open(file_path, "r") as f:
            names = json.load(f)
        DCT[file.split(".")[0]] = names

# Loading model

In [3]:
# model = LanguageModel("meta-llama/Meta-Llama-3.1-405B")
model = LanguageModel("meta-llama/Meta-Llama-3-8B-Instruct", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16, dispatch=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

# Loading Helper Functions

In [4]:
def get_ques_start_token_idx(tokenizer, prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    corrolary_token = tokenizer.encode(":", return_tensors="pt").squeeze()[-1].item()
    ques_start_idx = (input_tokens == corrolary_token).nonzero()[2].item()

    return ques_start_idx-1

In [5]:
def get_prompt_token_len(tokenizer, prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    return len(input_tokens)

In [6]:
def check_pred(pred, target, verbose=False):
    prompt = f"Instruction: Check if the following ground truth and prediction of the state of the object mean the same thing or different. If they mean the same, then predict 'Yes', else 'No' \n\nGround truth: {target}\nPrediction: {pred}\nAnswer:"
    
    if verbose:
        print(prompt)

    with torch.no_grad():
        with model.generate(prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id):
            out = model.generator.output.save()

    prompt_len = get_prompt_token_len(model.tokenizer, prompt)

    return model.tokenizer.decode(out[0][prompt_len:-1]).strip()

# Loading BigToM dataset

In [3]:
# Read a csv file
df_false = pd.read_csv("../data/bigtom/0_forward_belief_false_belief/stories.csv", delimiter=";")
df_true = pd.read_csv("../data/bigtom/0_forward_belief_true_belief/stories.csv", delimiter=";")

In [4]:
# For each row in the dataframe extract story, answer, and distractor
true_stories, false_stories = [], []
for i in range(len(df_true)):
    story = df_true.iloc[i]['story']
    question = df_true.iloc[i]['question']
    answer = df_true.iloc[i]['answer']
    distractor = df_true.iloc[i]['distractor']
    true_stories.append({"story": story, "question": question, "answer": answer, "distractor": distractor})

for i in range(len(df_false)):
    story = df_false.iloc[i]['story']
    question = df_true.iloc[i]['question']
    answer = df_false.iloc[i]['answer']
    distractor = df_false.iloc[i]['distractor']
    false_stories.append({"story": story, "question": question, "answer": answer, "distractor": distractor})

dataset = []
instruction = "1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only the final state of the queried object in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict container or character as the final output."

for i in range(min(len(true_stories), len(false_stories))):
    question = true_stories[i]['question']
    alt_prompt = f"Instructions: {instruction}\n\nStory: {true_stories[i]['story']}\nQuestion: {question}\nAnswer:"

    question = false_stories[i]['question']
    org_prompt = f"Instructions: {instruction}\n\nStory: {false_stories[i]['story']}\nQuestion: {question}\nAnswer:"

    dataset.append({
        "visible_prompt": alt_prompt,
        "visible_ans": true_stories[i]['answer'],
        "invisible_prompt": org_prompt,
        "invisible_ans": false_stories[i]['answer'],
    })

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [6]:
idx = 0
print(dataset[idx]['visible_prompt'], dataset[idx]['visible_ans'])
print(dataset[idx]['invisible_prompt'], dataset[idx]['invisible_ans'])

Instructions: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only the final state of the queried object in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict container or character as the final output.

Story: Noor is working as a barista at a busy coffee shop. Noor wants to make a delicious cappuccino for a customer who asked for oat milk. Noor grabs a milk pitcher and fills it with oat milk. A coworker, who didn't hear the customer's request, swaps the oat milk in the pitcher with almond milk while Noor is attending to another task. Noor sees her coworker swapping the milk.
Question: Does Noor believe the milk pitcher contains oat milk or almond mil

# Loading Custom Data

In [7]:
train_size = 100
valid_size = 20
batch_size = 1

train_dataset = get_visibility_align_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size,
                             question_type="belief_question",
                             diff_visibility=True)
valid_dataset = get_visibility_align_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             valid_size,
                             question_type="belief_question",
                             diff_visibility=True)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

# Loading Singular Vectors

In [8]:
sing_vecs = defaultdict(dict)
for l in range(41):
    # sing_vecs[l] = torch.load(f"/media/sda/bigtom/singular_vecs/{l}.pt").cpu()
    sing_vecs[l] = torch.load(f"../svd_results/selected_tokens_diff/singular_vecs/{l}.pt").cpu()

# DCM

In [11]:
modules = [i for i in range(sing_vecs[0].shape[0])]
mask = torch.ones(len(modules), requires_grad=True, device="cuda")
optimizer = torch.optim.Adam([mask], lr=1e-1)
n_epochs = 1
lamb = 0.01

In [19]:
layer_idx = 40
for epoch in range(n_epochs):
    epoch_loss = 0

    for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        alt_prompt = batch["corrupt_prompt"]
        org_prompt = batch["clean_prompt"]
        target = batch["target"]
        target_token = model.tokenizer(target, return_tensors="pt").input_ids
        batch_size = target_token.size(0)

        mask.data.clamp_(0, 1)
        optimizer.zero_grad()

        alt_acts = defaultdict(dict)
        with model.trace() as tracer:
            with tracer.invoke(alt_prompt):
                for l in range(layer_idx+1):
                    for t in range(-8, 0, -1):
                        alt_acts[l][t] = model.model.layers[l].output[0][:, t].clone()

            with tracer.invoke(org_prompt):
                for l in range(layer_idx+1):
                    # Select the singular vectors (#vecs, hidden_size) layer using the mask (#vecs)
                    sing_vec = sing_vecs[l].cuda() # (#vecs, hidden_size)
                    sing_vec = sing_vec * mask.unsqueeze(-1) # (#vecs, hidden_size)
                    sing_vec = sing_vec.t().cpu() # (hidden_size, #vecs)

                    proj_matrix = torch.matmul(sing_vec, sing_vec.t()) # (hidden_size, hidden_size)

                    for t in range(-8, 0, -1):
                        alt_proj = torch.matmul(alt_acts[l][t], proj_matrix) # (batch_size, #hidden_size)
                        org_proj = torch.matmul(model.model.layers[l].output[0][:, t], proj_matrix) # (batch_size, #hidden_size)

                        model.model.layers[l].output[0][:, t] = (model.model.layers[l].output[0][:, t] - org_proj) + alt_proj

                    del sing_vec, proj_matrix
                    torch.cuda.empty_cache()

                logits = model.lm_head.output[:, -1].save()

        target_logit = logits[torch.arange(batch_size), target_token].sum()
        loss = -target_logit + lamb * torch.sum(1 - mask)
        epoch_loss += loss.item()

        if bi % 10 == 0:
            mean_loss = epoch_loss / (bi+1)
            print(f"Epoch: {epoch}, Batch: {bi}, Loss: {mean_loss}")

        loss.backward()
        optimizer.step()

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch: 0, Batch: 0, Loss: -10.6015625


 10%|█         | 10/100 [03:02<27:09, 18.10s/it]

Epoch: 0, Batch: 10, Loss: -7.9577414772727275


 20%|██        | 20/100 [06:07<24:41, 18.52s/it]

Epoch: 0, Batch: 20, Loss: -7.878627232142857


 30%|███       | 30/100 [09:16<21:48, 18.69s/it]

Epoch: 0, Batch: 30, Loss: -8.268460181451612


 35%|███▌      | 35/100 [10:55<20:17, 18.73s/it]


KeyboardInterrupt: 

In [None]:
# (#vecs, hidden_size)
sing_vec = sing_vec * mask.unsqueeze(-1) # (#vecs, hidden_size)

proj_matrix = torch.matmul(sing_vec, sing_vec.T)

for t in range(-8, 0):
    alt_proj = torch.matmul(alt_acts[l][t], proj_matrix)