In [1]:
import json
import random
import os
import math
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any, List, Optional
import nnsight
from nnsight import CONFIG, LanguageModel
import numpy as np
from collections import defaultdict
from einops import einsum
import time
from einops import rearrange, reduce
import pandas as pd

sys.path.append("../")
from src.dataset import SampleV3, DatasetV3, STORY_TEMPLATES
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(10)

CONFIG.set_default_api_key("d9e00ab7d4f74643b3176de0913f24a7")
os.environ["HF_TOKEN"] = "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"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
CONFIG.APP.REMOTE_LOGGING = False

# Define random seed
seed = 10
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


# Loading model

In [2]:
# model = LanguageModel("meta-llama/Meta-Llama-3.1-405B")
model = LanguageModel("meta-llama/Meta-Llama-3-70B-Instruct", cache_dir="/disk/u/nikhil/.cache/huggingface/hub/", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16, dispatch=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 30/30 [00:37<00:00,  1.25s/it]


In [3]:
model.eval()
for param in model.parameters():
    param.requires_grad_(False)

# Loading Helper Functions

In [4]:
def get_ques_start_token_idx(batch_size, tokenizer, prompt, padding_side="right"):
    input_tokens = tokenizer(prompt, return_tensors="pt", padding=True, padding_side=padding_side).input_ids
    colon_token = tokenizer.encode(":", return_tensors="pt").squeeze()[-1].item()
    ques_start_idx = (input_tokens == colon_token).nonzero()[torch.arange(2, 4*batch_size, 4)][:, 1] - 1

    return ques_start_idx

In [5]:
def get_prompt_token_len(tokenizer, prompt, padding_side="right"):
    input_tokens = tokenizer(prompt, return_tensors="pt", padding=True, padding_side=padding_side)
    return input_tokens.attention_mask.sum(dim=-1)

In [6]:
def check_pred(pred, target, verbose=False):
    prompt = f"Instruction: Check if the following ground truth and prediction are similar or not. If they are the same, then predict 'Yes', else 'No'.\nGround truth: {target}\nPrediction: {pred}\nAnswer:"

    with torch.no_grad():
        with model.generate(prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id):
            out = model.generator.output.save()

    prompt_len = get_prompt_token_len(model.tokenizer, prompt)
    
    if verbose:
        print(model.tokenizer.decode(out[0]))

    return out, prompt_len

# Loading BigToM dataset

In [7]:
# Read a csv file
df_false = pd.read_csv("../data/bigtom/0_forward_belief_false_belief/stories.csv", delimiter=";")
df_true = pd.read_csv("../data/bigtom/0_forward_belief_true_belief/stories.csv", delimiter=";")

# DCM

## Answer Variable

In [24]:
sing_vecs = defaultdict(dict)
for l in range(model.config.num_hidden_layers):
    sing_vecs[l] = torch.load(f"../svd_results/BigToM/last_token/singular_vecs/{l}.pt").cpu()

In [25]:
train_size = 80
valid_size = 40
batch_size = 1

dataset = get_bigtom_value_fetcher_exps(df_false, df_true, train_size+valid_size)
train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [26]:
idx=0
print(dataset[idx]['alt_prompt'], dataset[idx]['alt_ans'])
print(dataset[idx]['org_prompt'], dataset[idx]['org_ans'])
print(f"Target: {dataset[idx]['target']}")

Instructions: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any belief about the container or its content which they cannot observe directly. 4. To answer the question, predict only the final state of the queried container in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict the entire sentence with character or container as the final output.

Story: Fatima is a talented painter who has been commissioned to create a mural on a local community center's wall. She wants to choose the perfect color palette to represent the vibrant energy of the community. Fatima notices that the paint colors she has selected are bright and lively. However, while she is away for lunch, a mischievous cat knocks over the paint cans, causing the colors to mix into dull, muted shades

In [27]:
model.tokenizer.padding_side = "left"

valid_accs = defaultdict(dict)
for layer_idx in range(70, 72, 2):

    n_epochs = 2
    lambs = [0.05]
    for lamb in lambs:
        modules = [i for i in range(sing_vecs[layer_idx].size(0))]
        mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)
        optimizer = torch.optim.Adam([mask], lr=1e-1)

        print(f"Training layer: {layer_idx}, lambda: {lamb}")
        for epoch in range(n_epochs):
            epoch_loss = 0

            for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
                alt_prompt = batch["alt_prompt"]
                org_prompt = batch["org_prompt"]
                target = batch["target"]
                target_token = model.tokenizer(target, return_tensors="pt", padding=True, padding_side="right")
                target_input_ids = target_token.input_ids[:, 1:]
                batch_size = target_input_ids.size(0)

                alt_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, alt_prompt, padding_side="right")
                alt_prompt_len = get_prompt_token_len(model.tokenizer, alt_prompt, padding_side="right")
                org_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, org_prompt, padding_side="right")
                org_prompt_len = get_prompt_token_len(model.tokenizer, org_prompt, padding_side="right")

                optimizer.zero_grad()

                with model.trace() as tracer:

                    with tracer.invoke(alt_prompt):
                        alt_acts = model.model.layers[layer_idx].output[0][0, -1].clone().save()

                    with tracer.invoke(org_prompt):
                        sing_vec = sing_vecs[layer_idx].cuda()
                        masked_vec = sing_vec * mask.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        curr_output = model.model.layers[layer_idx].output[0][0, -1].clone()

                        alt_proj = torch.matmul(alt_acts, proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[layer_idx].output[0][0, -1] = modified_out

                        logits = model.lm_head.output[0, -1].save()

                        del sing_vec, proj_matrix, masked_vec
                        torch.cuda.empty_cache()

                target_logit = logits[target_input_ids[0]].sum()

                task_loss = -(target_logit/batch_size)
                l1_loss = lamb * torch.norm(mask, p=1)
                loss = task_loss + l1_loss.to(task_loss.device)
                
                epoch_loss += loss.item()
                
                if bi % 4 == 0:
                    mean_loss = epoch_loss / (bi + 1)
                    print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
                        f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
                    with torch.no_grad():
                        mask.data.clamp_(0, 1)
                        rounded = torch.round(mask)
                        print(f"#Rank: {(rounded == 1).sum().item()}")
                
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    mask.data.clamp_(0, 1)

        print(f"Training finished for layer: {layer_idx}, lambda: {lamb}")

        print(f"Validation started for layer: {layer_idx}, lambda: {lamb}")
        correct, total = 0, 0

        with torch.inference_mode():
            mask_data = mask.data.clone()
            mask_data.clamp_(0, 1)
            rounded = torch.round(mask)

            print(f"Rank: {(rounded == 1).sum().item()}")

            # Save the mask
            # torch.save(mask_data, f"../masks/bigtom/{layer_idx}.pt")

            for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
                alt_prompt = batch["alt_prompt"]
                org_prompt = batch["org_prompt"]
                alt_ans = batch["alt_ans"]
                target = batch["target"][0]
                batch_size = len(alt_ans)

                alt_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, alt_prompt, padding_side="left")
                alt_prompt_len = get_prompt_token_len(model.tokenizer, alt_prompt, padding_side="left")
                org_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, org_prompt, padding_side="left")
                org_prompt_len = get_prompt_token_len(model.tokenizer, org_prompt, padding_side="left")

                with model.session() as session:

                    with model.trace(alt_prompt):
                        alt_acts = model.model.layers[layer_idx].output[0][0, -1].save()

                    with model.generate(org_prompt, max_new_tokens=2, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id):
                        sing_vec = sing_vecs[layer_idx].cuda()
                        masked_vec = sing_vec * rounded.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        curr_output = model.model.layers[layer_idx].output[0][0, -1].clone()

                        alt_proj = torch.matmul(alt_acts, proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj

                        model.model.layers[layer_idx].output[0][0, -1] = modified_out

                        out = model.generator.output.save()

                    del alt_acts
                    torch.cuda.empty_cache()

                pred = model.tokenizer.decode(out[0][org_prompt_len:-1]).strip()
                print(f"Prediction: {pred} | Target: {target}")
                if pred.lower() in target.lower():
                    correct += 1
                total += 1

            print(f"Validation accuracy: {correct / total:.2f} | Correct: {correct} | Total: {total}\n")
            valid_accs[lamb][layer_idx] = round(correct / total, 2)

Training layer: 70, lambda: 0.05


  0%|          | 0/80 [00:00<?, ?it/s]

  1%|▏         | 1/80 [00:04<06:16,  4.77s/it]

Epoch: 0, Batch: 0, Task Loss: -38.0938, L1 Loss: 20.0000, Total Loss: -18.0938
#Rank: 400


  6%|▋         | 5/80 [00:25<06:31,  5.22s/it]

Epoch: 0, Batch: 4, Task Loss: -30.2344, L1 Loss: 16.1250, Total Loss: -18.0219
#Rank: 400


 11%|█▏        | 9/80 [00:46<06:08,  5.19s/it]

Epoch: 0, Batch: 8, Task Loss: -56.2500, L1 Loss: 13.6875, Total Loss: -26.2118
#Rank: 274


 16%|█▋        | 13/80 [01:07<05:51,  5.24s/it]

Epoch: 0, Batch: 12, Task Loss: -32.6875, L1 Loss: 11.3750, Total Loss: -22.4507
#Rank: 234


 21%|██▏       | 17/80 [01:27<05:22,  5.13s/it]

Epoch: 0, Batch: 16, Task Loss: -30.1250, L1 Loss: 10.2500, Total Loss: -20.2206
#Rank: 213


 26%|██▋       | 21/80 [01:48<05:05,  5.18s/it]

Epoch: 0, Batch: 20, Task Loss: -18.6250, L1 Loss: 9.4375, Total Loss: -19.9182
#Rank: 200


 31%|███▏      | 25/80 [02:09<04:49,  5.27s/it]

Epoch: 0, Batch: 24, Task Loss: -16.7812, L1 Loss: 8.8125, Total Loss: -19.1106
#Rank: 183


 36%|███▋      | 29/80 [02:30<04:28,  5.27s/it]

Epoch: 0, Batch: 28, Task Loss: -70.0000, L1 Loss: 8.3750, Total Loss: -19.9065
#Rank: 170


 41%|████▏     | 33/80 [02:51<04:07,  5.26s/it]

Epoch: 0, Batch: 32, Task Loss: -30.3750, L1 Loss: 8.0625, Total Loss: -21.6357
#Rank: 167


 46%|████▋     | 37/80 [03:12<03:40,  5.13s/it]

Epoch: 0, Batch: 36, Task Loss: -44.8750, L1 Loss: 7.8438, Total Loss: -24.0990
#Rank: 161


 51%|█████▏    | 41/80 [03:32<03:15,  5.02s/it]

Epoch: 0, Batch: 40, Task Loss: -35.5312, L1 Loss: 7.6875, Total Loss: -25.6317
#Rank: 153


 56%|█████▋    | 45/80 [03:53<03:03,  5.23s/it]

Epoch: 0, Batch: 44, Task Loss: -32.8750, L1 Loss: 7.5625, Total Loss: -25.0398
#Rank: 152


 61%|██████▏   | 49/80 [04:14<02:44,  5.29s/it]

Epoch: 0, Batch: 48, Task Loss: -31.3281, L1 Loss: 7.4062, Total Loss: -24.5120
#Rank: 150


 66%|██████▋   | 53/80 [04:35<02:19,  5.17s/it]

Epoch: 0, Batch: 52, Task Loss: -20.9375, L1 Loss: 7.4062, Total Loss: -24.6826
#Rank: 152


 71%|███████▏  | 57/80 [04:56<01:57,  5.13s/it]

Epoch: 0, Batch: 56, Task Loss: -22.5625, L1 Loss: 7.4062, Total Loss: -24.8930
#Rank: 154


 76%|███████▋  | 61/80 [05:17<01:39,  5.24s/it]

Epoch: 0, Batch: 60, Task Loss: -41.3750, L1 Loss: 7.4062, Total Loss: -25.8090
#Rank: 154


 81%|████████▏ | 65/80 [05:38<01:19,  5.28s/it]

Epoch: 0, Batch: 64, Task Loss: -33.0625, L1 Loss: 7.4375, Total Loss: -25.8843
#Rank: 154


 86%|████████▋ | 69/80 [05:59<00:58,  5.29s/it]

Epoch: 0, Batch: 68, Task Loss: -16.5156, L1 Loss: 7.4062, Total Loss: -25.3851
#Rank: 154


 91%|█████████▏| 73/80 [06:20<00:36,  5.28s/it]

Epoch: 0, Batch: 72, Task Loss: -15.8750, L1 Loss: 7.3438, Total Loss: -25.1125
#Rank: 152


 96%|█████████▋| 77/80 [06:41<00:15,  5.28s/it]

Epoch: 0, Batch: 76, Task Loss: -44.6250, L1 Loss: 7.2500, Total Loss: -25.8136
#Rank: 151


100%|██████████| 80/80 [06:57<00:00,  5.22s/it]
  1%|▏         | 1/80 [00:05<07:01,  5.33s/it]

Epoch: 1, Batch: 0, Task Loss: -36.1250, L1 Loss: 7.1562, Total Loss: -28.9688
#Rank: 150


  6%|▋         | 5/80 [00:26<06:33,  5.24s/it]

Epoch: 1, Batch: 4, Task Loss: -29.3594, L1 Loss: 7.0938, Total Loss: -26.2719
#Rank: 150


 11%|█▏        | 9/80 [00:46<06:07,  5.17s/it]

Epoch: 1, Batch: 8, Task Loss: -59.0000, L1 Loss: 7.1562, Total Loss: -35.1337
#Rank: 150


 16%|█▋        | 13/80 [01:07<05:51,  5.24s/it]

Epoch: 1, Batch: 12, Task Loss: -30.9375, L1 Loss: 7.0938, Total Loss: -29.9724
#Rank: 147


 21%|██▏       | 17/80 [01:28<05:25,  5.17s/it]

Epoch: 1, Batch: 16, Task Loss: -31.1406, L1 Loss: 7.0938, Total Loss: -26.8722
#Rank: 146


 26%|██▋       | 21/80 [01:48<05:06,  5.20s/it]

Epoch: 1, Batch: 20, Task Loss: -19.0312, L1 Loss: 7.0625, Total Loss: -26.0424
#Rank: 146


 31%|███▏      | 25/80 [02:10<04:50,  5.27s/it]

Epoch: 1, Batch: 24, Task Loss: -17.3750, L1 Loss: 7.0000, Total Loss: -24.7078
#Rank: 145


 36%|███▋      | 29/80 [02:31<04:28,  5.27s/it]

Epoch: 1, Batch: 28, Task Loss: -69.1250, L1 Loss: 7.0000, Total Loss: -25.0482
#Rank: 144


 41%|████▏     | 33/80 [02:52<04:07,  5.27s/it]

Epoch: 1, Batch: 32, Task Loss: -30.5469, L1 Loss: 7.0000, Total Loss: -26.3857
#Rank: 142


 46%|████▋     | 37/80 [03:13<03:41,  5.15s/it]

Epoch: 1, Batch: 36, Task Loss: -44.2500, L1 Loss: 7.0000, Total Loss: -28.5353
#Rank: 144


 51%|█████▏    | 41/80 [03:33<03:16,  5.03s/it]

Epoch: 1, Batch: 40, Task Loss: -35.9375, L1 Loss: 7.0000, Total Loss: -29.8306
#Rank: 146


 56%|█████▋    | 45/80 [03:54<03:03,  5.23s/it]

Epoch: 1, Batch: 44, Task Loss: -32.3125, L1 Loss: 6.9375, Total Loss: -28.9106
#Rank: 146


 61%|██████▏   | 49/80 [04:15<02:43,  5.28s/it]

Epoch: 1, Batch: 48, Task Loss: -31.7500, L1 Loss: 6.9062, Total Loss: -28.1143
#Rank: 145


 66%|██████▋   | 53/80 [04:36<02:19,  5.16s/it]

Epoch: 1, Batch: 52, Task Loss: -21.1875, L1 Loss: 6.9062, Total Loss: -28.1263
#Rank: 143


 71%|███████▏  | 57/80 [04:56<01:58,  5.13s/it]

Epoch: 1, Batch: 56, Task Loss: -22.2500, L1 Loss: 6.9062, Total Loss: -28.1575
#Rank: 143


 76%|███████▋  | 61/80 [05:18<01:39,  5.26s/it]

Epoch: 1, Batch: 60, Task Loss: -41.6875, L1 Loss: 6.9375, Total Loss: -28.8836
#Rank: 143


 81%|████████▏ | 65/80 [05:39<01:19,  5.30s/it]

Epoch: 1, Batch: 64, Task Loss: -32.9375, L1 Loss: 7.0000, Total Loss: -28.7963
#Rank: 142


 86%|████████▋ | 69/80 [06:00<00:58,  5.30s/it]

Epoch: 1, Batch: 68, Task Loss: -16.1406, L1 Loss: 7.0000, Total Loss: -28.1541
#Rank: 141


 91%|█████████▏| 73/80 [06:21<00:37,  5.30s/it]

Epoch: 1, Batch: 72, Task Loss: -16.0469, L1 Loss: 6.9375, Total Loss: -27.7548
#Rank: 140


 96%|█████████▋| 77/80 [06:42<00:15,  5.29s/it]

Epoch: 1, Batch: 76, Task Loss: -44.5000, L1 Loss: 6.8438, Total Loss: -28.3524
#Rank: 139


100%|██████████| 80/80 [06:58<00:00,  5.23s/it]


Training finished for layer: 70, lambda: 0.05
Validation started for layer: 70, lambda: 0.05
Rank: 139


  2%|▎         | 1/40 [00:05<03:44,  5.77s/it]

Prediction: N | Target: murky and polluted


  5%|▌         | 2/40 [00:11<03:40,  5.80s/it]

Prediction: Frag | Target: fragile branches


  8%|▊         | 3/40 [00:17<03:35,  5.81s/it]

Prediction: Healthy | Target: healthy state


 10%|█         | 4/40 [00:23<03:31,  5.87s/it]

Prediction: wilt | Target: wilted and less ideal


 12%|█▎        | 5/40 [00:29<03:26,  5.89s/it]

Prediction: Dis | Target: disrupted by fallen leaves


 15%|█▌        | 6/40 [00:35<03:21,  5.93s/it]

Prediction: hot | Target: hot sauce


 18%|█▊        | 7/40 [00:41<03:15,  5.93s/it]

Prediction: Takes | Target: nearly empty


 20%|██        | 8/40 [00:47<03:10,  5.95s/it]

Prediction: with | Target: withered


 22%|██▎       | 9/40 [00:53<03:03,  5.91s/it]

Prediction: Frag | Target: fragile branches


 25%|██▌       | 10/40 [00:59<02:57,  5.93s/it]

Prediction: Wet | Target: wet and difficult to ignite


 28%|██▊       | 11/40 [01:04<02:51,  5.92s/it]

Prediction: dry | Target: dry and stale


 30%|███       | 12/40 [01:10<02:45,  5.93s/it]

Prediction: St | Target: stained and damaged


 32%|███▎      | 13/40 [01:16<02:40,  5.93s/it]

Prediction: Green | Target: water


 35%|███▌      | 14/40 [01:22<02:34,  5.96s/it]

Prediction:  | Target: 18°C


 38%|███▊      | 15/40 [01:28<02:27,  5.90s/it]

Prediction: Off | Target: turned off


 40%|████      | 16/40 [01:34<02:21,  5.89s/it]

Prediction: wilt | Target: wilted flowers


 42%|████▎     | 17/40 [01:40<02:15,  5.91s/it]

Prediction: Bur | Target: buried under the sand


 45%|████▌     | 18/40 [01:46<02:10,  5.92s/it]

Prediction: hard | Target: hard and brittle


 48%|████▊     | 19/40 [01:52<02:04,  5.95s/it]

Prediction: washed | Target: washed away and diluted


 50%|█████     | 20/40 [01:58<01:59,  5.96s/it]

Prediction: Enrique | Target: sultanas


 52%|█████▎    | 21/40 [02:04<01:52,  5.91s/it]

Prediction: Frag | Target: fragile branches


 55%|█████▌    | 22/40 [02:10<01:45,  5.89s/it]

Prediction: torn | Target: torn apart


 57%|█████▊    | 23/40 [02:15<01:40,  5.89s/it]

Prediction: The | Target: fox has stolen the eggs


 60%|██████    | 24/40 [02:21<01:34,  5.88s/it]

Prediction: K | Target: Mount Fuji is covered by fog


 62%|██████▎   | 25/40 [02:27<01:28,  5.90s/it]

Prediction: flour | Target: flour


 65%|██████▌   | 26/40 [02:33<01:22,  5.88s/it]

Prediction: Am | Target: very high temperature


 68%|██████▊   | 27/40 [02:39<01:16,  5.87s/it]

Prediction: Dam | Target: damaged by the monkey


 70%|███████   | 28/40 [02:45<01:10,  5.87s/it]

Prediction: severely | Target: severely damaged


 72%|███████▎  | 29/40 [02:51<01:04,  5.84s/it]

Prediction: Dr | Target: a state of disrepair


 75%|███████▌  | 30/40 [02:56<00:58,  5.86s/it]

Prediction: rough | Target: rough and choppy due to the storm


 78%|███████▊  | 31/40 [03:02<00:52,  5.83s/it]

Prediction: Has | Target: has shifted


 80%|████████  | 32/40 [03:08<00:46,  5.83s/it]

Prediction: She | Target: knows about the concealed door


 82%|████████▎ | 33/40 [03:14<00:40,  5.83s/it]

Prediction: Fl | Target: been flattened


 85%|████████▌ | 34/40 [03:20<00:35,  5.85s/it]

Prediction: Ex | Target: affected by the power outage and has cooled down


 88%|████████▊ | 35/40 [03:26<00:29,  5.83s/it]

Prediction: Not | Target: has a broken string


 90%|█████████ | 36/40 [03:31<00:23,  5.83s/it]

Prediction: wilt | Target: wilted and less ideal


 92%|█████████▎| 37/40 [03:37<00:17,  5.83s/it]

Prediction: Bian | Target: almond milk


 95%|█████████▌| 38/40 [03:43<00:11,  5.82s/it]

Prediction: Coll | Target: has collapsed


 98%|█████████▊| 39/40 [03:49<00:05,  5.83s/it]

Prediction: wet | Target: wet and wilted


100%|██████████| 40/40 [03:55<00:00,  5.88s/it]

Prediction: Ol | Target: underripe
Validation accuracy: 0.72 | Correct: 29 | Total: 40






## Answer State OID Variable

In [8]:
sing_vecs = defaultdict(dict)
for l in range(model.config.num_hidden_layers):
    sing_vecs[l] = torch.load(f"../svd_results/BigToM/last_token/singular_vecs/{l}.pt").cpu()

In [9]:
train_size = 80
valid_size = 40
batch_size = 1

dataset = get_bigtom_answer_state_exps(df_false, df_true, train_size+valid_size)
train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [22]:
idx = 0
print(train_dataset[idx]['alt_prompt'], train_dataset[idx]['alt_ans'])
print(train_dataset[idx]['org_prompt'], train_dataset[idx]['org_ans'])
print(f"Target: {train_dataset[idx]['target']}")

Instructions: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any belief about the container or its content which they cannot observe directly. 4. To answer the question, predict only the final state of the queried container in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict the entire sentence with character or container as the final output.

Story: Amara is a botanist exploring a dense rainforest in search of a rare orchid species. Amara's goal is to find the rare orchid and study its unique characteristics. She spots an orchid with vibrant purple petals that she thinks might be the one she is searching for. As she continues to explore, a sudden downpour washes away the purple pigment from the orchid, revealing that it is actually a common white orchid. Am

In [32]:
model.tokenizer.padding_side = "left"

valid_accs = defaultdict(dict)
for layer_idx in range(40, 42, 2):

    n_epochs = 2
    lambs = [0.005]
    for lamb in lambs:
        modules = [i for i in range(sing_vecs[layer_idx].size(0))]
        mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)
        optimizer = torch.optim.Adam([mask], lr=1e-1)

        print(f"Training layer: {layer_idx}, lambda: {lamb}")
        for epoch in range(n_epochs):
            epoch_loss = 0

            for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
                alt_prompt = batch["alt_prompt"]
                org_prompt = batch["org_prompt"]
                target = batch["target"]
                target_token = model.tokenizer(target, return_tensors="pt", padding=True, padding_side="right")
                target_input_ids = target_token.input_ids[:, 1:]
                batch_size = target_input_ids.size(0)

                alt_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, alt_prompt, padding_side="right")
                alt_prompt_len = get_prompt_token_len(model.tokenizer, alt_prompt, padding_side="right")
                org_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, org_prompt, padding_side="right")
                org_prompt_len = get_prompt_token_len(model.tokenizer, org_prompt, padding_side="right")

                optimizer.zero_grad()

                with model.trace() as tracer:

                    with tracer.invoke(alt_prompt):
                        alt_acts = model.model.layers[layer_idx].output[0][0, -1].clone().save()

                    with tracer.invoke(org_prompt):
                        sing_vec = sing_vecs[layer_idx].cuda()
                        masked_vec = sing_vec * mask.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        curr_output = model.model.layers[layer_idx].output[0][0, -1].clone()

                        alt_proj = torch.matmul(alt_acts, proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj
                        model.model.layers[layer_idx].output[0][0, -1] = modified_out

                        logits = model.lm_head.output[0, -1].save()

                        del sing_vec, proj_matrix, masked_vec
                        torch.cuda.empty_cache()

                target_logit = logits[target_input_ids[0]].sum()

                task_loss = -(target_logit/batch_size)
                l1_loss = lamb * torch.norm(mask, p=1)
                loss = task_loss + l1_loss.to(task_loss.device)
                
                epoch_loss += loss.item()
                
                if bi % 4 == 0:
                    mean_loss = epoch_loss / (bi + 1)
                    print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
                        f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
                    with torch.no_grad():
                        mask.data.clamp_(0, 1)
                        rounded = torch.round(mask)
                        print(f"#Rank: {(rounded == 1).sum().item()}")
                
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    mask.data.clamp_(0, 1)

        print(f"Training finished for layer: {layer_idx}, lambda: {lamb}")

        print(f"Validation started for layer: {layer_idx}, lambda: {lamb}")
        correct, total = 0, 0

        with torch.inference_mode():
            mask_data = mask.data.clone()
            mask_data.clamp_(0, 1)
            rounded = torch.round(mask)

            print(f"Rank: {(rounded == 1).sum().item()}")

            # Save the mask
            # torch.save(mask_data, f"../masks/bigtom/{layer_idx}.pt")

            for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
                alt_prompt = batch["alt_prompt"]
                org_prompt = batch["org_prompt"]
                alt_ans = batch["alt_ans"]
                target = batch["target"][0]
                batch_size = len(alt_ans)

                alt_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, alt_prompt, padding_side="left")
                alt_prompt_len = get_prompt_token_len(model.tokenizer, alt_prompt, padding_side="left")
                org_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, org_prompt, padding_side="left")
                org_prompt_len = get_prompt_token_len(model.tokenizer, org_prompt, padding_side="left")

                with model.session() as session:

                    with model.trace(alt_prompt):
                        alt_acts = model.model.layers[layer_idx].output[0][0, -1].save()

                    with model.generate(org_prompt, max_new_tokens=2, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id):
                        sing_vec = sing_vecs[layer_idx].cuda()
                        masked_vec = sing_vec * rounded.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        curr_output = model.model.layers[layer_idx].output[0][0, -1].clone()

                        alt_proj = torch.matmul(alt_acts, proj_matrix)
                        org_proj = torch.matmul(curr_output, proj_matrix)

                        modified_out = curr_output - org_proj + alt_proj

                        model.model.layers[layer_idx].output[0][0, -1] = modified_out

                        out = model.generator.output.save()

                    del alt_acts
                    torch.cuda.empty_cache()

                pred = model.tokenizer.decode(out[0][org_prompt_len:-1]).strip()
                print(f"Prediction: {pred} | Target: {target}")
                if pred.lower() in target.lower():
                    correct += 1
                total += 1

            print(f"Validation accuracy: {correct / total:.2f} | Correct: {correct} | Total: {total}\n")
            valid_accs[lamb][layer_idx] = round(correct / total, 2)

Training layer: 40, lambda: 0.005


  0%|          | 0/80 [00:00<?, ?it/s]

Epoch: 0, Batch: 0, Task Loss: -18.4531, L1 Loss: 2.0000, Total Loss: -16.4531
#Rank: 400


  5%|▌         | 4/80 [00:26<08:26,  6.67s/it]

Epoch: 0, Batch: 4, Task Loss: -39.3750, L1 Loss: 1.6719, Total Loss: -24.4047
#Rank: 400


 10%|█         | 8/80 [00:53<07:56,  6.61s/it]

Epoch: 0, Batch: 8, Task Loss: -24.1562, L1 Loss: 1.5078, Total Loss: -24.4002
#Rank: 327


 15%|█▌        | 12/80 [01:20<07:40,  6.78s/it]

Epoch: 0, Batch: 12, Task Loss: -107.9375, L1 Loss: 1.3672, Total Loss: -30.2296
#Rank: 288


 20%|██        | 16/80 [01:48<07:22,  6.91s/it]

Epoch: 0, Batch: 16, Task Loss: -36.9375, L1 Loss: 1.2578, Total Loss: -32.8718
#Rank: 269


 25%|██▌       | 20/80 [02:16<06:54,  6.90s/it]

Epoch: 0, Batch: 20, Task Loss: -20.2344, L1 Loss: 1.1719, Total Loss: -32.0822
#Rank: 247


 30%|███       | 24/80 [02:44<06:30,  6.97s/it]

Epoch: 0, Batch: 24, Task Loss: -20.5625, L1 Loss: 1.1016, Total Loss: -30.5388
#Rank: 224


 35%|███▌      | 28/80 [03:12<06:03,  6.98s/it]

Epoch: 0, Batch: 28, Task Loss: -91.8750, L1 Loss: 1.0312, Total Loss: -32.5135
#Rank: 217


 40%|████      | 32/80 [03:40<05:34,  6.97s/it]

Epoch: 0, Batch: 32, Task Loss: -41.0000, L1 Loss: 0.9766, Total Loss: -35.4163
#Rank: 199


 45%|████▌     | 36/80 [04:07<05:01,  6.86s/it]

Epoch: 0, Batch: 36, Task Loss: -50.3125, L1 Loss: 0.9297, Total Loss: -35.6271
#Rank: 196


 50%|█████     | 40/80 [04:34<04:29,  6.73s/it]

Epoch: 0, Batch: 40, Task Loss: -87.3750, L1 Loss: 0.8906, Total Loss: -38.2053
#Rank: 185


 55%|█████▌    | 44/80 [05:01<04:05,  6.82s/it]

Epoch: 0, Batch: 44, Task Loss: -62.0625, L1 Loss: 0.8516, Total Loss: -39.9546
#Rank: 173


 60%|██████    | 48/80 [05:28<03:36,  6.75s/it]

Epoch: 0, Batch: 48, Task Loss: -30.7969, L1 Loss: 0.8164, Total Loss: -39.8522
#Rank: 162


 65%|██████▌   | 52/80 [05:56<03:13,  6.92s/it]

Epoch: 0, Batch: 52, Task Loss: -76.1875, L1 Loss: 0.7891, Total Loss: -40.6981
#Rank: 155


 70%|███████   | 56/80 [06:24<02:46,  6.93s/it]

Epoch: 0, Batch: 56, Task Loss: -33.0938, L1 Loss: 0.7695, Total Loss: -40.2714
#Rank: 151


 75%|███████▌  | 60/80 [06:52<02:19,  6.95s/it]

Epoch: 0, Batch: 60, Task Loss: -110.0625, L1 Loss: 0.7539, Total Loss: -41.1557
#Rank: 152


 80%|████████  | 64/80 [07:20<01:51,  6.99s/it]

Epoch: 0, Batch: 64, Task Loss: -24.4219, L1 Loss: 0.7344, Total Loss: -40.7297
#Rank: 155


 85%|████████▌ | 68/80 [07:48<01:24,  7.04s/it]

Epoch: 0, Batch: 68, Task Loss: -44.5000, L1 Loss: 0.7148, Total Loss: -40.8648
#Rank: 153


 90%|█████████ | 72/80 [08:15<00:54,  6.85s/it]

Epoch: 0, Batch: 72, Task Loss: -56.7500, L1 Loss: 0.6992, Total Loss: -41.6181
#Rank: 146


 95%|█████████▌| 76/80 [08:42<00:27,  6.80s/it]

Epoch: 0, Batch: 76, Task Loss: -9.6875, L1 Loss: 0.6836, Total Loss: -41.6486
#Rank: 139


100%|██████████| 80/80 [09:10<00:00,  6.89s/it]
  0%|          | 0/80 [00:00<?, ?it/s]

Epoch: 1, Batch: 0, Task Loss: -18.6406, L1 Loss: 0.6719, Total Loss: -17.9688
#Rank: 140


  5%|▌         | 4/80 [00:27<08:34,  6.77s/it]

Epoch: 1, Batch: 4, Task Loss: -46.3125, L1 Loss: 0.6562, Total Loss: -28.9742
#Rank: 138


 10%|█         | 8/80 [00:54<07:54,  6.60s/it]

Epoch: 1, Batch: 8, Task Loss: -24.0625, L1 Loss: 0.6484, Total Loss: -27.9371
#Rank: 133


 15%|█▌        | 12/80 [01:21<07:39,  6.76s/it]

Epoch: 1, Batch: 12, Task Loss: -110.5625, L1 Loss: 0.6445, Total Loss: -33.3098
#Rank: 135


 20%|██        | 16/80 [01:49<07:22,  6.92s/it]

Epoch: 1, Batch: 16, Task Loss: -37.2188, L1 Loss: 0.6367, Total Loss: -35.8511
#Rank: 132


 25%|██▌       | 20/80 [02:16<06:54,  6.90s/it]

Epoch: 1, Batch: 20, Task Loss: -19.8906, L1 Loss: 0.6289, Total Loss: -34.4278
#Rank: 132


 30%|███       | 24/80 [02:44<06:29,  6.96s/it]

Epoch: 1, Batch: 24, Task Loss: -19.9375, L1 Loss: 0.6133, Total Loss: -32.6528
#Rank: 127


 35%|███▌      | 28/80 [03:12<06:02,  6.97s/it]

Epoch: 1, Batch: 28, Task Loss: -95.8750, L1 Loss: 0.6133, Total Loss: -34.5940
#Rank: 129


 40%|████      | 32/80 [03:40<05:33,  6.95s/it]

Epoch: 1, Batch: 32, Task Loss: -41.2500, L1 Loss: 0.6094, Total Loss: -37.8800
#Rank: 128


 45%|████▌     | 36/80 [04:07<05:01,  6.85s/it]

Epoch: 1, Batch: 36, Task Loss: -54.4062, L1 Loss: 0.6094, Total Loss: -38.0631
#Rank: 129


 50%|█████     | 40/80 [04:34<04:29,  6.73s/it]

Epoch: 1, Batch: 40, Task Loss: -89.8750, L1 Loss: 0.6094, Total Loss: -40.7811
#Rank: 130


 55%|█████▌    | 44/80 [05:01<04:05,  6.81s/it]

Epoch: 1, Batch: 44, Task Loss: -63.6562, L1 Loss: 0.6094, Total Loss: -42.4852
#Rank: 125


 60%|██████    | 48/80 [05:28<03:35,  6.74s/it]

Epoch: 1, Batch: 48, Task Loss: -30.5469, L1 Loss: 0.6055, Total Loss: -42.2420
#Rank: 129


 65%|██████▌   | 52/80 [05:56<03:12,  6.86s/it]

Epoch: 1, Batch: 52, Task Loss: -76.0625, L1 Loss: 0.6016, Total Loss: -42.9844
#Rank: 126


 70%|███████   | 56/80 [06:23<02:44,  6.85s/it]

Epoch: 1, Batch: 56, Task Loss: -33.4062, L1 Loss: 0.6016, Total Loss: -42.4169
#Rank: 129


 75%|███████▌  | 60/80 [06:51<02:17,  6.88s/it]

Epoch: 1, Batch: 60, Task Loss: -111.5000, L1 Loss: 0.6016, Total Loss: -43.2264
#Rank: 128


 80%|████████  | 64/80 [07:18<01:50,  6.94s/it]

Epoch: 1, Batch: 64, Task Loss: -24.4219, L1 Loss: 0.5938, Total Loss: -42.6788
#Rank: 126


 85%|████████▌ | 68/80 [07:46<01:23,  6.98s/it]

Epoch: 1, Batch: 68, Task Loss: -44.4375, L1 Loss: 0.5859, Total Loss: -42.7212
#Rank: 124


 90%|█████████ | 72/80 [08:14<00:54,  6.83s/it]

Epoch: 1, Batch: 72, Task Loss: -57.1250, L1 Loss: 0.5820, Total Loss: -43.4340
#Rank: 124


 95%|█████████▌| 76/80 [08:41<00:27,  6.78s/it]

Epoch: 1, Batch: 76, Task Loss: -9.6328, L1 Loss: 0.5820, Total Loss: -43.3821
#Rank: 122


100%|██████████| 80/80 [09:09<00:00,  6.87s/it]


Training finished for layer: 40, lambda: 0.005
Validation started for layer: 40, lambda: 0.005
Rank: 121


  2%|▎         | 1/40 [00:05<03:42,  5.71s/it]

Prediction: Aim | Target: is aimed at the correct location


  5%|▌         | 2/40 [00:11<03:40,  5.79s/it]

Prediction: too | Target: set at the correct temperature for baking biscotti


  8%|▊         | 3/40 [00:17<03:34,  5.81s/it]

Prediction: R | Target: ripe and ready to be picked


 10%|█         | 4/40 [00:23<03:30,  5.83s/it]

Prediction: Em | Target: the fishing net is strong and without holes


 12%|█▎        | 5/40 [00:29<03:25,  5.88s/it]

Prediction: fresh | Target: her ingredients are fresh and suitable for baking the cake


 15%|█▌        | 6/40 [00:35<03:20,  5.90s/it]

Prediction: Fresh | Target: fresh and of high quality


 18%|█▊        | 7/40 [00:41<03:15,  5.92s/it]

Prediction: calm | Target: sea is calm and full of fish


 20%|██        | 8/40 [00:46<03:08,  5.90s/it]

Prediction: Cal | Target: waters near the shore are calm and full of fish


 22%|██▎       | 9/40 [00:52<03:03,  5.91s/it]

Prediction: strong | Target: is strong and ready to be used


 25%|██▌       | 10/40 [00:58<02:57,  5.92s/it]

Prediction: Te | Target: teeming with fish


 28%|██▊       | 11/40 [01:04<02:52,  5.95s/it]

Prediction: In | Target: in excellent condition


 30%|███       | 12/40 [01:10<02:46,  5.94s/it]

Prediction: Kw | Target: is sturdy and suitable for carving


 32%|███▎      | 13/40 [01:16<02:40,  5.94s/it]

Prediction: Green | Target: blue glaze


 35%|███▌      | 14/40 [01:22<02:33,  5.91s/it]

Prediction: Vine | Target: vinegar


 38%|███▊      | 15/40 [01:28<02:26,  5.87s/it]

Prediction: valuable | Target: her valuable violin


 40%|████      | 16/40 [01:34<02:21,  5.88s/it]

Prediction: free | Target: is free from cracks


 42%|████▎     | 17/40 [01:40<02:14,  5.87s/it]

Prediction: clean | Target: clean and ready for use


 45%|████▌     | 18/40 [01:46<02:09,  5.91s/it]

Prediction: filled | Target: clean and in excellent condition


 48%|████▊     | 19/40 [01:51<02:03,  5.89s/it]

Prediction: clean | Target: clean and ready for use


 50%|█████     | 20/40 [01:57<01:58,  5.92s/it]

Prediction: Int | Target: intact


 52%|█████▎    | 21/40 [02:03<01:51,  5.89s/it]

Prediction: functioning | Target: functioning properly


 55%|█████▌    | 22/40 [02:09<01:45,  5.87s/it]

Prediction: crisp | Target: crisp and colorful


 57%|█████▊    | 23/40 [02:15<01:40,  5.89s/it]

Prediction: Full | Target: full of honey


 60%|██████    | 24/40 [02:21<01:34,  5.92s/it]

Prediction: ripe | Target: ripe and undamaged


 62%|██████▎   | 25/40 [02:27<01:28,  5.89s/it]

Prediction: less | Target: rare, valuable one


 65%|██████▌   | 26/40 [02:33<01:22,  5.86s/it]

Prediction: ideal | Target: still ideal for taking pictures


 68%|██████▊   | 27/40 [02:38<01:16,  5.85s/it]

Prediction: half | Target: unbaked cookies


 70%|███████   | 28/40 [02:44<01:10,  5.86s/it]

Prediction: adorned | Target: still unadorned


 72%|███████▎  | 29/40 [02:50<01:04,  5.83s/it]

Prediction: T | Target: tied to the dock


 75%|███████▌  | 30/40 [02:56<00:58,  5.86s/it]

Prediction: Clear | Target: obstructed by leaves


 78%|███████▊  | 31/40 [03:02<00:52,  5.82s/it]

Prediction: soaked | Target: dry


 80%|████████  | 32/40 [03:08<00:46,  5.83s/it]

Prediction: Fresh | Target: fresh


 82%|████████▎ | 33/40 [03:13<00:40,  5.83s/it]

Prediction: moved | Target: is covering the entrance


 85%|████████▌ | 34/40 [03:19<00:35,  5.85s/it]

Prediction: patched | Target: patched


 88%|████████▊ | 35/40 [03:25<00:29,  5.84s/it]

Prediction: Not | Target: preheating at 350°F


 90%|█████████ | 36/40 [03:31<00:23,  5.82s/it]

Prediction: Im | Target: rare and exotic


 92%|█████████▎| 37/40 [03:37<00:17,  5.82s/it]

Prediction: Healthy | Target: weakened state


 95%|█████████▌| 38/40 [03:43<00:11,  5.83s/it]

Prediction: sturdy | Target: sturdy branches


 98%|█████████▊| 39/40 [03:48<00:05,  5.83s/it]

Prediction: severely | Target: healthy


100%|██████████| 40/40 [03:54<00:00,  5.87s/it]

Prediction: Area | Target: area shown on the map
Validation accuracy: 0.65 | Correct: 26 | Total: 40






In [None]:
# 0.005 -> 0.75
# 0.0025 -> 0.75

## Query Character

In [8]:
sing_vecs = defaultdict(dict)
for l in range(model.config.num_hidden_layers):
    sing_vecs[l] = torch.load(f"../svd_results/BigToM/query_charac_new/singular_vecs/{l}.pt").cpu()

In [9]:
train_size = 80
valid_size = 40
batch_size = 1

dataset = get_bigtom_query_charac(df_false, df_true, train_size+valid_size)
train_dataset = dataset[:train_size]
valid_dataset = dataset[train_size:]

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

In [10]:
idx=0
print(dataset[idx]['alt_prompt'], dataset[idx]['alt_ans'])
print(dataset[idx]['org_prompt'], dataset[idx]['org_ans'])
print(f"Target: {dataset[idx]['target']}")

Instructions: 1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any belief about the container or its content which they cannot observe directly. 4. To answer the question, predict only the final state of the queried container in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict the entire sentence with character or container as the final output.

Story: Noor is working as a barista at a busy coffee shop. Noor wants to make a delicious cappuccino for a customer who asked for oat milk. Noor grabs a milk pitcher and fills it with oat milk. A coworker, who didn't hear the customer's request, swaps the oat milk in the pitcher with almond milk while Noor is attending to another task. Noor sees her coworker swapping the milk.
Question: Does Noor believe the milk pitch

In [16]:
model.tokenizer.padding_side = "left"

valid_accs = defaultdict(dict)
for layer_idx in range(14, 16, 2):

    n_epochs = 1
    lambs = [0.005]
    for lamb in lambs:
        modules = [i for i in range(sing_vecs[layer_idx].size(0))]
        mask = torch.ones(len(modules), requires_grad=True, device="cuda", dtype=torch.bfloat16)
        optimizer = torch.optim.Adam([mask], lr=1e-1)

        print(f"Training layer: {layer_idx}, lambda: {lamb}")
        for epoch in range(n_epochs):
            epoch_loss = 0

            for bi, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
                alt_prompt = batch["alt_prompt"]
                org_prompt = batch["org_prompt"]
                target = batch["target"]
                target_token = model.tokenizer(target, return_tensors="pt", padding=True, padding_side="right")
                target_input_ids = target_token.input_ids[:, 1:]
                batch_size = target_input_ids.size(0)

                alt_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, alt_prompt, padding_side="right")
                alt_prompt_len = get_prompt_token_len(model.tokenizer, alt_prompt, padding_side="right")
                org_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, org_prompt, padding_side="right")
                org_prompt_len = get_prompt_token_len(model.tokenizer, org_prompt, padding_side="right")

                optimizer.zero_grad()

                with model.trace() as tracer:
                    alt_acts = defaultdict(dict)
                    with tracer.invoke(alt_prompt):
                        for t_idx, t in enumerate([i for i in range(alt_ques_idx+3, alt_ques_idx+5)]):
                            alt_acts[t_idx] = model.model.layers[layer_idx].output[0][0, t].clone().save()

                    with tracer.invoke(org_prompt):
                        sing_vec = sing_vecs[layer_idx].cuda()
                        masked_vec = sing_vec * mask.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()

                        for t_idx, t in enumerate([i for i in range(org_ques_idx+3, org_ques_idx+5)]):
                            curr_output = model.model.layers[layer_idx].output[0][0, t].clone()

                            alt_proj = torch.matmul(alt_acts[t_idx], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)

                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[layer_idx].output[0][0, t] = modified_out

                        logits = model.lm_head.output[0, -1].save()

                        del sing_vec, proj_matrix, masked_vec
                        torch.cuda.empty_cache()

                target_logit = logits[target_input_ids[0]].sum()

                task_loss = -(target_logit/batch_size)
                l1_loss = lamb * torch.norm(mask, p=1)
                loss = task_loss + l1_loss.to(task_loss.device)
                
                epoch_loss += loss.item()
                
                if bi % 4 == 0:
                    mean_loss = epoch_loss / (bi + 1)
                    print(f"Epoch: {epoch}, Batch: {bi}, Task Loss: {task_loss.item():.4f}, "
                        f"L1 Loss: {l1_loss.item():.4f}, Total Loss: {mean_loss:.4f}")
                    with torch.no_grad():
                        mask.data.clamp_(0, 1)
                        rounded = torch.round(mask)
                        print(f"#Rank: {(rounded == 1).sum().item()}")
                
                loss.backward()
                optimizer.step()

                with torch.no_grad():
                    mask.data.clamp_(0, 1)

        print(f"Training finished for layer: {layer_idx}, lambda: {lamb}")

        layer_idx = 14
        print(f"Validation started for layer: {layer_idx}, lambda: {lamb}")
        correct, total = 0, 0

        with torch.inference_mode():
            mask_data = mask.data.clone()
            mask_data.clamp_(0, 1)
            rounded = torch.round(mask)

            print(f"Rank: {(rounded == 1).sum().item()}")

            # Save the mask
            # torch.save(mask_data, f"../masks/bigtom/{layer_idx}.pt")

            for bi, batch in tqdm(enumerate(valid_dataloader), total=len(valid_dataloader)):
                alt_prompt = batch["alt_prompt"]
                org_prompt = batch["org_prompt"]
                alt_ans = batch["alt_ans"]
                target = batch["target"][0]
                batch_size = len(alt_ans)

                alt_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, alt_prompt, padding_side="left")
                alt_prompt_len = get_prompt_token_len(model.tokenizer, alt_prompt, padding_side="left")
                org_ques_idx = get_ques_start_token_idx(batch_size, model.tokenizer, org_prompt, padding_side="left")
                org_prompt_len = get_prompt_token_len(model.tokenizer, org_prompt, padding_side="left")

                with model.session() as session:
                    alt_layer_out = defaultdict(dict)
                    with model.trace(alt_prompt):
                        for t_idx, t in enumerate([i for i in range(alt_ques_idx+3, alt_ques_idx+5)]):
                            alt_layer_out[t_idx] = model.model.layers[layer_idx].output[0][0, t].save()

                    with model.generate(org_prompt, max_new_tokens=2, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id):
                        sing_vec = sing_vecs[layer_idx].cuda()
                        masked_vec = sing_vec * rounded.unsqueeze(-1)
                        proj_matrix = torch.matmul(masked_vec.t(), masked_vec).half()
                        
                        for t_idx, t in enumerate([i for i in range(org_ques_idx+3, org_ques_idx+5)]):
                            curr_output = model.model.layers[layer_idx].output[0][0, t].clone()
                            alt_proj = torch.matmul(alt_layer_out[t_idx], proj_matrix)
                            org_proj = torch.matmul(curr_output, proj_matrix)
                            modified_out = curr_output - org_proj + alt_proj
                            model.model.layers[layer_idx].output[0][0, t] = modified_out

                        out = model.generator.output.save()

                    del alt_layer_out
                    torch.cuda.empty_cache()

                pred = model.tokenizer.decode(out[0][org_prompt_len:-1]).strip()
                print(f"Prediction: {pred} | Target: {target}")
                if pred.lower() in target.lower():
                    correct += 1
                total += 1

            print(f"Validation accuracy: {correct / total:.2f} | Correct: {correct} | Total: {total}\n")
            valid_accs[lamb][layer_idx] = round(correct / total, 2)

Training layer: 14, lambda: 0.005


  0%|          | 0/80 [00:00<?, ?it/s]

Epoch: 0, Batch: 0, Task Loss: -14.8203, L1 Loss: 4.0000, Total Loss: -10.8203
#Rank: 800


  5%|▌         | 4/80 [00:32<10:24,  8.22s/it]

Epoch: 0, Batch: 4, Task Loss: -10.4141, L1 Loss: 3.4844, Total Loss: -8.1102
#Rank: 800


 10%|█         | 8/80 [01:07<09:57,  8.30s/it]

Epoch: 0, Batch: 8, Task Loss: -5.6055, L1 Loss: 3.1562, Total Loss: -13.5634
#Rank: 699


 15%|█▌        | 12/80 [01:42<09:46,  8.62s/it]

Epoch: 0, Batch: 12, Task Loss: -21.2031, L1 Loss: 2.9375, Total Loss: -15.6364
#Rank: 584


 20%|██        | 16/80 [02:18<09:30,  8.91s/it]

Epoch: 0, Batch: 16, Task Loss: -49.2188, L1 Loss: 2.7344, Total Loss: -23.9784
#Rank: 554


 25%|██▌       | 20/80 [02:53<08:53,  8.89s/it]

Epoch: 0, Batch: 20, Task Loss: -10.2578, L1 Loss: 2.5469, Total Loss: -20.2484
#Rank: 525


 30%|███       | 24/80 [03:30<08:21,  8.96s/it]

Epoch: 0, Batch: 24, Task Loss: -30.8125, L1 Loss: 2.4062, Total Loss: -18.9494
#Rank: 495


 35%|███▌      | 28/80 [04:06<07:46,  8.97s/it]

Epoch: 0, Batch: 28, Task Loss: -13.6172, L1 Loss: 2.2656, Total Loss: -20.5725
#Rank: 465


 40%|████      | 32/80 [04:41<06:58,  8.73s/it]

Epoch: 0, Batch: 32, Task Loss: -89.8750, L1 Loss: 2.1562, Total Loss: -22.7676
#Rank: 432


 45%|████▌     | 36/80 [05:16<06:24,  8.74s/it]

Epoch: 0, Batch: 36, Task Loss: -4.6680, L1 Loss: 2.0625, Total Loss: -21.8538
#Rank: 424


 50%|█████     | 40/80 [05:51<06:03,  9.09s/it]

Epoch: 0, Batch: 40, Task Loss: -36.5000, L1 Loss: 1.9922, Total Loss: -22.0757
#Rank: 404


 55%|█████▌    | 44/80 [06:25<05:16,  8.80s/it]

Epoch: 0, Batch: 44, Task Loss: -36.3750, L1 Loss: 1.9375, Total Loss: -22.6069
#Rank: 396


 60%|██████    | 48/80 [07:03<04:58,  9.33s/it]

Epoch: 0, Batch: 48, Task Loss: -42.7188, L1 Loss: 1.8672, Total Loss: -22.6900
#Rank: 377


 65%|██████▌   | 52/80 [07:39<04:20,  9.30s/it]

Epoch: 0, Batch: 52, Task Loss: -40.5625, L1 Loss: 1.8203, Total Loss: -22.8216
#Rank: 367


 70%|███████   | 56/80 [08:14<03:41,  9.24s/it]

Epoch: 0, Batch: 56, Task Loss: -10.4609, L1 Loss: 1.7812, Total Loss: -22.6715
#Rank: 353


 75%|███████▌  | 60/80 [08:51<03:05,  9.29s/it]

Epoch: 0, Batch: 60, Task Loss: -10.3203, L1 Loss: 1.7500, Total Loss: -21.4731
#Rank: 350


 80%|████████  | 64/80 [09:27<02:27,  9.21s/it]

Epoch: 0, Batch: 64, Task Loss: -59.9375, L1 Loss: 1.7266, Total Loss: -22.4138
#Rank: 344


 85%|████████▌ | 68/80 [10:06<01:55,  9.63s/it]

Epoch: 0, Batch: 68, Task Loss: -11.9141, L1 Loss: 1.7109, Total Loss: -22.0800
#Rank: 340


 90%|█████████ | 72/80 [10:43<01:12,  9.09s/it]

Epoch: 0, Batch: 72, Task Loss: -41.9688, L1 Loss: 1.6875, Total Loss: -21.6590
#Rank: 334


 95%|█████████▌| 76/80 [11:18<00:36,  9.01s/it]

Epoch: 0, Batch: 76, Task Loss: -9.9688, L1 Loss: 1.6719, Total Loss: -21.6988
#Rank: 333


100%|██████████| 80/80 [11:54<00:00,  8.93s/it]


Training finished for layer: 14, lambda: 0.005
Validation started for layer: 14, lambda: 0.005
Rank: 327


  2%|▎         | 1/40 [00:05<03:45,  5.79s/it]

Prediction: N | Target: warming up due to the power outage


  5%|▌         | 2/40 [00:12<04:00,  6.34s/it]

Prediction: has | Target: has a hairline crack


  8%|▊         | 3/40 [00:19<04:04,  6.61s/it]

Prediction: Already | Target: have already been found


 10%|█         | 4/40 [00:26<04:01,  6.72s/it]

Prediction: melted | Target: has melted


 12%|█▎        | 5/40 [00:33<03:56,  6.76s/it]

Prediction: Fish | Target: fish have moved away


 15%|█▌        | 6/40 [00:39<03:50,  6.79s/it]

Prediction: A | Target: corrupted due to a server malfunction


 18%|█▊        | 7/40 [00:46<03:44,  6.80s/it]

Prediction: corrupted | Target: corrupted and difficult to use


 20%|██        | 8/40 [00:53<03:38,  6.81s/it]

Prediction: rough | Target: rough and dangerous


 22%|██▎       | 9/40 [01:00<03:31,  6.82s/it]

Prediction: Not | Target: exposed


 25%|██▌       | 10/40 [01:07<03:24,  6.83s/it]

Prediction: Ru | Target: the paintbrushes are ruined


 28%|██▊       | 11/40 [01:14<03:18,  6.84s/it]

Prediction: Dis | Target: disrupted by wind and leaves


 30%|███       | 12/40 [01:21<03:11,  6.85s/it]

Prediction: De | Target: devoid of fish due to the volcanic eruption


 32%|███▎      | 13/40 [01:27<03:04,  6.83s/it]

Prediction: N | Target: soaked


 35%|███▌      | 14/40 [01:34<02:58,  6.86s/it]

Prediction: N | Target: half-baked cookies


 38%|███▊      | 15/40 [01:41<02:53,  6.94s/it]

Prediction: Not | Target: not heating at all


 40%|████      | 16/40 [01:48<02:45,  6.91s/it]

Prediction: Am | Target: hard and unworkable


 42%|████▎     | 17/40 [01:55<02:37,  6.86s/it]

Prediction: Over | Target: over-fermented


 45%|████▌     | 18/40 [02:02<02:31,  6.89s/it]

Prediction: Mei | Target: become indistinguishable from any other piece of clay


 48%|████▊     | 19/40 [02:09<02:25,  6.91s/it]

Prediction: severely | Target: severely damaged


 50%|█████     | 20/40 [02:16<02:19,  6.97s/it]

Prediction: Enrique | Target: has a broken string


 52%|█████▎    | 21/40 [02:23<02:12,  6.95s/it]

Prediction: extremely | Target: extremely out of tune


 55%|█████▌    | 22/40 [02:30<02:04,  6.94s/it]

Prediction: Kw | Target: washed away and diluted


 57%|█████▊    | 23/40 [02:37<01:57,  6.92s/it]

Prediction: concealed | Target: knows about the concealed door


 60%|██████    | 24/40 [02:44<01:50,  6.90s/it]

Prediction: tangled | Target: tangled and difficult to use


 62%|██████▎   | 25/40 [02:50<01:43,  6.89s/it]

Prediction: Healthy | Target: healthy state


 65%|██████▌   | 26/40 [02:57<01:36,  6.90s/it]

Prediction: Slip | Target: floor is slippery due to the oil spill


 68%|██████▊   | 27/40 [03:04<01:29,  6.90s/it]

Prediction: Ru | Target: the paintbrushes are ruined


 70%|███████   | 28/40 [03:11<01:22,  6.84s/it]

Prediction: torn | Target: torn with a large hole


 72%|███████▎  | 29/40 [03:18<01:16,  6.94s/it]

Prediction: N | Target: damaged by the storm


 75%|███████▌  | 30/40 [03:25<01:09,  6.92s/it]

Prediction: F | Target: filled with sand


 78%|███████▊  | 31/40 [03:32<01:02,  6.92s/it]

Prediction: There | Target: has crashed and become unresponsive


 80%|████████  | 32/40 [03:39<00:55,  6.94s/it]

Prediction: Wash | Target: washed away and diluted


 82%|████████▎ | 33/40 [03:46<00:48,  6.88s/it]

Prediction: W | Target: have wilted due to the rainstorm


 85%|████████▌ | 34/40 [03:53<00:41,  6.90s/it]

Prediction: corrupted | Target: corrupted and difficult to use


 88%|████████▊ | 35/40 [03:59<00:34,  6.85s/it]

Prediction: Rash | Target: has been badly damaged


 90%|█████████ | 36/40 [04:06<00:27,  6.92s/it]

Prediction: Im | Target: that the fishing net has a large hole in it


 92%|█████████▎| 37/40 [04:13<00:20,  6.93s/it]

Prediction: W | Target: wilted


 95%|█████████▌| 38/40 [04:20<00:13,  6.94s/it]

Prediction: weakened | Target: weakened and infested with termites


 98%|█████████▊| 39/40 [04:27<00:06,  6.94s/it]

Prediction: emerging | Target: an emerging artist


100%|██████████| 40/40 [04:34<00:00,  6.87s/it]

Prediction: Not | Target: sold out
Validation accuracy: 0.70 | Correct: 28 | Total: 40




