In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3,4"

import json
import random
import math
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any, List, Optional
import nnsight
from nnsight import CONFIG, LanguageModel
import numpy as np
from collections import defaultdict
from einops import einsum
import time
from einops import rearrange, reduce
import pandas as pd

sys.path.append("../")
from src.dataset import SampleV3, DatasetV3, STORY_TEMPLATES
from src.utils import env_utils
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(10)

CONFIG.set_default_api_key("d9e00ab7d4f74643b3176de0913f24a7")
os.environ["HF_TOKEN"] = "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"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
CONFIG.APP.REMOTE_LOGGING = False

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm
env.yml not found in /disk/u/nikhil/mind!
Setting MODEL_ROOT="". Models will now be downloaded to conda env cache, if not already there
Other defaults are set to:
    DATA_DIR = "data"
    RESULTS_DIR = "results"
    HPARAMS_DIR = "hparams"


# Loading Raw Data

In [2]:
all_states = {}
all_containers= {}
all_characters = json.load(open(os.path.join(env_utils.DEFAULT_DATA_DIR, "synthetic_entities", "characters.json"), "r"))

for TYPE, DCT in {"states": all_states, "containers": all_containers}.items():
    ROOT = os.path.join(
        env_utils.DEFAULT_DATA_DIR, "synthetic_entities", TYPE
    )
    for file in os.listdir(ROOT):
        file_path = os.path.join(ROOT, file)
        with open(file_path, "r") as f:
            names = json.load(f)
        DCT[file.split(".")[0]] = names

# Loading model

In [3]:
# model = LanguageModel("meta-llama/Meta-Llama-3.1-405B-Instruct")
model = LanguageModel("meta-llama/Meta-Llama-3-70B-Instruct", cache_dir="/disk/u/nikhil/.cache/huggingface/hub/", device_map="auto", torch_dtype=torch.float16, dispatch=True)

Loading checkpoint shards: 100%|██████████| 30/30 [00:36<00:00,  1.23s/it]


In [4]:
model.eval()
for param in model.parameters():
    param.requires_grad_(False)

# Loading Singular Vectors

In [5]:
sing_vecs = defaultdict(dict)
for l in range(model.config.num_hidden_layers):
    sing_vecs[l] = torch.load(f"../svd_results/belief_tracking/second_visibility_sent/singular_vecs/{l}.pt").cpu()

# DCM

In [6]:
charac_indices = [131, 133, 146, 147, 158, 159]
object_indices = [150, 151, 162, 163]
state_indices = [155, 156, 167, 168]
reversed_state_indices = [167, 168, 155, 156]
reversed_object_indices = [162, 163, 150, 151]
reversed_charac_indices = [133, 131, 158, 159, 146, 147]
query_sent = [i for i in range(169, 181)]
first_visibility_sent = [i for i in range(169, 176)]
second_visibility_sent = [i for i in range(176, 183)]

## State Subspace

In [7]:
train_size = 80
valid_size = 40
batch_size = 4

dataset = get_state_pos_exps(STORY_TEMPLATES,
                                all_characters,
                                all_containers,
                                all_states,
                                train_size+valid_size,
                                question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [8]:
idx = 1
print(train_dataset[idx]['corrupt_prompt'], train_dataset[idx]['corrupt_ans'])
print(train_dataset[idx]['clean_prompt'], train_dataset[idx]['clean_ans'])
print(f"Target: '{train_dataset[idx]['target']}'")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Nancy and Tony are working in a busy restaurant. To complete an order, Nancy grabs an opaque can and fills it with juice. Then Tony grabs another opaque dispenser and fills it with stout.
Question: What does Nancy believe the can contains?
Answer: juice
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when th

In [9]:
idx = 0
tokens = model.tokenizer.encode(train_dataset[idx]['corrupt_prompt'], return_tensors="pt").to(device)
print(model.tokenizer.decode(tokens[0][state_indices]))

tokens = model.tokenizer.encode(train_dataset[idx]['clean_prompt'], return_tensors="pt").to(device)
print(model.tokenizer.decode(tokens[0][state_indices]))

 water. port.

 port. water.



In [61]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(0, 40, 10):
    modules = [i for i in range(sing_vecs[layer_idx].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 1

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate(state_indices):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec * mask.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate(reversed_state_indices):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[layer_idx].output[0][:, t] = modified_out

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         mask.data.clamp_(0, 1)
            #         rounded = torch.round(mask)
            #         print(f"#Causal SVs: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0

    with torch.inference_mode():
        mask_data = mask.data.clone()
        mask_data.clamp_(0, 1)
        rounded = torch.round(mask_data)
        print(f"#Rank: {(rounded == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        torch.save(rounded, f"../masks/toy/state_oid/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate(state_indices):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate(reversed_state_indices):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[layer_idx].output[0][:, t] = modified_out

                        # model.model.layers[layer_idx].output[0][:, t] = alt_acts[t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()
            
            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    valid_accs[layer_idx] = correct / total

Training layer: 0


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:31<00:00,  1.59s/it]


Training complete for 0!
Validation started for 0
#Rank: 0


100%|██████████| 10/10 [00:09<00:00,  1.01it/s]


Layer: 0 | Validation accuracy: 0.00

Training layer: 10


100%|██████████| 20/20 [00:31<00:00,  1.57s/it]


Training complete for 10!
Validation started for 10
#Rank: 0


100%|██████████| 10/10 [00:09<00:00,  1.01it/s]


Layer: 10 | Validation accuracy: 0.00

Training layer: 20


100%|██████████| 20/20 [00:31<00:00,  1.58s/it]


Training complete for 20!
Validation started for 20
#Rank: 2


100%|██████████| 10/10 [00:09<00:00,  1.00it/s]


Layer: 20 | Validation accuracy: 0.00

Training layer: 30


100%|██████████| 20/20 [00:31<00:00,  1.57s/it]


Training complete for 30!
Validation started for 30
#Rank: 3


100%|██████████| 10/10 [00:10<00:00,  1.02s/it]

Layer: 30 | Validation accuracy: 0.30






In [63]:
# sort valid_accs by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{0: 0.0,
 10: 0.0,
 20: 0.0,
 30: 0.3,
 32: 0.6,
 34: 1.0,
 36: 0.7,
 38: 0.425,
 40: 0.025}

## Object Subspace

In [33]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = get_obj_pos_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size,
                             question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [34]:
idx = 1
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Uma and Jake are working in a busy restaurant. To complete an order, Uma grabs an opaque dispenser and fills it with port. Then Jake grabs another opaque glass and fills it with cocktail.
Question: What does Jake believe the glass contains?
Answer: cocktail
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only whe

In [35]:
idx = 0
tokens = model.tokenizer.encode(train_dataset[idx]['corrupt_prompt'], return_tensors="pt").to(device)
print(model.tokenizer.decode(tokens[0][object_indices]))

tokens = model.tokenizer.encode(train_dataset[idx]['clean_prompt'], return_tensors="pt").to(device)
print(model.tokenizer.decode(tokens[0][object_indices]))

 flask and cup and
 cup and flask and


In [54]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(12, 20, 2):
    mask = torch.ones(layer_idx+1, sing_vecs[layer_idx].shape[0], requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.01

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, alt_acts_charac, org_acts_state, org_acts_query_charac = defaultdict(dict), defaultdict(dict), defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for l in range(layer_idx + 1):
                        for t_idx, t in enumerate(charac_indices):
                            alt_acts_charac[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                        for t_idx, t in enumerate(object_indices):
                            alt_acts[l][t_idx] = model.model.layers[l].output[0][:, t].clone()


                with tracer.invoke(org_prompts):
                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(state_indices):
                            org_acts_state[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                        for t_idx, t in enumerate([-8, -7]):
                            org_acts_query_charac[l][t_idx] = model.model.layers[l].output[0][:, t].clone()


                with tracer.invoke(org_prompts):
                    for l in range(layer_idx+1):
                        for t_idx, t in enumerate(reversed_charac_indices):
                            model.model.layers[l].output[0][:, t] = alt_acts_charac[l][t_idx]

                        sing_vec = sing_vecs[l].cuda()
                        masked_vec = sing_vec * mask[l].unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()
                        for t_idx, t in enumerate(reversed_object_indices):
                            curr_output = model.model.layers[l].output[0][:, t].clone()
                            alt_proj = torch.matmul(alt_acts[l][t_idx], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)
                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[l].output[0][:, t] = modified_out

                        for t_idx, t in enumerate([-8, -7]):
                            model.model.layers[l].output[0][:, t] = org_acts_query_charac[l][t_idx]

                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(state_indices):
                            model.model.layers[l].output[0][:, t] = org_acts_state[l][t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         for l in range(layer_idx+1):
            #             mask[l].data.clamp_(0, 1)
            #             rounded = torch.round(mask[l])
            #             print(f"#Rank: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                for l in range(layer_idx+1):
                    mask[l].data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0
    rounded = torch.zeros(layer_idx+1, sing_vecs[layer_idx].shape[0], device="cuda", dtype=torch.bfloat16)
    with torch.inference_mode():
        for l in range(layer_idx+1):
            mask_data = mask[l].data.clone()
            mask_data.clamp_(0, 1)
            rounded[l] = torch.round(mask_data)
            print(f"Layer: {l} | #Rank: {(rounded[l] == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        torch.save(rounded, f"../masks/toy/object_oid/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, alt_acts_charac, org_acts_state, org_acts_query_charac = defaultdict(dict), defaultdict(dict), defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for l in range(layer_idx+1):
                        for t_idx, t in enumerate(charac_indices):
                            alt_acts_charac[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                        for t_idx, t in enumerate(object_indices):
                            alt_acts[l][t_idx] = model.model.layers[l].output[0][:, t].clone()


                with tracer.invoke(org_prompts):
                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(state_indices):
                            org_acts_state[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                        for t_idx, t in enumerate([-8, -7]):
                            org_acts_query_charac[l][t_idx] = model.model.layers[l].output[0][:, t].clone()


                with tracer.invoke(org_prompts):
                    for l in range(layer_idx+1):
                        for t_idx, t in enumerate(reversed_charac_indices):
                            model.model.layers[l].output[0][:, t] = alt_acts_charac[l][t_idx]

                        sing_vec = sing_vecs[l].cuda()
                        masked_vec = sing_vec.to(rounded[l].device) * rounded[l].unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()
                        for t_idx, t in enumerate(reversed_object_indices):
                            curr_output = model.model.layers[l].output[0][:, t].clone()
                            alt_proj = torch.matmul(alt_acts[l][t_idx], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)
                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[l].output[0][:, t] = modified_out

                        for t_idx, t in enumerate([-8, -7]):
                            model.model.layers[l].output[0][:, t] = org_acts_query_charac[l][t_idx]

                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(state_indices):
                            model.model.layers[l].output[0][:, t] = org_acts_state[l][t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()

            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    valid_accs[layer_idx] = correct / total

Training layer: 12


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [01:17<00:00,  3.88s/it]


Training complete for 12!
Validation started for 12
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 0
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 0
Layer: 5 | #Rank: 0
Layer: 6 | #Rank: 0
Layer: 7 | #Rank: 0
Layer: 8 | #Rank: 1
Layer: 9 | #Rank: 1
Layer: 10 | #Rank: 0
Layer: 11 | #Rank: 2
Layer: 12 | #Rank: 8


100%|██████████| 20/20 [00:42<00:00,  2.13s/it]


Layer: 12 | Validation accuracy: 0.06

Training layer: 14


100%|██████████| 20/20 [01:19<00:00,  3.98s/it]


Training complete for 14!
Validation started for 14
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 0
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 1
Layer: 4 | #Rank: 0
Layer: 5 | #Rank: 0
Layer: 6 | #Rank: 1
Layer: 7 | #Rank: 0
Layer: 8 | #Rank: 1
Layer: 9 | #Rank: 2
Layer: 10 | #Rank: 1
Layer: 11 | #Rank: 2
Layer: 12 | #Rank: 1
Layer: 13 | #Rank: 14
Layer: 14 | #Rank: 31


100%|██████████| 20/20 [00:44<00:00,  2.21s/it]


Layer: 14 | Validation accuracy: 0.38

Training layer: 16


100%|██████████| 20/20 [01:21<00:00,  4.07s/it]


Training complete for 16!
Validation started for 16
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 0
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 2
Layer: 5 | #Rank: 2
Layer: 6 | #Rank: 0
Layer: 7 | #Rank: 0
Layer: 8 | #Rank: 1
Layer: 9 | #Rank: 6
Layer: 10 | #Rank: 0
Layer: 11 | #Rank: 1
Layer: 12 | #Rank: 0
Layer: 13 | #Rank: 11
Layer: 14 | #Rank: 4
Layer: 15 | #Rank: 8
Layer: 16 | #Rank: 57


100%|██████████| 20/20 [00:45<00:00,  2.29s/it]


Layer: 16 | Validation accuracy: 0.69

Training layer: 18


100%|██████████| 20/20 [01:23<00:00,  4.18s/it]


Training complete for 18!
Validation started for 18
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 0
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 0
Layer: 5 | #Rank: 0
Layer: 6 | #Rank: 0
Layer: 7 | #Rank: 0
Layer: 8 | #Rank: 1
Layer: 9 | #Rank: 2
Layer: 10 | #Rank: 0
Layer: 11 | #Rank: 0
Layer: 12 | #Rank: 1
Layer: 13 | #Rank: 6
Layer: 14 | #Rank: 0
Layer: 15 | #Rank: 3
Layer: 16 | #Rank: 13
Layer: 17 | #Rank: 2
Layer: 18 | #Rank: 21


100%|██████████| 20/20 [00:46<00:00,  2.34s/it]

Layer: 18 | Validation accuracy: 0.90






In [56]:
# Sort the valid_accs dict by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{0: 0.025,
 5: 0.025,
 10: 0.025,
 12: 0.0625,
 14: 0.375,
 15: 0.375,
 16: 0.6875,
 18: 0.9,
 20: 0.9125,
 25: 0.9375,
 30: 0.9125}

## Character Subspace

In [18]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = get_charac_pos_exp(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size,
                             question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [8]:
idx = 1
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Liz and Ray are working in a busy restaurant. To complete an order, Liz grabs an opaque pitcher and fills it with sprite. Then Ray grabs another opaque container and fills it with champagne.
Question: What does Ray believe the container contains?
Answer: champagne
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed o

In [9]:
idx = 0
tokens = model.tokenizer.encode(train_dataset[idx]['corrupt_prompt'], return_tensors="pt").to(device)
print(model.tokenizer.decode(tokens[0][object_indices]))

tokens = model.tokenizer.encode(train_dataset[idx]['clean_prompt'], return_tensors="pt").to(device)
print(model.tokenizer.decode(tokens[0][object_indices]))

 dispenser and tun and
 tun and dispenser and


In [29]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(12, 20, 2):
    mask = torch.ones(layer_idx+1, sing_vecs[layer_idx].shape[0], requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.01

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0
        
        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)
            
            optimizer.zero_grad()

            alt_acts, org_acts_state, org_acts_query_obj = defaultdict(dict), defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for l in range(layer_idx + 1):
                        for t_idx, t in enumerate(charac_indices):
                            alt_acts[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(object_indices + state_indices):
                            org_acts_state[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                        for t_idx, t in enumerate([-5, -4]):
                            org_acts_query_obj[l][t_idx] = model.model.layers[l].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    for l in range(layer_idx+1):
                        sing_vec = sing_vecs[l].cuda()
                        masked_vec = sing_vec * mask[l].unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        for t_idx, t in enumerate(reversed_charac_indices):
                            curr_output = model.model.layers[l].output[0][:, t].clone()
                            alt_proj = torch.matmul(alt_acts[l][t_idx], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)

                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[l].output[0][:, t] = modified_out
                        
                        for t_idx, t in enumerate([-5, -4]):
                            model.model.layers[l].output[0][:, t] = org_acts_query_obj[l][t_idx]
                        
                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(object_indices + state_indices):
                            model.model.layers[l].output[0][:, t] = org_acts_state[l][t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()
                    
                    logits = model.lm_head.output[:, -1].save()
            
            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         for l in range(layer_idx+1):
            #             mask[l].data.clamp_(0, 1)
            #             rounded = torch.round(mask[l])
            #             print(f"#Rank: {(rounded == 1).sum().item()}")
            
            loss.backward()
            optimizer.step()
            
            # Clamp after optimizer step
            with torch.no_grad():
                for l in range(layer_idx+1):
                    mask[l].data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0
    rounded = torch.zeros(layer_idx+1, sing_vecs[layer_idx].shape[0], device="cuda", dtype=torch.bfloat16)
    with torch.inference_mode():
        for l in range(layer_idx+1):
            mask_data = mask[l].data.clone()
            mask_data.clamp_(0, 1)
            rounded[l] = torch.round(mask_data)
            print(f"Layer: {l} | #Rank: {(rounded[l] == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        torch.save(rounded, f"../masks/toy/charac_oid/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, org_acts_state, org_acts_query_charac = defaultdict(dict), defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for l in range(layer_idx+1):
                        for t_idx, t in enumerate(charac_indices):
                            alt_acts[l][t_idx] = model.model.layers[l].output[0][:, t].clone()


                with tracer.invoke(org_prompts):
                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(object_indices + state_indices):
                            org_acts_state[l][t_idx] = model.model.layers[l].output[0][:, t].clone()
                        
                        for t_idx, t in enumerate([-5, -4]):
                            org_acts_query_charac[l][t_idx] = model.model.layers[l].output[0][:, t].clone()


                with tracer.invoke(org_prompts):
                    for l in range(layer_idx+1):
                        sing_vec = sing_vecs[l].cuda()
                        masked_vec = sing_vec.to(rounded[l].device) * rounded[l].unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        for t_idx, t in enumerate(reversed_charac_indices):
                            curr_output = model.model.layers[l].output[0][:, t].clone()
                            alt_proj = torch.matmul(alt_acts[l][t_idx], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)

                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[l].output[0][:, t] = modified_out

                        for t_idx, t in enumerate([-5, -4]):
                            model.model.layers[l].output[0][:, t] = org_acts_query_charac[l][t_idx]

                    for l in range(model.config.num_hidden_layers):
                        for t_idx, t in enumerate(object_indices + state_indices):
                            model.model.layers[l].output[0][:, t] = org_acts_state[l][t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()
            
            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    valid_accs[layer_idx] = correct / total

Training layer: 12


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [01:27<00:00,  4.35s/it]


Training complete for 12!
Validation started for 12
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 1
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 2
Layer: 5 | #Rank: 0
Layer: 6 | #Rank: 1
Layer: 7 | #Rank: 2
Layer: 8 | #Rank: 1
Layer: 9 | #Rank: 2
Layer: 10 | #Rank: 5
Layer: 11 | #Rank: 4
Layer: 12 | #Rank: 7


100%|██████████| 20/20 [00:49<00:00,  2.46s/it]


Layer: 12 | Validation accuracy: 0.41

Training layer: 14


100%|██████████| 20/20 [01:30<00:00,  4.50s/it]


Training complete for 14!
Validation started for 14
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 1
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 1
Layer: 5 | #Rank: 1
Layer: 6 | #Rank: 0
Layer: 7 | #Rank: 0
Layer: 8 | #Rank: 0
Layer: 9 | #Rank: 1
Layer: 10 | #Rank: 1
Layer: 11 | #Rank: 3
Layer: 12 | #Rank: 1
Layer: 13 | #Rank: 10
Layer: 14 | #Rank: 25


100%|██████████| 20/20 [00:51<00:00,  2.56s/it]


Layer: 14 | Validation accuracy: 0.78

Training layer: 16


100%|██████████| 20/20 [01:33<00:00,  4.66s/it]


Training complete for 16!
Validation started for 16
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 0
Layer: 2 | #Rank: 0
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 0
Layer: 5 | #Rank: 0
Layer: 6 | #Rank: 0
Layer: 7 | #Rank: 0
Layer: 8 | #Rank: 0
Layer: 9 | #Rank: 1
Layer: 10 | #Rank: 1
Layer: 11 | #Rank: 1
Layer: 12 | #Rank: 2
Layer: 13 | #Rank: 5
Layer: 14 | #Rank: 5
Layer: 15 | #Rank: 1
Layer: 16 | #Rank: 7


100%|██████████| 20/20 [00:52<00:00,  2.64s/it]


Layer: 16 | Validation accuracy: 0.86

Training layer: 18


100%|██████████| 20/20 [01:36<00:00,  4.83s/it]


Training complete for 18!
Validation started for 18
Layer: 0 | #Rank: 0
Layer: 1 | #Rank: 1
Layer: 2 | #Rank: 1
Layer: 3 | #Rank: 0
Layer: 4 | #Rank: 1
Layer: 5 | #Rank: 1
Layer: 6 | #Rank: 1
Layer: 7 | #Rank: 1
Layer: 8 | #Rank: 2
Layer: 9 | #Rank: 3
Layer: 10 | #Rank: 3
Layer: 11 | #Rank: 4
Layer: 12 | #Rank: 3
Layer: 13 | #Rank: 7
Layer: 14 | #Rank: 5
Layer: 15 | #Rank: 0
Layer: 16 | #Rank: 3
Layer: 17 | #Rank: 2
Layer: 18 | #Rank: 2


100%|██████████| 20/20 [00:54<00:00,  2.73s/it]

Layer: 18 | Validation accuracy: 0.76






In [30]:
# Sort valid_accs by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{0: 0.0625,
 5: 0.0625,
 10: 0.3375,
 12: 0.4125,
 14: 0.775,
 15: 0.85,
 16: 0.8625,
 18: 0.7625,
 20: 0.6875,
 25: 0.675,
 30: 0.5875}

## Query Character

In [21]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = query_charac_pos(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size,
                             question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [22]:
idx = 4
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Nick and Ruth are working in a busy restaurant. To complete an order, Nick grabs an opaque cup and fills it with bourbon. Then Ruth grabs another opaque pint and fills it with juice.
Question: What does Nick believe the cup contains?
Answer: bourbon
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they p

In [38]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(34, 36, 2):
    modules = [i for i in range(sing_vecs[layer_idx].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.025

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate([-8, -7]):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec * mask.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate([-8, -7]):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[l].output[0][:, t] = modified_out

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         mask.data.clamp_(0, 1)
            #         rounded = torch.round(mask)
            #         print(f"#Causal SVs: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0

    with torch.inference_mode():
        mask_data = mask.data.clone()
        mask_data.clamp_(0, 1)
        rounded = torch.round(mask_data)
        print(f"#Rank: {(rounded == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        torch.save(rounded, f"../masks/toy/query_charac_oid/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate([-8, -7]):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate([-8, -7]):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[layer_idx].output[0][:, t] = modified_out

                        model.model.layers[layer_idx].output[0][:, t] = alt_acts[t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()
            
            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    valid_accs[layer_idx] = correct / total

Training layer: 34


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch: 0, Batch: 0, Task Loss: -16.3281, L1 Loss: 25.0000, Total Loss: 8.6719
#Causal SVs: 1000


 10%|█         | 2/20 [00:03<00:27,  1.54s/it]

Epoch: 0, Batch: 2, Task Loss: -16.4219, L1 Loss: 19.8750, Total Loss: 5.7760
#Causal SVs: 1000


 20%|██        | 4/20 [00:06<00:24,  1.55s/it]

Epoch: 0, Batch: 4, Task Loss: -17.8125, L1 Loss: 14.8750, Total Loss: 2.7969
#Causal SVs: 1000


 30%|███       | 6/20 [00:09<00:21,  1.55s/it]

Epoch: 0, Batch: 6, Task Loss: -18.6406, L1 Loss: 9.8750, Total Loss: -0.1585
#Causal SVs: 1


 40%|████      | 8/20 [00:12<00:18,  1.55s/it]

Epoch: 0, Batch: 8, Task Loss: -17.1250, L1 Loss: 4.9375, Total Loss: -2.6997
#Causal SVs: 1


 50%|█████     | 10/20 [00:15<00:15,  1.55s/it]

Epoch: 0, Batch: 10, Task Loss: -16.9844, L1 Loss: 0.0325, Total Loss: -5.1292
#Causal SVs: 1


 60%|██████    | 12/20 [00:18<00:12,  1.56s/it]

Epoch: 0, Batch: 12, Task Loss: -16.9688, L1 Loss: 0.0236, Total Loss: -6.8436
#Causal SVs: 1


 70%|███████   | 14/20 [00:21<00:09,  1.55s/it]

Epoch: 0, Batch: 14, Task Loss: -17.6875, L1 Loss: 0.0234, Total Loss: -8.3686
#Causal SVs: 1


 80%|████████  | 16/20 [00:24<00:06,  1.55s/it]

Epoch: 0, Batch: 16, Task Loss: -18.9219, L1 Loss: 0.0237, Total Loss: -9.5523
#Causal SVs: 1


 90%|█████████ | 18/20 [00:27<00:03,  1.55s/it]

Epoch: 0, Batch: 18, Task Loss: -18.8281, L1 Loss: 0.0238, Total Loss: -10.3814
#Causal SVs: 1


100%|██████████| 20/20 [00:31<00:00,  1.55s/it]


Training complete for 34!
Validation started for 34
#Rank: 1


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]

Layer: 34 | Validation accuracy: 0.07






In [46]:
# Sort valid_accs by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{0: 0.075,
 10: 0.0375,
 12: 0.225,
 14: 1.0,
 16: 1.0,
 18: 1.0,
 20: 0.9,
 30: 0.4625,
 32: 0.1625,
 34: 0.075}

## Query Object

In [7]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = query_obj_pos(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size,
                             question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [8]:
idx = 1
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Liz and Ray are working in a busy restaurant. To complete an order, Liz grabs an opaque pitcher and fills it with sprite. Then Ray grabs another opaque container and fills it with champagne.
Question: What does Ray believe the container contains?
Answer: champagne
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed o

In [18]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(22, 30, 2):
    modules = [i for i in range(sing_vecs[layer_idx].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.05

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate([-5, -4]):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec * mask.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate([-5, -4]):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[l].output[0][:, t] = modified_out

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         mask.data.clamp_(0, 1)
            #         rounded = torch.round(mask)
            #         print(f"#Causal SVs: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0

    with torch.inference_mode():
        mask_data = mask.data.clone()
        mask_data.clamp_(0, 1)
        rounded = torch.round(mask_data)
        print(f"#Rank: {(rounded == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        torch.save(rounded, f"../masks/toy/query_obj_oid/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate([-5, -4]):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate([-5, -4]):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[layer_idx].output[0][:, t] = modified_out

                        model.model.layers[layer_idx].output[0][:, t] = alt_acts[t_idx]

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()
            
            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    valid_accs[layer_idx] = correct / total

Training layer: 22


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:30<00:00,  1.55s/it]


Training complete for 22!
Validation started for 22
#Rank: 1


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


Layer: 22 | Validation accuracy: 0.91

Training layer: 24


100%|██████████| 20/20 [00:30<00:00,  1.54s/it]


Training complete for 24!
Validation started for 24
#Rank: 1


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


Layer: 24 | Validation accuracy: 0.93

Training layer: 26


100%|██████████| 20/20 [00:30<00:00,  1.54s/it]


Training complete for 26!
Validation started for 26
#Rank: 1


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


Layer: 26 | Validation accuracy: 0.90

Training layer: 28


100%|██████████| 20/20 [00:30<00:00,  1.54s/it]


Training complete for 28!
Validation started for 28
#Rank: 3


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]

Layer: 28 | Validation accuracy: 0.75






In [19]:
# Sort valid_accs by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{0: 0.0,
 10: 0.0,
 12: 0.0,
 14: 0.6,
 16: 0.6,
 18: 0.8625,
 20: 0.9,
 22: 0.9125,
 24: 0.925,
 26: 0.9,
 28: 0.75,
 30: 0.2625}

## Correct State Fetcher

In [7]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = get_pos_trans_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size,
                             question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [8]:
idx = 0
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Karen and Max are working in a busy restaurant. To complete an order, Karen grabs an opaque dispenser and fills it with coffee. Then Max grabs another opaque tun and fills it with cocoa.
Question: What does Karen believe the dispenser contains?
Answer: coffee
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only w

In [9]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(40, 65, 5):
    modules = [i for i in range(sing_vecs[layer_idx].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.1

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    alt_acts = model.model.layers[layer_idx].output[0][:, -1].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec * mask.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    curr_output = model.model.layers[layer_idx].output[0][:, -1].clone()
                    alt_proj = torch.matmul(alt_acts, proj_matrix)
                    org_proj = torch.matmul(curr_output, proj_matrix)

                    modified_out = curr_output - org_proj + alt_proj
                    model.model.layers[layer_idx].output[0][:, -1] = modified_out

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         mask.data.clamp_(0, 1)
            #         rounded = torch.round(mask)
            #         print(f"#Causal SVs: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0

    with torch.inference_mode():
        mask_data = mask.data.clone()
        mask_data.clamp_(0, 1)
        rounded = torch.round(mask_data)
        print(f"#Rank: {(rounded == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        # torch.save(rounded, f"../masks/toy/correct_state_oid/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    alt_acts = model.model.layers[layer_idx].output[0][:, -1].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    curr_output = model.model.layers[layer_idx].output[0][:, -1].clone()
                    alt_proj = torch.matmul(alt_acts, proj_matrix)
                    org_proj = torch.matmul(curr_output, proj_matrix)

                    modified_out = curr_output - org_proj + alt_proj
                    model.model.layers[layer_idx].output[0][:, -1] = modified_out

                    # model.model.layers[layer_idx].output[0][:, -1] = alt_acts

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()
            
            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    # valid_accs[layer_idx] = correct / total

Training layer: 40


  0%|          | 0/20 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
100%|██████████| 20/20 [00:31<00:00,  1.58s/it]


Training complete for 40!
Validation started for 40
#Rank: 12


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Layer: 40 | Validation accuracy: 0.86

Training layer: 45


100%|██████████| 20/20 [00:29<00:00,  1.46s/it]


Training complete for 45!
Validation started for 45
#Rank: 14


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Layer: 45 | Validation accuracy: 0.86

Training layer: 50


100%|██████████| 20/20 [00:27<00:00,  1.38s/it]


Training complete for 50!
Validation started for 50
#Rank: 14


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Layer: 50 | Validation accuracy: 0.84

Training layer: 55


100%|██████████| 20/20 [00:26<00:00,  1.31s/it]


Training complete for 55!
Validation started for 55
#Rank: 16


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Layer: 55 | Validation accuracy: 0.24

Training layer: 60


100%|██████████| 20/20 [00:24<00:00,  1.24s/it]


Training complete for 60!
Validation started for 60
#Rank: 14


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]

Layer: 60 | Validation accuracy: 0.11






In [48]:
# Sort valid_accs by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{30: 0.0125,
 32: 0.0,
 34: 0.975,
 36: 0.8,
 38: 0.9625,
 40: 0.0,
 45: 0.0,
 50: 0.0,
 55: 0.025,
 60: 0.0625}

## Value Fetcher

In [12]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = get_value_fetcher_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size,
                             question_type="belief_question")
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [13]:
idx = 0
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_target'])
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_target'])

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Max and Karen are working in a busy restaurant. To complete an order, Max grabs an opaque tun and fills it with port. Then Karen grabs another opaque dispenser and fills it with water.
Question: What does Max believe the tun contains?
Answer:  port
Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they pe

In [22]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(56, 60, 2):
    modules = [i for i in range(sing_vecs[layer_idx].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.4

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["corrupt_target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    alt_acts = model.model.layers[layer_idx].output[0][:, -1].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec * mask.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    curr_output = model.model.layers[layer_idx].output[0][:, -1].clone()
                    alt_proj = torch.matmul(alt_acts, proj_matrix)
                    org_proj = torch.matmul(curr_output, proj_matrix)

                    modified_out = curr_output - org_proj + alt_proj
                    model.model.layers[layer_idx].output[0][:, -1] = modified_out

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            # if bi % 2 == 0:
            #     mean_loss = epoch_loss / (bi + 1)
            #     print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
            #         f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
            #     with torch.no_grad():
            #         mask.data.clamp_(0, 1)
            #         rounded = torch.round(mask)
            #         print(f"#Causal SVs: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

    print(f"Validation started for {layer_idx}")
    correct, total = 0, 0

    with torch.inference_mode():
        mask_data = mask.data.clone()
        mask_data.clamp_(0, 1)
        rounded = torch.round(mask_data)
        print(f"#Rank: {(rounded == 1).sum().item()}")

        # rank[layer_idx] = (rounded == 1).sum().item()
        torch.save(rounded, f"../masks/toy/value_fetcher/{layer_idx}.pt")

        for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["corrupt_target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    alt_acts = model.model.layers[layer_idx].output[0][:, -1].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    curr_output = model.model.layers[layer_idx].output[0][:, -1].clone()
                    alt_proj = torch.matmul(alt_acts, proj_matrix)
                    org_proj = torch.matmul(curr_output, proj_matrix)

                    modified_out = curr_output - org_proj + alt_proj
                    model.model.layers[layer_idx].output[0][:, -1] = modified_out

                    # model.model.layers[layer_idx].output[0][:, -1] = alt_acts

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()

            for i in range(batch_size):
                pred_token = model.tokenizer.decode(pred[i])
                # print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
                if pred_token.lower().strip() == targets[i].lower().strip():
                    correct += 1
                total += 1

            del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
            torch.cuda.empty_cache()

    print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
    valid_accs[layer_idx] = correct / total

Training layer: 56


  0%|          | 0/20 [00:00<?, ?it/s]

100%|██████████| 20/20 [00:19<00:00,  1.02it/s]


Training complete for 56!
Validation started for 56
#Rank: 23


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]


Layer: 56 | Validation accuracy: 0.65

Training layer: 58


100%|██████████| 20/20 [00:19<00:00,  1.03it/s]


Training complete for 58!
Validation started for 58
#Rank: 25


100%|██████████| 20/20 [00:19<00:00,  1.04it/s]

Layer: 58 | Validation accuracy: 0.86






In [23]:
# Sort valid_accs by key
valid_accs = dict(sorted(valid_accs.items()))
valid_accs

{50: 0.0625,
 55: 0.2625,
 56: 0.65,
 58: 0.8625,
 60: 0.8375,
 65: 0.95,
 70: 0.9375,
 75: 0.9875}

## Visibility Parser

In [11]:
train_size = 80
valid_size = 80
batch_size = 4

dataset = get_unidirectional_visibility_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             train_size+valid_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [12]:
idx = 0
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

Instruction: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any beliefs about the container and its contents which they cannot observe. 4. To answer the question, predict only what is inside the queried container, strictly based on the belief of the character, mentioned in the question. 5. If the queried character has no belief about the container in question, then predict 'unknown'. 6. Do not predict container or character as the final output.

Story: Tony and Nancy are working in a busy restaurant. To complete an order, Tony grabs an opaque tote and fills it with porter. Then Nancy grabs another opaque flask and fills it with cocoa. Nancy cannot observe Tony's actions. Tony can observe Nancy's actions.
Question: What does Tony believe the flask contains?
Answer: cocoa
Instruction: 1. Track the belief of each character as 

In [13]:
# valid_accs, rank = {}, {}
model.tokenizer.padding_side = "left"

for layer_idx in range(18, 20, 2):
    modules = [i for i in range(sing_vecs[layer_idx].shape[0])]
    mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)

    optimizer = torch.optim.Adam([mask], lr=1e-1)
    n_epochs = 1
    lamb = 0.001

    print(f"Training layer: {layer_idx}")
    for epoch in range(n_epochs):
        epoch_loss = 0

        for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
            alt_prompts = batch["corrupt_prompt"]
            org_prompts = batch["clean_prompt"]
            targets = batch["target"]
            target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
            batch_size = target_tokens.size(0)

            optimizer.zero_grad()

            alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
            with model.trace() as tracer:
                with tracer.invoke(alt_prompts):
                    for t_idx, t in enumerate(second_visibility_sent):
                        alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

                with tracer.invoke(org_prompts):
                    sing_vec = sing_vecs[layer_idx].cuda()
                    masked_vec = sing_vec * mask.unsqueeze(-1)
                    proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                    for t_idx, t in enumerate(second_visibility_sent):
                        curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                        alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                    modified_out = curr_output - org_proj + alt_proj
                    model.model.layers[layer_idx].output[0][:, t] = modified_out

                    del sing_vec, proj_matrix, masked_vec
                    torch.cuda.empty_cache()

                    logits = model.lm_head.output[:, -1].save()

            target_logit = logits[torch.arange(batch_size), target_tokens]
            task_loss = -torch.mean(target_logit)
            l1_loss = lamb * torch.norm(mask, p=1)
            loss = task_loss + l1_loss.to(task_loss.device)

            epoch_loss += loss.item()

            if bi % 2 == 0:
                mean_loss = epoch_loss / (bi + 1)
                print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
                    f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
                with torch.no_grad():
                    mask.data.clamp_(0, 1)
                    rounded = torch.round(mask)
                    print(f"Rank: {(rounded == 1).sum().item()}")

            loss.backward()
            optimizer.step()

            # Clamp after optimizer step
            with torch.no_grad():
                mask.data.clamp_(0, 1)

    print(f"Training complete for {layer_idx}!")

Training layer: 18


  0%|          | 0/20 [00:00<?, ?it/s]

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch: 0, Batch: 0, Task Loss: -15.2891, L1 Loss: 3.5000, Total Loss: -11.7891
Rank: 3500


 10%|█         | 2/20 [00:04<00:40,  2.27s/it]

Epoch: 0, Batch: 2, Task Loss: -15.3984, L1 Loss: 2.9844, Total Loss: -12.1328
Rank: 3500


 20%|██        | 4/20 [00:08<00:35,  2.20s/it]

Epoch: 0, Batch: 4, Task Loss: -15.5938, L1 Loss: 2.5312, Total Loss: -12.5359
Rank: 3500


 30%|███       | 6/20 [00:13<00:29,  2.13s/it]

Epoch: 0, Batch: 6, Task Loss: -16.8594, L1 Loss: 2.0781, Total Loss: -13.0748
Rank: 1283


 40%|████      | 8/20 [00:17<00:25,  2.14s/it]

Epoch: 0, Batch: 8, Task Loss: -17.3281, L1 Loss: 1.6406, Total Loss: -13.5547
Rank: 1187


 50%|█████     | 10/20 [00:21<00:21,  2.14s/it]

Epoch: 0, Batch: 10, Task Loss: -17.4375, L1 Loss: 1.2266, Total Loss: -13.9574
Rank: 1145


 60%|██████    | 12/20 [00:25<00:17,  2.14s/it]

Epoch: 0, Batch: 12, Task Loss: -17.5156, L1 Loss: 1.1328, Total Loss: -14.2987
Rank: 1133


 70%|███████   | 14/20 [00:30<00:12,  2.14s/it]

Epoch: 0, Batch: 14, Task Loss: -17.6406, L1 Loss: 1.1172, Total Loss: -14.5958
Rank: 1116


 80%|████████  | 16/20 [00:34<00:08,  2.10s/it]

Epoch: 0, Batch: 16, Task Loss: -17.4688, L1 Loss: 1.1016, Total Loss: -14.8396
Rank: 1103


 90%|█████████ | 18/20 [00:38<00:04,  2.12s/it]

Epoch: 0, Batch: 18, Task Loss: -17.4062, L1 Loss: 1.0938, Total Loss: -15.0613
Rank: 1098


100%|██████████| 20/20 [00:42<00:00,  2.15s/it]

Training complete for 18!





In [34]:
print(f"Validation started for {layer_idx}")
correct, total = 0, 0

with torch.inference_mode():
    mask_data = mask.data.clone()
    mask_data.clamp_(0, 1)
    rounded = torch.round(mask_data)
    print(f"#Rank: {(rounded == 1).sum().item()}")

    # rank[layer_idx] = (rounded == 1).sum().item()
    # torch.save(rounded, f"../masks/toy/value_fetcher/{layer_idx}.pt")

    for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
        alt_prompts = batch["corrupt_prompt"]
        org_prompts = batch["clean_prompt"]
        targets = batch["target"]
        target_tokens = model.tokenizer(targets, return_tensors="pt").input_ids[:, -1]
        batch_size = target_tokens.size(0)

        alt_acts, org_acts_state = defaultdict(dict), defaultdict(dict)
        with model.trace() as tracer:
            with tracer.invoke(alt_prompts):
                for t_idx, t in enumerate(second_visibility_sent):
                    alt_acts[t_idx] = model.model.layers[layer_idx].output[0][:, t].clone()

            with tracer.invoke(org_prompts):
                sing_vec = sing_vecs[layer_idx].cuda()
                masked_vec = sing_vec.to(rounded.device) * rounded.unsqueeze(-1)
                proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                for t_idx, t in enumerate(second_visibility_sent):
                    curr_output = model.model.layers[layer_idx].output[0][:, t].clone()
                    alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                    org_proj = torch.matmul(curr_output, proj_matrix)

                    modified_out = curr_output - org_proj + alt_proj
                    model.model.layers[layer_idx].output[0][:, t] = modified_out

                    # model.model.layers[layer_idx].output[0][:, t] = alt_acts[t_idx]

                del sing_vec, proj_matrix, masked_vec
                torch.cuda.empty_cache()

                logits = model.lm_head.output[:, -1].save()

        pred = torch.argmax(logits, dim=-1).to(target_tokens.device).cpu()

        for i in range(batch_size):
            pred_token = model.tokenizer.decode(pred[i])
            print(f"Predicted: {pred_token.lower().strip()}, Target: {targets[i].lower().strip()}")
            if pred_token.lower().strip() == targets[i].lower().strip():
                correct += 1
            total += 1

        del alt_acts, alt_prompts, org_prompts, targets, target_tokens, logits, pred
        torch.cuda.empty_cache()

print(f"Layer: {layer_idx} | Validation accuracy: {correct / total:.2f}\n")
# valid_accs[layer_idx] = correct / total

Validation started for 18
#Rank: 1136


  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:01<00:24,  1.28s/it]

Predicted: bourbon, Target: bourbon
Predicted: sprite, Target: sprite
Predicted: unknown, Target: champagne
Predicted: bourbon, Target: bourbon


 10%|█         | 2/20 [00:02<00:21,  1.19s/it]

Predicted: porter, Target: porter
Predicted: port, Target: port
Predicted: unknown, Target: stout
Predicted: espresso, Target: espresso


 15%|█▌        | 3/20 [00:03<00:19,  1.16s/it]

Predicted: unknown, Target: porter
Predicted: water, Target: water
Predicted: unknown, Target: porter
Predicted: unknown, Target: soda


 20%|██        | 4/20 [00:04<00:17,  1.12s/it]

Predicted: unknown, Target: porter
Predicted: ale, Target: ale
Predicted: gin, Target: gin
Predicted: ale, Target: ale


 25%|██▌       | 5/20 [00:05<00:16,  1.12s/it]

Predicted: cocktail, Target: sprite
Predicted: port, Target: port
Predicted: porter, Target: porter
Predicted: cocktail, Target: cocktail


 30%|███       | 6/20 [00:06<00:15,  1.12s/it]

Predicted: rum, Target: rum
Predicted: tea, Target: tea
Predicted: porter, Target: porter
Predicted: port, Target: port


 35%|███▌      | 7/20 [00:07<00:14,  1.13s/it]

Predicted: porter, Target: porter
Predicted: unknown, Target: espresso
Predicted: punch, Target: punch
Predicted: beer, Target: beer


 40%|████      | 8/20 [00:09<00:13,  1.13s/it]

Predicted: unknown, Target: stout
Predicted: champagne, Target: champagne
Predicted: unknown, Target: milk
Predicted: port, Target: port


 45%|████▌     | 9/20 [00:10<00:12,  1.13s/it]

Predicted: champagne, Target: champagne
Predicted: unknown, Target: milk
Predicted: stout, Target: stout
Predicted: soda, Target: soda


 50%|█████     | 10/20 [00:11<00:11,  1.13s/it]

Predicted: ale, Target: ale
Predicted: milk, Target: milk
Predicted: unknown, Target: stout
Predicted: coffee, Target: coffee


 55%|█████▌    | 11/20 [00:12<00:10,  1.13s/it]

Predicted: monster, Target: monster
Predicted: tea, Target: tea
Predicted: tea, Target: tea
Predicted: water, Target: water


 60%|██████    | 12/20 [00:13<00:09,  1.13s/it]

Predicted: port, Target: port
Predicted: espresso, Target: port
Predicted: unknown, Target: port
Predicted: soda, Target: soda


 65%|██████▌   | 13/20 [00:14<00:07,  1.13s/it]

Predicted: champagne, Target: champagne
Predicted: water, Target: water
Predicted: punch, Target: rum
Predicted: unknown, Target: rum


 70%|███████   | 14/20 [00:15<00:06,  1.13s/it]

Predicted: tea, Target: tea
Predicted: unknown, Target: bourbon
Predicted: port, Target: port
Predicted: monster, Target: monster


 75%|███████▌  | 15/20 [00:17<00:05,  1.13s/it]

Predicted: punch, Target: punch
Predicted: espresso, Target: espresso
Predicted: espresso, Target: espresso
Predicted: milk, Target: milk


 80%|████████  | 16/20 [00:18<00:04,  1.13s/it]

Predicted: tea, Target: tea
Predicted: unknown, Target: wine
Predicted: water, Target: water
Predicted: sprite, Target: sprite


 85%|████████▌ | 17/20 [00:19<00:03,  1.13s/it]

Predicted: espresso, Target: espresso
Predicted: monster, Target: monster
Predicted: unknown, Target: bourbon
Predicted: soda, Target: soda


 90%|█████████ | 18/20 [00:20<00:02,  1.13s/it]

Predicted: stout, Target: stout
Predicted: espresso, Target: espresso
Predicted: unknown, Target: espresso
Predicted: float, Target: float


 95%|█████████▌| 19/20 [00:21<00:01,  1.13s/it]

Predicted: milk, Target: wine
Predicted: sprite, Target: sprite
Predicted: unknown, Target: cocoa
Predicted: monster, Target: monster


100%|██████████| 20/20 [00:22<00:00,  1.13s/it]

Predicted: espresso, Target: espresso
Predicted: water, Target: water
Predicted: punch, Target: punch
Predicted: unknown, Target: bourbon
Layer: 18 | Validation accuracy: 0.71




