In [1]:
import json
import random
import os
import math
import sys
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any, List, Optional
import nnsight
from nnsight import CONFIG, LanguageModel
from collections import defaultdict
from einops import rearrange, reduce

sys.path.append("../")
from src.dataset import SampleV3, DatasetV3, STORY_TEMPLATES
from src.utils import env_utils
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(10)

CONFIG.set_default_api_key("d9e00ab7d4f74643b3176de0913f24a7")
os.environ["HF_TOKEN"] = "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"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
CONFIG.APP.REMOTE_LOGGING = False

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
env.yml not found in /home/local_nikhil/Projects/mind!
Setting MODEL_ROOT="". Models will now be downloaded to conda env cache, if not already there
Other defaults are set to:
    DATA_DIR = "data"
    RESULTS_DIR = "results"
    HPARAMS_DIR = "hparams"


# Loading Data

In [2]:
all_states = {}
all_containers= {}
all_characters = json.load(open(os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities", "characters.json"), "r"))

for TYPE, DCT in {"states": all_states, "containers": all_containers}.items():
    ROOT = os.path.join(
        env_utils.DEFAULT_DATA_DIR, "synthetic_entities", TYPE
    )
    for file in os.listdir(ROOT):
        file_path = os.path.join(ROOT, file)
        with open(file_path, "r") as f:
            names = json.load(f)
        DCT[file.split(".")[0]] = names

# Loading Model

In [3]:
# model = LanguageModel("meta-llama/Meta-Llama-3.1-405B")
model = LanguageModel("meta-llama/Meta-Llama-3-70B-Instruct", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16, dispatch=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 30/30 [01:03<00:00,  2.10s/it]


# Dataset creation

In [150]:
n_samples = 40
batch_size = 1

dataset = get_visibility_align_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             n_samples,
                             question_type="belief_question",
                             diff_visibility=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [151]:
idx = 0
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Kyle and Ben are working in a busy restaurant side by side and can clearly observe each other's actions. To complete an order, Kyle grabs an opaque tun and fills it with cocoa. Then Ben grabs another opaque bottle and fills it with bourbon.
Question: What does Ben believe the tun contains?
Answer: cocoa
Instruction: 1. Track the belief of each character as described in the st

In [152]:
visibility_sent = [i for i in range(129, 152)]
content_sent = [i for i in range(152, 180)]
first_sent = [i for i in range(152, 168)]
second_sent = [i for i in range(168, 179)]
first_charac = [157, 158]
second_charac = [169, 170]
query_sent = [-8, -7]

input_tokens = model.tokenizer(dataset[idx]['corrupt_prompt'], return_tensors="pt").input_ids
print(model.tokenizer.decode(input_tokens[0][second_charac]))

 Ben grabs


In [159]:
n_layers = 40
vis_acts = torch.zeros(n_samples, len(query_sent), n_layers, model.config.hidden_size) # sample, token, layer, hidden_size
no_vis_acts = torch.zeros(n_samples, len(query_sent), n_layers, model.config.hidden_size)

for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    vis_prompt = batch['corrupt_prompt'][0]
    no_vis_prompt = batch['clean_prompt'][0]

    with torch.no_grad():

        with model.trace() as tracer:

            with tracer.invoke(vis_prompt):
                for l in range(n_layers):
                    for t_idx, t in enumerate([-8, -7]):
                        vis_acts[i, t_idx, l] = model.model.layers[l].output[0][0, t].cpu().save()

            with tracer.invoke(no_vis_prompt):
                for l in range(n_layers):
                    for t_idx, t in enumerate([-8, -7]):
                        no_vis_acts[i, t_idx, l] = model.model.layers[l].output[0][0, t].cpu().save()

  0%|          | 0/40 [00:00<?, ?it/s]

100%|██████████| 40/40 [02:36<00:00,  3.91s/it]


In [160]:
mean_acts = torch.empty(50, model.config.hidden_size)

for l in range(n_layers):
    mean_acts[l] = torch.mean(vis_acts[:, :, l] - no_vis_acts[:, :, l], dim=(0, 1))

In [163]:
correct, total = 0, 0
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    vis_prompt = batch['corrupt_prompt'][0]
    no_vis_prompt = batch['clean_prompt'][0]
    vis_ans = batch['corrupt_ans'][0]
    no_vis_ans = batch['clean_ans'][0]

    with torch.no_grad():

        with model.trace() as tracer:

            with tracer.invoke(no_vis_prompt):
                for l in range(n_layers):
                    for t_idx, t in enumerate(second_charac):
                        model.model.layers[l].output[0][0, t] = model.model.layers[l].output[0][0, t] + mean_acts[l]

            pred = model.lm_head.output[0, -1].argmax(dim=-1).save()
        
        print(f"Pred: {model.tokenizer.decode([pred]).lower().strip()} | Target: {vis_ans}")
        if model.tokenizer.decode([pred]).lower().strip() == vis_ans:
            correct += 1
        total += 1

        del pred
        torch.cuda.empty_cache()

print(f"Layer: {l} | Accuracy: {round(correct/total, 2)}")

  0%|          | 0/40 [00:00<?, ?it/s]

  2%|▎         | 1/40 [00:02<01:34,  2.41s/it]

Pred: unknown | Target: cocoa


  5%|▌         | 2/40 [00:04<01:31,  2.40s/it]

Pred: unknown | Target: cocktail


  8%|▊         | 3/40 [00:07<01:28,  2.40s/it]

Pred: unknown | Target: beer


 10%|█         | 4/40 [00:09<01:26,  2.40s/it]

Pred: unknown | Target: bourbon


 12%|█▎        | 5/40 [00:12<01:24,  2.40s/it]

Pred: unknown | Target: tea


 15%|█▌        | 6/40 [00:14<01:21,  2.41s/it]

Pred: soda | Target: soda


 18%|█▊        | 7/40 [00:16<01:19,  2.41s/it]

Pred: unknown | Target: cocoa


 20%|██        | 8/40 [00:19<01:17,  2.42s/it]

Pred: unknown | Target: porter


 22%|██▎       | 9/40 [00:21<01:15,  2.42s/it]

Pred: unknown | Target: tea


 25%|██▌       | 10/40 [00:24<01:12,  2.43s/it]

Pred: unknown | Target: gin


 28%|██▊       | 11/40 [00:26<01:10,  2.43s/it]

Pred: unknown | Target: sprite


 30%|███       | 12/40 [00:29<01:08,  2.44s/it]

Pred: unknown | Target: milk


 32%|███▎      | 13/40 [00:31<01:05,  2.44s/it]

Pred: unknown | Target: monster


 35%|███▌      | 14/40 [00:33<01:03,  2.44s/it]

Pred: unknown | Target: cocoa


 38%|███▊      | 15/40 [00:36<01:01,  2.45s/it]

Pred: unknown | Target: punch


 40%|████      | 16/40 [00:38<00:58,  2.45s/it]

Pred: milk | Target: milk


 42%|████▎     | 17/40 [00:41<00:56,  2.46s/it]

Pred: unknown | Target: ale


 45%|████▌     | 18/40 [00:43<00:54,  2.46s/it]

Pred: unknown | Target: rum


 48%|████▊     | 19/40 [00:46<00:51,  2.47s/it]

Pred: unknown | Target: coffee


 50%|█████     | 20/40 [00:48<00:49,  2.47s/it]

Pred: unknown | Target: cocktail


 52%|█████▎    | 21/40 [00:51<00:46,  2.47s/it]

Pred: unknown | Target: bourbon


 55%|█████▌    | 22/40 [00:53<00:44,  2.47s/it]

Pred: unknown | Target: espresso


 57%|█████▊    | 23/40 [00:56<00:42,  2.48s/it]

Pred: unknown | Target: milk


 60%|██████    | 24/40 [00:58<00:40,  2.53s/it]

Pred: unknown | Target: cocoa


 62%|██████▎   | 25/40 [01:01<00:37,  2.52s/it]

Pred: unknown | Target: milk


 65%|██████▌   | 26/40 [01:03<00:35,  2.51s/it]

Pred: unknown | Target: beer


 68%|██████▊   | 27/40 [01:06<00:32,  2.50s/it]

Pred: unknown | Target: water


 70%|███████   | 28/40 [01:08<00:30,  2.50s/it]

Pred: unknown | Target: cocktail


 72%|███████▎  | 29/40 [01:11<00:27,  2.50s/it]

Pred: punch | Target: punch


 75%|███████▌  | 30/40 [01:13<00:25,  2.50s/it]

Pred: coffee | Target: coffee


 78%|███████▊  | 31/40 [01:16<00:22,  2.50s/it]

Pred: unknown | Target: cocktail


 80%|████████  | 32/40 [01:18<00:20,  2.50s/it]

Pred: unknown | Target: punch


 82%|████████▎ | 33/40 [01:21<00:17,  2.50s/it]

Pred: unknown | Target: bourbon


 85%|████████▌ | 34/40 [01:23<00:15,  2.50s/it]

Pred: unknown | Target: port


 88%|████████▊ | 35/40 [01:26<00:12,  2.51s/it]

Pred: unknown | Target: bourbon


 90%|█████████ | 36/40 [01:28<00:10,  2.51s/it]

Pred: unknown | Target: beer


 92%|█████████▎| 37/40 [01:31<00:07,  2.51s/it]

Pred: unknown | Target: stout


 95%|█████████▌| 38/40 [01:33<00:05,  2.51s/it]

Pred: soda | Target: soda


 98%|█████████▊| 39/40 [01:36<00:02,  2.51s/it]

Pred: unknown | Target: champagne


100%|██████████| 40/40 [01:38<00:00,  2.47s/it]

Pred: unknown | Target: water
Layer: 39 | Accuracy: 0.12





# Probing

In [4]:
n_samples = 100
batch_size = 1

configs = []
for _ in range(n_samples):
    template = STORY_TEMPLATES['templates'][0]
    characters = random.sample(all_characters, 2)
    containers = random.sample(all_containers[template["container_type"]], 2)
    states = random.sample(all_states[template["state_type"]], 2)
    event_idx = None
    event_noticed = False

    sample = SampleV3(
        template=template,
        characters=characters,
        containers=containers,
        states=states,
        visibility=False,
        event_idx=event_idx,
        event_noticed=event_noticed,
    )
    configs.append(sample)

dataset = DatasetV3(configs)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [5]:
idx = 0
print(dataset[idx]['prompt'], dataset[idx]['target'])

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Max and Karen are working in a busy restaurant. To complete an order, Max grabs an opaque tun and fills it with port. Then Karen grabs another opaque dispenser and fills it with water. They are working in the entirely separate sections, with no visibility between them.
Question: What does Karen believe the tun contains?
Answer: water


In [38]:
input_tokens = model.tokenizer(dataset[idx]['prompt'], return_tensors="pt").input_ids
print(model.tokenizer.decode(input_tokens[0][[146, 158]]))

 Max Karen


In [39]:
probing_layer = 20
n_layers = 40

charac_indices = [146, 158]
object_indices = [150, 162]
state_indices = [155, 167]
first_sent = [i for i in range(141, 157)]
second_sent = [i for i in range(158, 169)]

acts = torch.zeros(n_samples, n_layers, len(first_sent)+len(second_sent), model.config.hidden_size)

for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    prompt = batch['prompt'][0]
    target = batch['target'][0]

    with torch.no_grad():
        with model.trace(prompt):
            for l in range(n_layers):
                for t_idx, t in enumerate(first_sent+second_sent):
                    acts[i, l, t_idx] = model.model.layers[l].output[0][0, t].cpu().save()

 30%|███       | 30/100 [01:24<03:17,  2.82s/it]


KeyboardInterrupt: 

In [None]:
class ProbingClassifier(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProbingClassifier, self).__init__()
        self.fc = torch.nn.Linear(input_dim, output_dim)
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        x = self.fc(x)
        return self.sigmoid(x)

In [143]:
# Create a training data using cached activations
training_samples = []
labels = []

for i in range(200):
    for t_idx, t in enumerate(first_sent+second_sent):
        if t == 161:
            labels.append(torch.tensor(1))
            training_samples.append(acts[i, 20, t_idx, :])
        elif t in [173] + charac_indices + state_indices:
            labels.append(torch.tensor(0))
            training_samples.append(acts[i, 20, t_idx, :])

In [144]:
# Create a dataloader using training_samples and labels
training_samples = torch.stack(training_samples)
labels = torch.stack(labels)
train_data = torch.utils.data.TensorDataset(training_samples, labels)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=10, shuffle=True)

In [146]:
# Train a probing classifier
classifier = ProbingClassifier(model.config.hidden_size, 1)
classifier.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

n_epochs = 50
for epoch in range(n_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device).unsqueeze(1)

        optimizer.zero_grad()
        outputs = classifier(inputs)
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch: {epoch} | Loss: {loss.item()}")
    
    del inputs, labels, outputs
    torch.cuda.empty_cache()

Epoch: 0 | Loss: 0.010615534149110317


Epoch: 10 | Loss: 0.002030023140832782
Epoch: 20 | Loss: 0.0003757626691367477
Epoch: 30 | Loss: 0.00026621154393069446
Epoch: 40 | Loss: 8.729223918635398e-05


In [147]:
batch_size = 1

configs = []
for _ in range(5):
    template = STORY_TEMPLATES['templates'][1]
    characters = random.sample(all_characters, 2)
    containers = random.sample(all_containers[template["container_type"]], 2)
    states = random.sample(all_states[template["state_type"]], 2)
    event_idx = None
    event_noticed = False

    sample = SampleV3(
        template=template,
        characters=characters,
        containers=containers,
        states=states,
        visibility=True,
        event_idx=event_idx,
        event_noticed=event_noticed,
    )
    configs.append(sample)

dataset = DatasetV3(configs)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [148]:
idx = 1
print(dataset[idx]['prompt'], dataset[idx]['target'])

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Lee and Laura are working in a busy restaurant side by side and can clearly observe each other's actions. To complete an order, Lee grabs an opaque quart and fills it with beer. Then Laura grabs another opaque pint and fills it with rum.
Question: What does Laura believe the quart contains?
Answer: beer


In [149]:
# Predict classifier output for the query character
with torch.no_grad():
    for l in range(0, 40, 10):
        preds = []
        for data in tqdm(dataloader, total=len(dataloader)):
            prompt = data['prompt'][0]
            target = data['target'][0]

            with model.trace(prompt):
                query_charac_act = model.model.layers[l].output[0][0, -8].save()
            
            query_charac = query_charac_act.unsqueeze(0)
            query_charac = query_charac.to(device)
            query_charac = query_charac.float()

            output = classifier(query_charac)
            # print(output)
            preds.append(output)

            del query_charac_act, query_charac, output
            torch.cuda.empty_cache()
        
        print(f"Layer: {l} | Output: {sum(preds)/len(preds)}")

  0%|          | 0/5 [00:00<?, ?it/s]

100%|██████████| 5/5 [00:12<00:00,  2.43s/it]


Layer: 0 | Output: tensor([[0.4794]], device='cuda:0')


100%|██████████| 5/5 [00:12<00:00,  2.46s/it]


Layer: 10 | Output: tensor([[0.2065]], device='cuda:0')


100%|██████████| 5/5 [00:12<00:00,  2.48s/it]


Layer: 20 | Output: tensor([[0.0002]], device='cuda:0')


100%|██████████| 5/5 [00:12<00:00,  2.49s/it]

Layer: 30 | Output: tensor([[0.0056]], device='cuda:0')





In [84]:
print(prompt)

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Scott and Olivia are working in the entirely separate sections of a busy restaurant, with no visibility between them. To complete an order, Scott grabs an opaque bottle and fills it with gin. Then Olivia grabs another opaque mug and fills it with cocoa.
Question: What does Olivia believe the mug contains?
Answer:
