In [2]:
import json
import random
import os
import math
import sys
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
from typing import Any, List, Optional
import nnsight
from nnsight import CONFIG, LanguageModel
import numpy as np
from collections import defaultdict
from einops import einsum
import time
from einops import rearrange, reduce
import pandas as pd

sys.path.append("../")
from src.dataset import SampleV3, DatasetV3, STORY_TEMPLATES
from utils import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
random.seed(10)

CONFIG.set_default_api_key("d9e00ab7d4f74643b3176de0913f24a7")
os.environ["HF_TOKEN"] = "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"

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
CONFIG.APP.REMOTE_LOGGING = False

%load_ext autoreload
%autoreload 2

# Loading model

In [None]:
# model = LanguageModel("meta-llama/Meta-Llama-3.1-405B")
model = LanguageModel("meta-llama/Meta-Llama-3-70B-Instruct", cache_dir="/disk/u/nikhil/.cache/huggingface/hub", device_map="auto", load_in_4bit=True, torch_dtype=torch.float16, dispatch=True)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards:  27%|██▋       | 8/30 [00:09<00:26,  1.19s/it]


KeyboardInterrupt: 

# Computing Singular Vectors

In [3]:
# Load cached_acts.pt using torch.load
cached_acts = torch.load("../caches/Llama-70B-Instruct/BigToM/query_charac_acts.pt")
# visibility_sent_lens = torch.load("../caches/Llama-70B-Instruct/BigToM/visibility_sent_lens.pt")

In [6]:
len(cached_acts[:, 0].shape)

3

In [7]:
for layer in tqdm(range(80), desc="Processing layers"):
    os.makedirs(f"../svd_results/BigToM/query_charac_new/singular_values/", exist_ok=True)
    os.makedirs(f"../svd_results/BigToM/query_charac_new/singular_vecs/", exist_ok=True)

    # Create boolean mask for valid tokens
    # batch_size = cached_acts.size(0)
    # max_tokens = cached_acts.size(2)
    # mask = torch.arange(max_tokens).unsqueeze(0) < visibility_sent_lens.unsqueeze(1)
    # mask = mask.to(cached_acts.device)

    # Use boolean indexing to get all valid activations at once
    # acts = cached_acts[:, layer][mask]  # This will automatically flatten the valid activations

    if len(cached_acts[:, layer].shape) > 2:
        acts = cached_acts[:, layer].reshape(-1, cached_acts.size(-1))
        print(f"Reshaped acts: {acts.shape}")
    else:
        acts = cached_acts[:, layer]

    _, singular_values, Vh = torch.linalg.svd(acts, full_matrices=False)
    torch.save(singular_values, f"../svd_results/BigToM/query_charac_new/singular_values/{layer}.pt")
    torch.save(Vh, f"../svd_results/BigToM/query_charac_new/singular_vecs/{layer}.pt")

Processing layers:   0%|          | 0/80 [00:00<?, ?it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   1%|▏         | 1/80 [00:00<00:44,  1.76it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   2%|▎         | 2/80 [00:01<00:50,  1.55it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   4%|▍         | 3/80 [00:01<00:50,  1.54it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   5%|▌         | 4/80 [00:02<00:48,  1.55it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   6%|▋         | 5/80 [00:03<00:47,  1.57it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   8%|▊         | 6/80 [00:03<00:51,  1.45it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:   9%|▉         | 7/80 [00:04<00:49,  1.46it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  10%|█         | 8/80 [00:05<00:46,  1.54it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  11%|█▏        | 9/80 [00:05<00:44,  1.59it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  12%|█▎        | 10/80 [00:06<00:42,  1.64it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  14%|█▍        | 11/80 [00:06<00:42,  1.63it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  15%|█▌        | 12/80 [00:07<00:40,  1.68it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  16%|█▋        | 13/80 [00:08<00:38,  1.73it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  18%|█▊        | 14/80 [00:08<00:38,  1.72it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  19%|█▉        | 15/80 [00:09<00:37,  1.71it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  20%|██        | 16/80 [00:09<00:37,  1.73it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  21%|██▏       | 17/80 [00:10<00:36,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  22%|██▎       | 18/80 [00:10<00:34,  1.78it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  24%|██▍       | 19/80 [00:11<00:33,  1.80it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  25%|██▌       | 20/80 [00:12<00:33,  1.80it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  26%|██▋       | 21/80 [00:12<00:33,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  28%|██▊       | 22/80 [00:13<00:33,  1.73it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  29%|██▉       | 23/80 [00:13<00:32,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  30%|███       | 24/80 [00:14<00:32,  1.74it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  31%|███▏      | 25/80 [00:14<00:31,  1.76it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  32%|███▎      | 26/80 [00:15<00:30,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  34%|███▍      | 27/80 [00:16<00:29,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  35%|███▌      | 28/80 [00:16<00:29,  1.79it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  36%|███▋      | 29/80 [00:17<00:29,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  38%|███▊      | 30/80 [00:17<00:28,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  39%|███▉      | 31/80 [00:18<00:28,  1.70it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  40%|████      | 32/80 [00:18<00:27,  1.71it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  41%|████▏     | 33/80 [00:19<00:27,  1.69it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  42%|████▎     | 34/80 [00:20<00:30,  1.52it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  44%|████▍     | 35/80 [00:20<00:28,  1.59it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  45%|████▌     | 36/80 [00:21<00:26,  1.64it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  46%|████▋     | 37/80 [00:22<00:25,  1.66it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  48%|████▊     | 38/80 [00:22<00:24,  1.69it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  49%|████▉     | 39/80 [00:23<00:23,  1.72it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  50%|█████     | 40/80 [00:23<00:22,  1.74it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  51%|█████▏    | 41/80 [00:24<00:22,  1.73it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  52%|█████▎    | 42/80 [00:24<00:21,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  54%|█████▍    | 43/80 [00:25<00:20,  1.76it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  55%|█████▌    | 44/80 [00:26<00:20,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  56%|█████▋    | 45/80 [00:26<00:20,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  57%|█████▊    | 46/80 [00:27<00:19,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  59%|█████▉    | 47/80 [00:27<00:18,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  60%|██████    | 48/80 [00:28<00:17,  1.78it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  61%|██████▏   | 49/80 [00:28<00:17,  1.81it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  62%|██████▎   | 50/80 [00:29<00:16,  1.81it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  64%|██████▍   | 51/80 [00:29<00:16,  1.80it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  65%|██████▌   | 52/80 [00:30<00:15,  1.79it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  66%|██████▋   | 53/80 [00:31<00:15,  1.74it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  68%|██████▊   | 54/80 [00:31<00:14,  1.78it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  69%|██████▉   | 55/80 [00:32<00:14,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  70%|███████   | 56/80 [00:32<00:13,  1.78it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  71%|███████▏  | 57/80 [00:33<00:14,  1.62it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  72%|███████▎  | 58/80 [00:34<00:13,  1.63it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  74%|███████▍  | 59/80 [00:34<00:12,  1.64it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  75%|███████▌  | 60/80 [00:35<00:11,  1.67it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  76%|███████▋  | 61/80 [00:35<00:11,  1.70it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  78%|███████▊  | 62/80 [00:36<00:10,  1.72it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  79%|███████▉  | 63/80 [00:37<00:09,  1.71it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  80%|████████  | 64/80 [00:37<00:09,  1.71it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  81%|████████▏ | 65/80 [00:38<00:08,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  82%|████████▎ | 66/80 [00:38<00:08,  1.74it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  84%|████████▍ | 67/80 [00:39<00:07,  1.74it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  85%|████████▌ | 68/80 [00:39<00:06,  1.72it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  86%|████████▋ | 69/80 [00:40<00:06,  1.75it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  88%|████████▊ | 70/80 [00:41<00:05,  1.76it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  89%|████████▉ | 71/80 [00:41<00:05,  1.76it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  90%|█████████ | 72/80 [00:42<00:04,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  91%|█████████▏| 73/80 [00:42<00:03,  1.77it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  92%|█████████▎| 74/80 [00:43<00:03,  1.76it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  94%|█████████▍| 75/80 [00:43<00:02,  1.73it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  95%|█████████▌| 76/80 [00:44<00:02,  1.74it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  96%|█████████▋| 77/80 [00:45<00:01,  1.70it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  98%|█████████▊| 78/80 [00:45<00:01,  1.71it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers:  99%|█████████▉| 79/80 [00:46<00:00,  1.71it/s]

Reshaped acts: torch.Size([800, 8192])


Processing layers: 100%|██████████| 80/80 [00:46<00:00,  1.71it/s]


# Load Dataset

In [None]:
n_samples = 20
batch_size = 1

configs = []
for _ in range(n_samples):
    template_1 = STORY_TEMPLATES['templates'][0]
    template_2 = STORY_TEMPLATES['templates'][1]
    characters = random.sample(all_characters, 2)
    containers = random.sample(all_containers[template_1["container_type"]], 2)
    states = random.sample(all_states[template_1["state_type"]], 2)
    event_idx = None
    event_noticed = False
    visibility = random.choice([True, False])

    sample = SampleV3(
        template=template_2 if visibility else template_1,
        characters=characters,
        containers=containers,
        states=states,
        visibility=visibility,
        event_idx=event_idx,
        event_noticed=event_noticed,
    )
    configs.append(sample)

dataset = DatasetV3(configs)
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
idx = 0
sample = dataset[idx]
print(sample['prompt'], sample['target'], sample['visibility'])

In [None]:
acts_charac = torch.empty(n_samples, model.config.num_hidden_layers, model.config.hidden_size)
acts_obj = torch.empty(n_samples, model.config.num_hidden_layers, model.config.hidden_size)
character_indices, object_indices, visibility = [], [], []

for bi, data in tqdm(enumerate(dataloader), total=len(dataloader)):
    prompt = data['prompt'][0]
    character_idx = data['character_idx'][0]
    character_indices.append(character_idx)
    object_indices.append(data['object_idx'][0])
    visibility.append(data['visibility'][0])

    with torch.no_grad():

        with model.trace() as tracer:

            with tracer.invoke(prompt):
                for l in range(model.config.num_hidden_layers):
                    acts_charac[bi, l] = model.model.layers[l].output[0][0, -8].cpu().save()
                    acts_obj[bi, l] = model.model.layers[l].output[0][0, -5].cpu().save()

# Projection & Visualization onto Singular Vectors

## Query Character Viz

In [None]:
projected_acts = {}
for l in range(0, 40, 2):
    V = torch.load(f"../svd_results/charac_pos/Vh_{l}.pt")
    acts_l = acts_charac[:, l, :].cuda()
    projected_acts[l] = torch.matmul(acts_l, V[1:2, :].t()).cpu().numpy()

    del acts_l, V
    torch.cuda.empty_cache()

# Create a plot with 10 subplots, each subplot showing the projection of activations of layer l onto the first singular vector.
fig, axs = plt.subplots(5, 4, figsize=(12, 12))
for l in range(0, 40, 2):
    i = l//2

    for j, idx in enumerate(character_indices):
        axs[i // 4, i % 4].scatter(
            projected_acts[l][j][0], 
            np.zeros_like(projected_acts[l][j][0]), 
            color='r' if idx == 0 else 'b', 
            alpha=0.6
        )
    
    axs[i // 4, i % 4].set_title(f"Layer {l}")
    axs[i // 4, i % 4].set_yticks([])
    axs[i // 4, i % 4].set_xlabel("Projection values")

# Create a custom legend
custom_legend = [
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='r', markersize=8, label="First Character"),
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='b', markersize=8, label="Second Character"),
]
fig.legend(handles=custom_legend, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0.95))

plt.suptitle("Projection of activations onto the Third singular vector")
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to fit the legend
plt.show()
# plt.savefig("../plots/rep_viz/third_SV.png", dpi=300)

## Query Object Viz

In [None]:
projected_acts = {}
for l in range(0, 40, 2):
    V = torch.load(f"../svd_results/obj_pos/Vh_{l}.pt")
    acts_l = acts_obj[:, l, :].cuda()
    projected_acts[l] = torch.matmul(acts_l, V[2:4, :].t()).cpu().numpy()

    del acts_l, V
    torch.cuda.empty_cache()

# Create a plot with 10 subplots, each subplot showing the projection of activations of layer l onto the first singular vector.
fig, axs = plt.subplots(5, 4, figsize=(12, 12))
for l in range(0, 40, 2):
    i = l//2

    for j, idx in enumerate(object_indices):
        axs[i // 4, i % 4].scatter(
            projected_acts[l][j][0], 
            projected_acts[l][j][1], 
            color='r' if idx == 0 else 'b', 
            alpha=0.6
        )
    
    axs[i // 4, i % 4].set_title(f"Layer {l}")
    axs[i // 4, i % 4].set_yticks([])
    axs[i // 4, i % 4].set_xlabel("Projection values")

# Create a custom legend
custom_legend = [
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='r', markersize=8, label="First Object"),
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='b', markersize=8, label="Second Object"),
]
fig.legend(handles=custom_legend, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0.95))

plt.suptitle("Projection of activations onto the Third and Forth singular vectors")
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to fit the legend
plt.show()
# plt.savefig("../plots/rep_viz/obj_pos_SV_3_4.png", dpi=300)

## Visibility Viz

### Query Charac

In [None]:
projected_acts = {}
for l in range(0, 40, 2):
    V = torch.load(f"../svd_results/selected_tokens/singular_vecs/{l}.pt")
    acts_l = acts_charac[:, l, :].cuda()
    projected_acts[l] = torch.matmul(acts_l, V[3:4, :].t()).cpu().numpy()

    del acts_l
    torch.cuda.empty_cache()

# Create a plot with 10 subplots, each subplot showing the projection of activations of layer l onto the first singular vector.
fig, axs = plt.subplots(5, 4, figsize=(12, 12))
for l in range(0, 40, 2):
    i = l//2

    for j, idx in enumerate(visibility):
        axs[i // 4, i % 4].scatter(
            projected_acts[l][j][0], 
            np.zeros_like(projected_acts[l][j][0]), 
            color='r' if idx == 0 else 'b', 
            alpha=0.6
        )
    
    axs[i // 4, i % 4].set_title(f"Layer {l}")
    axs[i // 4, i % 4].set_yticks([])
    axs[i // 4, i % 4].set_xlabel("Projection values")

# Create a custom legend
custom_legend = [
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='r', markersize=8, label="Non Visible"),
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='b', markersize=8, label="Visible"),
]
fig.legend(handles=custom_legend, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0.95))

plt.suptitle("Projection of activations onto the third singular vector")
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to fit the legend
plt.show()
# plt.savefig("../plots/rep_viz/third_SV.png", dpi=300)

### Last token

In [None]:
projected_acts = {}
for l in range(0, 40, 2):
    V = torch.load(f"../svd_results/last_token/Vh_{l}.pt")
    acts_l = acts_charac[:, l, :].cuda()
    projected_acts[l] = torch.matmul(acts_l, V[1:2, :].t()).cpu().numpy()

    del acts_l
    torch.cuda.empty_cache()

# Create a plot with 10 subplots, each subplot showing the projection of activations of layer l onto the first singular vector.
fig, axs = plt.subplots(5, 4, figsize=(12, 12))
for l in range(0, 40, 2):
    i = l//2

    for j, idx in enumerate(visibility):
        axs[i // 4, i % 4].scatter(
            projected_acts[l][j][0], 
            np.zeros_like(projected_acts[l][j][0]), 
            color='r' if idx == 0 else 'b', 
            alpha=0.6
        )
    
    axs[i // 4, i % 4].set_title(f"Layer {l}")
    axs[i // 4, i % 4].set_yticks([])
    axs[i // 4, i % 4].set_xlabel("Projection values")

# Create a custom legend
custom_legend = [
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='r', markersize=8, label="Non Visible"),
    plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='b', markersize=8, label="Visible"),
]
fig.legend(handles=custom_legend, loc='upper center', ncol=2, bbox_to_anchor=(0.5, 0.95))

plt.suptitle("Projection of activations onto the third singular vector")
plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to fit the legend
plt.show()
# plt.savefig("../plots/rep_viz/third_SV.png", dpi=300)

# Causal Intervention for Character Position Info 

In [None]:
n_samples = 20
batch_size = 1

dataset = query_charac_pos(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             n_samples,
                             question_type="belief_question",
                             diff_visibility=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
idx = 0
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

## Error Detection

In [None]:
correct, total = 0, 0
errors = []
with torch.no_grad():
    for bi, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        clean_prompt = batch['clean_prompt'][0]
        corrupt_prompt = batch['corrupt_prompt'][0]
        clean_target = batch['clean_ans'][0]
        corrupt_target = batch['corrupt_ans'][0]

        with model.trace(clean_prompt):
            clean_pred = model.lm_head.output[0, -1].argmax(dim=-1).item().save()

        with model.trace(corrupt_prompt):
            corrupt_pred = model.lm_head.output[0, -1].argmax(dim=-1).item().save()

        print(f"Clean: {model.tokenizer.decode([clean_pred]).lower().strip()} | Corrupt: {model.tokenizer.decode([corrupt_pred]).lower().strip()}")
        if model.tokenizer.decode([clean_pred]).lower().strip() == clean_target and model.tokenizer.decode([corrupt_pred]).lower().strip() == corrupt_target:
            correct += 1
        else:
            errors.append(bi)
        total += 1
       
        del clean_pred, corrupt_pred
        torch.cuda.empty_cache()

print(f"Accuracy: {round(correct / total, 2)}")
print(f"correct: {correct} | total: {total}")

## Patching with Singular Vectors

In [None]:
singular_vecs = {}
for l in range(41):
    singular_vecs[l] = torch.load(f"../svd_results/charac_pos/Vh_{l}.pt").cpu()

In [None]:
# accs_query_charac_1_second_sv = {}

for layer_idx in range(12, 20, 2):
    correct, total = 0, 0
    for bi, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        if bi in errors:
            continue
        corrupt_prompt = batch["corrupt_prompt"][0]
        clean_prompt = batch["clean_prompt"][0]
        target = batch["target"][0]

        corrupt_layer_out, clean_layer_out = defaultdict(dict), defaultdict(dict)
        with torch.no_grad():
    
            with model.trace() as tracer:
                with tracer.invoke(corrupt_prompt):
                    corrupt_layer_out = model.model.layers[layer_idx].output[0][0, -8].save()

                with tracer.invoke(clean_prompt):
                    vec = singular_vecs[layer_idx][1:2, :].t().half().cuda()
                    # Calculate a projection matrix using the outer product of the singular vector
                    proj_matrix = torch.matmul(vec, vec.t())

                    corrupt_pos = torch.matmul(corrupt_layer_out, proj_matrix.T)
                    clean_pos = torch.matmul(model.model.layers[layer_idx].output[0][0, -8], proj_matrix.T)

                    model.model.layers[layer_idx].output[0][0, -8] = (model.model.layers[layer_idx].output[0][0, -8] - clean_pos) + corrupt_pos

                    del vec, proj_matrix
                    torch.cuda.empty_cache()

                    pred = model.lm_head.output[0, -1].argmax(dim=-1).save()

            # print(f"Pred: {model.tokenizer.decode([pred]).lower().strip()} | Target: {target}")
            if model.tokenizer.decode([pred]).lower().strip() == target:
                correct += 1
            total += 1

            del corrupt_layer_out, pred
            torch.cuda.empty_cache()

    acc = round(correct / total, 2)
    accs_query_charac_1_second_sv[layer_idx] = acc
    print(f"Layer: {layer_idx} | Accuracy: {acc}")

In [None]:
accs_query_charac_1_second_sv = dict(sorted(accs_query_charac_1_second_sv.items(), key=lambda x:x[0]))
accs_query_charac_1_second_sv

In [None]:
# Sort accs_query_charac_second_sv by key
accs_query_charac_second_sv = dict(sorted(accs_query_charac_second_sv.items(), key=lambda x: x[0]))
accs_query_charac_second_sv

## Result Visualization

In [None]:
true_stories = [
    {
        "story": dataset[0]["corrupt_story"],
        "question": dataset[0]["corrupt_question"],
        "answer": dataset[0]["corrupt_ans"],
    },
    {
        "story": dataset[0]["clean_story"],
        "question": dataset[0]["clean_question"],
        "answer": dataset[0]["clean_ans"],
    }
]

arrows = [{'start': token_pos_coords['e1_query_charac'], 'end': token_pos_coords['e2_query_charac'], 'color': 'red'}]

plot_data = {
    "labels": accs_query_charac_second_sv.keys(),
    "acc_one_layer": accs_query_charac_second_sv.values(),
    "title": "Aligning Query Character Position info",
    "x_label": "Layers",
    "y_label": "Intervention Accuracy",
}

characters = list(set(dataset[0]['clean_characters'] + dataset[0]['corrupt_characters']))
objects = list(set(dataset[0]['clean_objects'] + dataset[0]['corrupt_objects']))
states = list(set(dataset[0]['clean_states'] + dataset[0]['corrupt_states']))

generator = StoryGenerator(characters=characters, objects=objects, states=states, stories=true_stories, target=dataset[0]['target'], arrows=arrows, plot_data=plot_data)
generator.save_html(filename="../plots/belief_exps/second_obj/query_charac_2_sv_2.html")

# Causal Intervention for Object Position Info

In [None]:
n_samples = 20
batch_size = 1

dataset = query_obj_pos(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             n_samples,
                             question_type="belief_question",
                             diff_visibility=False)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
idx = 0
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

## Error Detection

In [None]:
correct, total = 0, 0
errors = []
with torch.no_grad():
    for bi, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        clean_prompt = batch['clean_prompt'][0]
        corrupt_prompt = batch['corrupt_prompt'][0]
        clean_target = batch['clean_ans'][0]
        corrupt_target = batch['corrupt_ans'][0]

        with model.trace(clean_prompt):
            clean_pred = model.lm_head.output[0, -1].argmax(dim=-1).item().save()

        with model.trace(corrupt_prompt):
            corrupt_pred = model.lm_head.output[0, -1].argmax(dim=-1).item().save()

        print(f"Clean: {model.tokenizer.decode([clean_pred]).lower().strip()} | Corrupt: {model.tokenizer.decode([corrupt_pred]).lower().strip()}")
        if model.tokenizer.decode([clean_pred]).lower().strip() == clean_target and model.tokenizer.decode([corrupt_pred]).lower().strip() == corrupt_target:
            correct += 1
        else:
            errors.append(bi)
        total += 1
       
        del clean_pred, corrupt_pred
        torch.cuda.empty_cache()

print(f"Accuracy: {round(correct / total, 2)}")
print(f"correct: {correct} | total: {total}")

## Patching with Singular Vectors

In [None]:
singular_vecs = {}
for l in range(41):
    singular_vecs[l] = torch.load(f"../svd_results/obj_pos/Vh_{l}.pt").cpu()

In [None]:
# accs_query_obj_2_sv_2 = {}

for layer_idx in range(0, 10, 10):
    correct, total = 0, 0
    for bi, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        # if bi in errors:
        #     continue
        corrupt_prompt = batch["corrupt_prompt"][0]
        clean_prompt = batch["clean_prompt"][0]
        target = batch["target"][0]

        corrupt_layer_out, clean_layer_out = defaultdict(dict), defaultdict(dict)
        with torch.no_grad():

            with model.trace() as tracer:
                with tracer.invoke(corrupt_prompt):
                    for l in range(40):
                        corrupt_layer_out[l] = model.model.layers[l].output[0][0, -5].save()

                with tracer.invoke(clean_prompt):
                    for l in range(40):
                        vec = singular_vecs[l][2:4, :].t().half().cuda()
                        # Calculate a projection matrix using the outer product of the singular vector
                        proj_matrix = torch.matmul(vec, vec.t())

                        corrupt_pos = torch.matmul(corrupt_layer_out[l], proj_matrix.T)
                        clean_pos = torch.matmul(model.model.layers[l].output[0][0, -5], proj_matrix.T)

                        # Find cosine similarity between the clean and corrupt position
                        cos_sim = torch.nn.functional.cosine_similarity(clean_pos, corrupt_pos, dim=0)
                        tracer.log(f"cosine_similarity_{l}", cos_sim.item())

                        model.model.layers[l].output[0][0, -5] = (model.model.layers[l].output[0][0, -5] - clean_pos) + corrupt_pos

                        del vec, proj_matrix
                        torch.cuda.empty_cache()

                    pred = model.lm_head.output[0, -1].argmax(dim=-1).save()

            print(f"Pred: {model.tokenizer.decode([pred]).lower().strip()} | Target: {target}")
            if model.tokenizer.decode([pred]).lower().strip() == target:
                correct += 1
            total += 1

            del corrupt_layer_out, pred
            torch.cuda.empty_cache()

    acc = round(correct / total, 2)
    # accs_query_obj_2_sv_2[layer_idx] = acc
    print(f"Layer: {layer_idx} | Accuracy: {acc}")

In [None]:
accs_query_obj_2_sv_2_3_from = dict(sorted(accs_query_obj_2_sv_2_3_from.items()))
accs_query_obj_2_sv_2_3_from

In [None]:
accs_query_obj_2_sv_2_3 = dict(sorted(accs_query_obj_2_sv_2_3.items()))
accs_query_obj_2_sv_2_3

In [None]:
accs_query_obj_sv_3_4 = dict(sorted(accs_query_obj_sv_3_4.items(), key=lambda x: x[0]))
accs_query_obj_sv_3_4

## Result Visualization

In [None]:
true_stories = [
    {
        "story": dataset[0]["corrupt_story"],
        "question": dataset[0]["corrupt_question"],
        "answer": dataset[0]["corrupt_ans"],
    },
    {
        "story": dataset[0]["clean_story"],
        "question": dataset[0]["clean_question"],
        "answer": dataset[0]["clean_ans"],
    }
]

arrows = [{'start': token_pos_coords['e1_query_obj_belief'], 'end': token_pos_coords['e2_query_obj_belief'], 'color': 'red'}]

plot_data = {
    "labels": accs_query_obj_2_sv_2_3.keys(),
    "acc_upto_layer": accs_query_obj_2_sv_2_3.values(),
    "acc_from_layer": accs_query_obj_2_sv_2_3_from.values(),
    "title": "Aligning Query Object Position info",
    "x_label": "Layers",
    "y_label": "Intervention Accuracy",
}

characters = list(set(dataset[0]['clean_characters'] + dataset[0]['corrupt_characters']))
objects = list(set(dataset[0]['clean_objects'] + dataset[0]['corrupt_objects']))
states = list(set(dataset[0]['clean_states'] + dataset[0]['corrupt_states']))

generator = StoryGenerator(characters=characters, objects=objects, states=states, stories=true_stories, target=dataset[0]['target'], arrows=arrows, plot_data=plot_data)
generator.save_html(filename="../plots/belief_exps/second_obj/query_obj_2_sv_1_2.html")

# Causal Intervention for Visibility Info

In [None]:
n_samples = 20
batch_size = 1

dataset = get_visibility_align_exps(STORY_TEMPLATES,
                             all_characters,
                             all_containers,
                             all_states,
                             n_samples,
                             question_type="belief_question",
                             diff_visibility=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
idx = 0
print(dataset[idx]['corrupt_prompt'], dataset[idx]['corrupt_ans'])
print(dataset[idx]['clean_prompt'], dataset[idx]['clean_ans'])
print(f"Target: {dataset[idx]['target']}")

## Error Detection

In [None]:
correct, total = 0, 0
errors = []
for bi, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
    clean_prompt = batch['clean_prompt'][0]
    corrupt_prompt = batch['corrupt_prompt'][0]
    clean_target = batch['clean_ans'][0]
    corrupt_target = batch['corrupt_ans'][0]

    with torch.no_grad():

        with model.trace() as tracer:

            with tracer.invoke(clean_prompt):
                clean_pred = model.lm_head.output[0, -1].argmax(dim=-1).item().save()

            with tracer.invoke(corrupt_prompt):
                corrupt_pred = model.lm_head.output[0, -1].argmax(dim=-1).item().save()

    print(f"Clean: {model.tokenizer.decode([clean_pred]).lower().strip()} | Corrupt: {model.tokenizer.decode([corrupt_pred]).lower().strip()}")
    if model.tokenizer.decode([clean_pred]).lower().strip() == clean_target and model.tokenizer.decode([corrupt_pred]).lower().strip() == corrupt_target:
        correct += 1
    else:
        errors.append(bi)
    total += 1

    del clean_pred, corrupt_pred
    torch.cuda.empty_cache()

print(f"Accuracy: {round(correct / total, 2)}")
print(f"correct: {correct} | total: {total}")

## Patching with Singular Vectors

In [None]:
sing_vecs = defaultdict(dict)
for l in range(41):
    sing_vecs[l] = torch.load(f"../svd_results/selected_tokens_diff/singular_vecs/{l}.pt").cpu()

In [None]:
token_indices = {
    -8: 0,
    -7: 1,
    -6: 2,
    -5: 3,
    -4: 4,
    -3: 5,
    -2: 6,
    -1: 7,
}

In [None]:
# visibility_diff_reverse_sv = {}

for layer_idx in range(32, 34, 2):
    correct, total = 0, 0
    for bi, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        # if bi in errors:
        #     continue
        corrupt_prompt = batch["clean_prompt"][0]
        clean_prompt = batch["corrupt_prompt"][0]
        target = batch["target"][0]

        corrupt_layer_out, clean_layer_out = defaultdict(dict), defaultdict(dict)
        with torch.no_grad():

            with model.trace() as tracer:

                with tracer.invoke(corrupt_prompt):
                    for t in [-8, -7, -5, -3, -1]:
                        corrupt_layer_out[t] = model.model.layers[layer_idx].output[0][0, t].save()

                with tracer.invoke(clean_prompt):
                    for t in [-8, -7, -5, -3, -1]:
                        vec = sing_vecs[layer_idx][:6, :].t().half()
                        proj_matrix = torch.matmul(vec, vec.t())

                        corrupt_pos_charac = torch.matmul(corrupt_layer_out[t], proj_matrix)
                        clean_pos_charac = torch.matmul(model.model.layers[layer_idx].output[0][0, t], proj_matrix)
                        model.model.layers[layer_idx].output[0][0, t] = (model.model.layers[layer_idx].output[0][0, t] - clean_pos_charac) + corrupt_pos_charac
                        # model.model.layers[layer_idx].output[0][0, t] = corrupt_layer_out[t]

                    del vec, proj_matrix, corrupt_pos_charac
                    torch.cuda.empty_cache()

                    pred = model.lm_head.output[0, -1].argmax(dim=-1).save()

            print(f"Pred: {model.tokenizer.decode([pred]).lower().strip()} | Target: {target}")
            if model.tokenizer.decode([pred]).lower().strip() == "unknown":
                correct += 1
            total += 1

            del corrupt_layer_out, pred
            torch.cuda.empty_cache()

    acc = round(correct / total, 2)
    # visibility_diff_reverse_sv[layer_idx] = acc
    print(f"Layer: {layer_idx} | Accuracy: {acc}")

In [None]:
# Sort accs_visibility_full by key
visibility_diff_reverse_sv = dict(sorted(visibility_diff_reverse_sv.items(), key=lambda x: x[0]))
visibility_diff_reverse_sv

## Result Visualization

In [None]:
true_stories = [
    {
        "story": dataset[0]["clean_story"],
        "question": dataset[0]["clean_question"],
        "answer": dataset[0]["clean_ans"],
    },
    {
        "story": dataset[0]["corrupt_story"],
        "question": dataset[0]["corrupt_question"],
        "answer": dataset[0]["corrupt_ans"],
    }
]

arrows = [{'start': token_pos_coords['e1_query_charac'], 'end': token_pos_coords['e2_query_charac'], 'color': 'red'}]

plot_data = {
    "labels": visibility_diff_reverse_sv.keys(),
    "acc_one_layer": visibility_diff_reverse_sv.values(),
    "title": "Aligning Visibility info",
    "x_label": "Layers",
    "y_label": "Intervention Accuracy",
}

characters = list(set(dataset[0]['clean_characters'] + dataset[0]['corrupt_characters']))
objects = list(set(dataset[0]['clean_objects'] + dataset[0]['corrupt_objects']))
states = list(set(dataset[0]['clean_states'] + dataset[0]['corrupt_states']))

generator = StoryGenerator(characters=characters, objects=objects, states=states, stories=true_stories, target="unknown", arrows=[], plot_data=plot_data)
generator.save_html(filename="../plots/visibility_exps/first_obj/visibility_diff_reverse_sv.html")

# Intervening on BigToM

## Helper Methods

In [None]:
sing_vecs = defaultdict(dict)
for l in range(41):
    sing_vecs[l] = torch.load(f"../svd_results/bigtom/singular_vecs/{l}.pt").cpu()
    # sing_vecs[l] = torch.load(f"../svd_results/selected_tokens_diff/singular_vecs/{l}.pt").cpu()

In [None]:
def get_ques_start_token_idx(tokenizer, prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    corrolary_token = tokenizer.encode(":", return_tensors="pt").squeeze()[-1].item()
    ques_start_idx = (input_tokens == corrolary_token).nonzero()[2].item()

    return ques_start_idx-1

In [None]:
def get_prompt_token_len(tokenizer, prompt):
    input_tokens = tokenizer.encode(prompt, return_tensors="pt").squeeze()
    return len(input_tokens)

In [None]:
def check_pred(pred, target, verbose=False):
    prompt = f"Instruction: Check if the following ground truth and prediction are same or different. If they are the same, then predict 'Yes', else 'No' \n\nGround truth: {target}\nPrediction: {pred}\nAnswer:"
    
    if verbose:
        print(prompt)

    with torch.no_grad():
        with model.generate(prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id):
            out = model.generator.output.save()

    prompt_len = get_prompt_token_len(model.tokenizer, prompt)

    return model.tokenizer.decode(out[0][prompt_len:-1]).strip()

## Loading BigToM data

In [None]:
# Read a csv file
df_false = pd.read_csv("../data/bigtom/0_forward_belief_false_belief/stories.csv", delimiter=";")
df_true = pd.read_csv("../data/bigtom/0_forward_belief_true_belief/stories.csv", delimiter=";")

In [None]:
# For each row in the dataframe extract story, answer, and distractor
true_stories, false_stories = [], []
for i in range(len(df_true)):
    story = df_true.iloc[i]['story']
    question = df_true.iloc[i]['question']
    answer = df_true.iloc[i]['answer']
    distractor = df_true.iloc[i]['distractor']
    true_stories.append({"story": story, "question": question, "answer": answer, "distractor": distractor})

for i in range(len(df_false)):
    story = df_false.iloc[i]['story']
    question = df_true.iloc[i]['question']
    answer = df_false.iloc[i]['answer']
    distractor = df_false.iloc[i]['distractor']
    false_stories.append({"story": story, "question": question, "answer": answer, "distractor": distractor})

dataset = []
instruction = "1. Track the belief of each character as described in the story. 2. A character's belief is formed only when they perform an action themselves or can observe the action taking place. 3. A character does not have any belief about the container or its content which they cannot observe directly. 4. To answer the question, predict only the final state of the queried container in fewest tokens possible, strictly based on the belief of the character, mentioned in the question. 5. Do not predict the entire sentence with character or container as the final output."

for i in range(min(len(true_stories), len(false_stories))):
    question = true_stories[i]['question']
    visible_prompt = f"Instructions: {instruction}\n\nStory: {true_stories[i]['story']}\nQuestion: {question}\nAnswer:"

    question = false_stories[i]['question']
    invisible_prompt = f"Instructions: {instruction}\n\nStory: {false_stories[i]['story']}\nQuestion: {question}\nAnswer:"

    visible_ans = true_stories[i]['answer'].split()
    invisible_ans = false_stories[i]['answer'].split()

    # Find the index of first word which is different in both answers
    diff_idx = 0
    for idx, (v, j) in enumerate(zip(visible_ans, invisible_ans)):
        if v != j:
            diff_idx = idx
            break
    
    visible_ans = " ".join(visible_ans[diff_idx:])[:-1]
    invisible_ans = " ".join(invisible_ans[diff_idx:])[:-1]

    dataset.append({
        "visible_story": true_stories[i]['story'],
        "visible_question": true_stories[i]['question'],
        "visible_prompt": visible_prompt,
        "visible_ans": visible_ans,
        "invisible_story": false_stories[i]['story'],
        "invisible_question": false_stories[i]['question'],
        "invisible_prompt": invisible_prompt,
        "invisible_ans": invisible_ans,
    })

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

In [None]:
idx = 4
print(dataset[idx]['visible_prompt'], dataset[idx]['visible_ans'])
print(dataset[idx]['invisible_prompt'], dataset[idx]['invisible_ans'])

In [None]:
# errors = []
# for bi, batch in enumerate(dataloader):
#     visible_prompt = batch['visible_prompt'][0]
#     visible_ans = batch['visible_ans'][0]
#     invisible_prompt = batch['invisible_prompt'][0]
#     invisible_ans = batch['invisible_ans'][0]

#     with torch.no_grad():
#         with model.generate(visible_prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id):
#             vis_out = model.generator.output.save()

#         with model.generate(invisible_prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id):
#             invis_out = model.generator.output.save()

#     vis_prompt_len = get_prompt_token_len(model.tokenizer, visible_prompt)
#     invis_prompt_len = get_prompt_token_len(model.tokenizer, invisible_prompt)

#     vis_check = check_pred(model.tokenizer.decode(vis_out[0][vis_prompt_len:-1]), visible_ans, verbose=True)
#     invis_check = check_pred(model.tokenizer.decode(invis_out[0][invis_prompt_len:-1]), invisible_ans, verbose=True)

#     print(f"Bi: {bi} | Visible: {vis_check} | Invisible: {invis_check}\n")

#     if vis_check == "No" or invis_check == "No":
#         errors.append(bi)

## Patching Experiments

In [None]:
# accs_sv = {}

for layer_idx in range(22, 26, 2):
    correct, total = 0, 0

    for bi, batch in tqdm(enumerate(dataloader), total=20):
        if bi > 19:
            break
        visible_prompt = batch['visible_prompt'][0]
        visible_ans = batch['visible_ans'][0]
        invisible_prompt = batch['invisible_prompt'][0]
        invisible_ans = batch['invisible_ans'][0]

        visible_ques_idx = get_ques_start_token_idx(model.tokenizer, visible_prompt)
        visible_prompt_len = get_prompt_token_len(model.tokenizer, visible_prompt)
        invisible_ques_idx = get_ques_start_token_idx(model.tokenizer, invisible_prompt)
        invisible_prompt_len = get_prompt_token_len(model.tokenizer, invisible_prompt)

        with torch.no_grad():
            with model.session() as session:

                visible_layer_out = defaultdict(dict)
                with model.trace(visible_prompt):
                    for t_idx, t in enumerate(range(visible_ques_idx, visible_prompt_len)):
                        visible_layer_out[t_idx] = model.model.layers[layer_idx].output[0][0, t].save()

                with model.generate(invisible_prompt, max_new_tokens=5, do_sample=False, num_return_sequences=1, pad_token_id=model.tokenizer.pad_token_id, eos_token_id=model.tokenizer.eos_token_id):
                    vec = sing_vecs[layer_idx][:100, :].t().half()
                    proj_matrix = torch.matmul(vec, vec.t())

                    for t_idx, t in enumerate(range(invisible_ques_idx, invisible_prompt_len)):
                        corrupt_pos_charac = torch.matmul(visible_layer_out[t_idx], proj_matrix)
                        clean_pos_charac = torch.matmul(model.model.layers[layer_idx].output[0][0, t], proj_matrix)

                        model.model.layers[layer_idx].output[0][0, t] = (model.model.layers[layer_idx].output[0][0, t] - clean_pos_charac) + corrupt_pos_charac

                        # model.model.layers[layer_idx].output[0][0, t] = visible_layer_out[t_idx]

                    out = model.generator.output.save()

                del visible_layer_out, vec, proj_matrix, corrupt_pos_charac, clean_pos_charac
                torch.cuda.empty_cache()

            out_check = check_pred(model.tokenizer.decode(out[0][invisible_prompt_len:-1]), visible_ans, verbose=True)
            print(f"Output check: {out_check}\n")

            if out_check == "Yes":
                correct += 1
            total += 1
        
    print(f"Layer: {layer_idx} | Accuracy: {round(correct / total, 2)}")
    accs_sv[layer_idx] = round(correct / total, 2)

In [None]:
# Sort accs by key
accs_sv = dict(sorted(accs_sv.items(), key=lambda x: x[0]))
accs_sv

## Result Visualization

In [None]:
true_stories = [
    {
        "story": dataset[0]["visible_story"],
        "question": dataset[0]["visible_question"],
        "answer": dataset[0]["visible_ans"],
    },
    {
        "story": dataset[0]["invisible_story"],
        "question": dataset[0]["invisible_question"],
        "answer": dataset[0]["invisible_ans"],
    }
]

arrows = [{'start': token_pos_coords['e1_query_charac'], 'end': token_pos_coords['e2_query_charac'], 'color': 'red'}]

plot_data = {
    "labels": accs_sv.keys(),
    "acc_one_layer": accs_sv.values(),
    "title": "Aligning Visibility info",
    "x_label": "Layers",
    "y_label": "Intervention Accuracy",
}

# characters = list(set(dataset[0]['clean_characters'] + dataset[0]['corrupt_characters']))
# objects = list(set(dataset[0]['clean_objects'] + dataset[0]['corrupt_objects']))
# states = list(set(dataset[0]['clean_states'] + dataset[0]['corrupt_states']))

generator = StoryGenerator(characters=["Noor"], objects=["pitcher"], states=["oat", "almond"], stories=true_stories, target=dataset[0]["visible_ans"], arrows=[], plot_data=plot_data)
generator.save_html(filename="../plots/visibility_exps/bigtom/invisibility_to_visbility_sv.html")