In [1]:
%load_ext autoreload
%autoreload 2

## Takeaways
* Low rank swapping is working!
* The edit seems effective at different range of layers.
    * actor layers 13-17 (rank 3)
    * object layers 24-28 (rank 3)
    * container is successful at a range of layers from 13-28
* Projections aren't effective across types - actor projections aren't causaly effective for objects. ==> The subspaces are different.

TODO: 
* Patch with different entities (that doesn't appear in the SVD extraction set) -- DONE
* Refine to codebase to be consistent with Tamar's templates.
    * Talk with Tamar about the tokenization issues
* Patch across templates
    * Projections do generalize. But for objects need higher rank projections (rank 10).
* Universal mechanisms for binding ID
* Calculate across different templates
* observation vs the causal event == how do the bindings IDs change?
* python psudocode that solves the problem and check equivalency

Questions:
* But patching the whole `h_other` instead of `h_patch` will also flip the answer (Nikhil's experiment). How to make sure that we are only bringing the binding ID and nothing else? 
   * The same projection can be used to remove the binding ID info. If binding ID is really important for the task the LM should give random answers without it. (?? verify this) - Not very convincing.

In [2]:
import os, time, json
import pandas as pd
from openai import OpenAI
from tqdm.auto import tqdm
import spacy
import pandas as pd
import numpy as np

import sys
sys.path.append("../../")
import os
from src.functional import free_gpu_cache, predict_next_token
from src.utils import env_utils

from src.models import load_LM
import torch
from nnsight import LanguageModel
from src.utils import env_utils

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
MODEL_KEY = "meta-llama/Meta-Llama-3-70B-Instruct"
svd_path = os.path.join(
    env_utils.DEFAULT_RESULTS_DIR, "SVDs",
    MODEL_KEY.split("/")[-1]
)

In [4]:
svd = np.load(os.path.join(svd_path, "characters", "projections", "model.layers.10.npz"))
Vh = torch.Tensor(svd['Vh'])
Vh.shape

torch.Size([8192, 8192])

In [5]:
lm = load_LM(
    model_key=MODEL_KEY,
    torch_dtype=torch.float16,
    load_in_4bit=True,
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
Loading checkpoint shards: 100%|██████████| 30/30 [00:59<00:00,  2.00s/it]

loaded /home/local_arnab/Codes/00_MODEL/meta-llama/Meta-Llama-3-70B-Instruct | size: 36650.535 MB





In [6]:
def patch_binding_ID(
    h_orig: torch.Tensor,
    h_other: torch.Tensor,
    Proj: torch.Tensor,
    rank: int = 10
):
    # The prediction will also change if the full h_other is patched in stead of h_orig
    h_orig = h_orig.to(torch.float32)
    h_other = h_other.to(torch.float32)
    Proj = Proj.to(torch.float32)

    Proj_r = Proj[:rank].to(lm.device)
    Proj_r = Proj_r.T @ Proj_r
    return (
        (torch.eye(Proj_r.shape[0]).to(lm.device) - Proj_r) @ h_orig # remove patching ID from s_orig
        + Proj_r @ h_other # add patching ID from s_other
    )


def load_Vh(type: str, layer_name: str):
    svd = np.load(os.path.join(svd_path, type, "projections", f"{layer_name}.npz"))
    Vh = torch.Tensor(svd['Vh'])
    return Vh

In [7]:
# lm.tokenizer.tokenize("fountain pen")
# lm.tokenizer.tokenize("fountain pen")

['f', 'ountain', 'Ġpen']

In [8]:
from src.dataset import SampleV3, DatasetV3

sample = SampleV3(
    protagonist="Nora",
    perpetrator="Jon",
    objects=["red velvet handkerchief", "blue silk handkerchief"],
    containers=["hat", "box"],
    event_idx=0,
    event_noticed=False
)

print(f"{sample.true_state=}")
print(f"{sample.protagonist_belief=}")

dataset = DatasetV3(samples = [sample])

sample.true_state={'hat': 'blue silk handkerchief', 'box': 'blue silk handkerchief'}
sample.protagonist_belief={'hat': 'red velvet handkerchief', 'box': 'blue silk handkerchief'}


## Change Actors

In [9]:
from scripts.collect_binding_id_states import collect_token_latent_in_question

prompt_perp, answer_perp = dataset.__getitem__(
    0, 
    set_actor="perpetrator",
    # set_ans="no",
    set_container=0,
    set_obj=0
)
print(prompt_perp)
print(answer_perp)

h_perpetrator = collect_token_latent_in_question(
    lm = lm,
    prompt = prompt_perp,
    answer = answer_perp,
    token_of_interest= sample.perpetrator,
    detensorize=False
)

You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


Instruction: Keep track of people's knowledge defined in the story. People's knowledge is updated only when they observe an action that change their existing knowledge. To answer the question following the story, choose "yes" or "no" after the "Answer:" tag.

Story: Nora is a magician performing at a grand theater. Nora wants to amaze the audience with a trick involving a red velvet handkerchief. Nora places the red velvet handkerchief in a hat and sets it on the stage. Then Nora prepares a backup box and places a blue silk handkerchief inside. An assistant named Jon, who thinks the trick should be different, swaps the red velvet handkerchief in the hat with the blue silk handkerchief while Nora is backstage. Nora didn't see Jon swapping the the contents of hat.
Question: Does Jon believe the hat contains red velvet handkerchief?
Answer:
no
168  Jon




answer='no' | predicted_ans=PredictedToken(token=' No', prob=0.47647860646247864, logit=21.515625, token_id=2360)


In [10]:
prompt_prot, answer_prot = dataset.__getitem__(
    0, 
    set_actor="protagonist",
    # set_ans="no",
    set_container=0,
    set_obj=0
)
print(prompt_prot)
print(answer_prot)

h_protagonist = collect_token_latent_in_question(
    lm = lm,
    prompt = prompt_prot,
    answer = answer_prot,
    token_of_interest= sample.protagonist,
    detensorize=False
)

Instruction: Keep track of people's knowledge defined in the story. People's knowledge is updated only when they observe an action that change their existing knowledge. To answer the question following the story, choose "yes" or "no" after the "Answer:" tag.

Story: Nora is a magician performing at a grand theater. Nora wants to amaze the audience with a trick involving a red velvet handkerchief. Nora places the red velvet handkerchief in a hat and sets it on the stage. Then Nora prepares a backup box and places a blue silk handkerchief inside. An assistant named Jon, who thinks the trick should be different, swaps the red velvet handkerchief in the hat with the blue silk handkerchief while Nora is backstage. Nora didn't see Jon swapping the the contents of hat.
Question: Does Nora believe the hat contains red velvet handkerchief?
Answer:
yes
168  Nora
answer='yes' | predicted_ans=PredictedToken(token=' YES', prob=0.61894690990448, logit=22.71875, token_id=14410)


In [12]:
from src.functional import PatchSpec, free_gpu_cache

LAYER_NAME_FORMAT = "model.layers.{}"
QUESTION_ACTOR_IDX = 168

free_gpu_cache()

patches = []
for layer_idx in range(13, 20):
    layer = LAYER_NAME_FORMAT.format(layer_idx)
    Vh = load_Vh("characters", layer)
    location = (layer, QUESTION_ACTOR_IDX)
    patches.append(PatchSpec(
        location=location,
        patch=patch_binding_ID(
            h_orig=h_protagonist.states[location],
            h_other=h_perpetrator.states[location],
            Proj=Vh,
            rank=3
        )
    ))

predict_next_token(
    lm=lm, 
    input=prompt_prot,
    patches=patches,
    interested_tokens=[2360]
)

([PredictedToken(token=' No', prob=0.4544861316680908, logit=21.1875, token_id=2360),
  PredictedToken(token=' NO', prob=0.41381433606147766, logit=21.09375, token_id=5782),
  PredictedToken(token=' no', prob=0.08949362486600876, logit=19.5625, token_id=912),
  PredictedToken(token=' **', prob=0.030928170308470726, logit=18.5, token_id=3146),
  PredictedToken(token='No', prob=0.0033632824197411537, logit=16.28125, token_id=2822)],
 {2360: (1,
   PredictedToken(token=' No', prob=0.4544861316680908, logit=21.1875, token_id=2360))})

In [13]:
free_gpu_cache()

patches = []
for layer_idx in range(13, 20):
    layer = LAYER_NAME_FORMAT.format(layer_idx)
    Vh = load_Vh("characters", layer)
    location = (layer, QUESTION_ACTOR_IDX)
    patches.append(PatchSpec(
        location=location,
        patch=patch_binding_ID(
            h_orig=h_perpetrator.states[location],
            h_other=h_protagonist.states[location],
            Proj=Vh,
            rank=3
        )
    ))

predict_next_token(
    lm=lm, 
    input=prompt_perp,
    patches=patches,
    interested_tokens=[14410]
)

([PredictedToken(token=' YES', prob=0.6241050362586975, logit=22.640625, token_id=14410),
  PredictedToken(token=' yes', prob=0.1816249042749405, logit=21.40625, token_id=10035),
  PredictedToken(token=' Yes', prob=0.17880909144878387, logit=21.390625, token_id=7566),
  PredictedToken(token=' **', prob=0.007265910971909761, logit=18.1875, token_id=3146),
  PredictedToken(token='Yes', prob=0.003950407262891531, logit=17.578125, token_id=9642)],
 {14410: (1,
   PredictedToken(token=' YES', prob=0.6241050362586975, logit=22.640625, token_id=14410))})

## Change Objects

In [17]:
from scripts.collect_binding_id_states import collect_token_latent_in_question

prompt_obj0, answer_obj0 = dataset.__getitem__(
    0, 
    set_actor="protagonist",
    # set_ans="no",
    set_container=0,
    set_obj=0
)
print(prompt_obj0)
print(answer_obj0)

h_obj0 = collect_token_latent_in_question(
    lm = lm,
    prompt = prompt_obj0,
    answer = answer_obj0,
    token_of_interest= sample.objects[0],
    detensorize=False
)

Instruction: Keep track of people's knowledge defined in the story. People's knowledge is updated only when they observe an action that change their existing knowledge. To answer the question following the story, choose "yes" or "no" after the "Answer:" tag.

Story: Nora is a magician performing at a grand theater. Nora wants to amaze the audience with a trick involving a red velvet handkerchief. Nora places the red velvet handkerchief in a hat and sets it on the stage. Then Nora prepares a backup box and places a blue silk handkerchief inside. An assistant named Jon, who thinks the trick should be different, swaps the red velvet handkerchief in the hat with the blue silk handkerchief while Nora is backstage. Nora didn't see Jon swapping the the contents of hat.
Question: Does Nora believe the hat contains red velvet handkerchief?
Answer:
yes
(173, 178)  red| velvet| hand|ker|chief
answer='yes' | predicted_ans=PredictedToken(token=' YES', prob=0.61894690990448, logit=22.71875, token_id

In [18]:
prompt_obj1, answer_obj1= dataset.__getitem__(
    0, 
    set_actor="protagonist",
    # set_ans="no",
    set_container=0,
    set_obj=1
)
print(prompt_obj1)
print(answer_obj1)

h_obj1 = collect_token_latent_in_question(
    lm = lm,
    prompt = prompt_obj1,
    answer = answer_obj1,
    token_of_interest= sample.objects[1],
    detensorize=False
)

Instruction: Keep track of people's knowledge defined in the story. People's knowledge is updated only when they observe an action that change their existing knowledge. To answer the question following the story, choose "yes" or "no" after the "Answer:" tag.

Story: Nora is a magician performing at a grand theater. Nora wants to amaze the audience with a trick involving a red velvet handkerchief. Nora places the red velvet handkerchief in a hat and sets it on the stage. Then Nora prepares a backup box and places a blue silk handkerchief inside. An assistant named Jon, who thinks the trick should be different, swaps the red velvet handkerchief in the hat with the blue silk handkerchief while Nora is backstage. Nora didn't see Jon swapping the the contents of hat.
Question: Does Nora believe the hat contains blue silk handkerchief?
Answer:
no
(173, 178)  blue| silk| hand|ker|chief
answer='no' | predicted_ans=PredictedToken(token=' No', prob=0.552524983882904, logit=21.484375, token_id=23

In [19]:
h_obj1.token_position

(173, 178)

In [20]:
free_gpu_cache()

patches = []
for layer_idx in range(20, 24):
    layer = LAYER_NAME_FORMAT.format(layer_idx)
    Vh = load_Vh("objects", layer)
    for t in range(*h_obj1.token_position):
        location = (layer, t)
        patches.append(PatchSpec(
            location=location,
            patch=patch_binding_ID(
                h_orig=h_obj0.states[location],
                h_other=h_obj1.states[location],
                Proj=Vh,
                rank=10
            )
        ))

predict_next_token(
    lm=lm, 
    input=prompt_obj0,
    patches=patches,
    interested_tokens=[2360]
)

([PredictedToken(token=' No', prob=0.47250068187713623, logit=20.28125, token_id=2360),
  PredictedToken(token=' NO', prob=0.3978855013847351, logit=20.109375, token_id=5782),
  PredictedToken(token=' no', prob=0.0629546269774437, logit=18.265625, token_id=912),
  PredictedToken(token=' **', prob=0.04752064496278763, logit=17.984375, token_id=3146),
  PredictedToken(token='No', prob=0.003693138714879751, logit=15.4296875, token_id=2822)],
 {2360: (1,
   PredictedToken(token=' No', prob=0.47250068187713623, logit=20.28125, token_id=2360))})

In [22]:
# patches = []
# for layer_idx in range(20, 24):
#     layer = LAYER_NAME_FORMAT.format(layer_idx)
#     Vh = load_Vh("objects", layer)
#     location = (layer, QUESTION_OBJ_IDX)
#     patches.append(PatchSpec(
#         location=location,
#         patch=patch_binding_ID(
#             h_orig=h_obj1.states[location],
#             h_other=h_obj0.states[location],
#             Proj=Vh,
#             rank=10
#         )
#     ))

patches = []
for layer_idx in range(20, 24):
    layer = LAYER_NAME_FORMAT.format(layer_idx)
    Vh = load_Vh("objects", layer)
    for t in range(*h_obj1.token_position):
        location = (layer, t)
        patches.append(PatchSpec(
            location=location,
            patch=patch_binding_ID(
                h_orig=h_obj1.states[location],
                h_other=h_obj0.states[location],
                Proj=Vh,
                rank=10
            )
        ))


predict_next_token(
    lm=lm, 
    input=prompt_obj1,
    patches=patches,
    interested_tokens=[14410]
)

([PredictedToken(token=' YES', prob=0.6483052968978882, logit=21.953125, token_id=14410),
  PredictedToken(token=' Yes', prob=0.17448900640010834, logit=20.640625, token_id=7566),
  PredictedToken(token=' yes', prob=0.158874049782753, logit=20.546875, token_id=10035),
  PredictedToken(token=' **', prob=0.007666510995477438, logit=17.515625, token_id=3146),
  PredictedToken(token='Yes', prob=0.004168210085481405, logit=16.90625, token_id=9642)],
 {14410: (1,
   PredictedToken(token=' YES', prob=0.6483052968978882, logit=21.953125, token_id=14410))})

In [None]:
predict_next_token(
    lm=lm, 
    input=prompt_obj1,
    # patches=patches,
    interested_tokens=[14410]
)

## Change Containers

In [70]:
prompt_cont0, answer_cont0 = dataset.__getitem__(
    0, 
    set_actor="protagonist",
    # set_ans="no",
    set_container=0,
    set_obj=0
)

print(prompt_cont0)
print(answer_cont0)

h_cont0 = collect_token_latent_in_question(
    lm = lm,
    prompt = prompt_cont0,
    answer = answer_cont0,
    token_of_interest= sample.containers[0],
    detensorize=False
)

Instruction: Keep track of people's knowledge defined in the story. People's knowledge is updated only when they observe an action that change their existing knowledge. To answer the question following the story, choose "yes" or "no" after the "Answer:" tag.

Story: Nora is a magician performing at a grand theater. Nora wants to amaze the audience with a trick involving a knife. Nora places the knife in a hat and sets it on the stage. Then Nora prepares a backup box and places a pen inside. An assistant named Jon, who thinks the trick should be different, swaps the knife in the hat with the pen while Nora is backstage. Nora didn't see Jon swapping the the contents of hat.
Question: Does Nora believe the hat contains knife?
Answer:
yes
151  hat
answer='yes' | predicted_ans=PredictedToken(token=' YES', prob=0.5597575306892395, logit=22.15625, token_id=14410)


In [71]:
prompt_cont1, answer_cont1 = dataset.__getitem__(
    0, 
    set_actor="protagonist",
    # set_ans="no",
    set_container=1,
    set_obj=0
)

print(prompt_cont1)
print(answer_cont1)

h_cont1 = collect_token_latent_in_question(
    lm = lm,
    prompt = prompt_cont1,
    answer = answer_cont1,
    token_of_interest= sample.containers[1],
    detensorize=False
)

Instruction: Keep track of people's knowledge defined in the story. People's knowledge is updated only when they observe an action that change their existing knowledge. To answer the question following the story, choose "yes" or "no" after the "Answer:" tag.

Story: Nora is a magician performing at a grand theater. Nora wants to amaze the audience with a trick involving a knife. Nora places the knife in a hat and sets it on the stage. Then Nora prepares a backup box and places a pen inside. An assistant named Jon, who thinks the trick should be different, swaps the knife in the hat with the pen while Nora is backstage. Nora didn't see Jon swapping the the contents of hat.
Question: Does Nora believe the box contains knife?
Answer:
no
151  box
answer='no' | predicted_ans=PredictedToken(token=' YES', prob=0.5212855339050293, logit=21.25, token_id=14410)


In [76]:
free_gpu_cache()

QUESTION_CONT_IDX = 151

patches = []
for layer_idx in range(24, 28):
    layer = LAYER_NAME_FORMAT.format(layer_idx)
    Vh = load_Vh("containers", layer)
    location = (layer, QUESTION_CONT_IDX)
    patches.append(PatchSpec(
        location=location,
        patch=patch_binding_ID(
            h_orig=h_cont0.states[location],
            h_other=h_cont1.states[location],
            Proj=Vh,
            rank=10
        )
    ))

predict_next_token(
    lm=lm, 
    input=prompt_cont0,
    patches=patches,
    interested_tokens=[2360]
)

([PredictedToken(token=' YES', prob=0.5657712817192078, logit=22.203125, token_id=14410),
  PredictedToken(token=' yes', prob=0.2471671849489212, logit=21.375, token_id=10035),
  PredictedToken(token=' Yes', prob=0.1698753386735916, logit=21.0, token_id=7566),
  PredictedToken(token=' **', prob=0.007011593785136938, logit=17.8125, token_id=3146),
  PredictedToken(token='Yes', prob=0.003995085600763559, logit=17.25, token_id=9642)],
 {2360: (23,
   PredictedToken(token=' No', prob=3.538393139024265e-05, logit=12.5234375, token_id=2360))})

In [77]:
free_gpu_cache()

patches = []
for layer_idx in range(24, 28):
    layer = LAYER_NAME_FORMAT.format(layer_idx)
    Vh = load_Vh("containers", layer)
    location = (layer, QUESTION_CONT_IDX)
    patches.append(PatchSpec(
        location=location,
        patch=patch_binding_ID(
            h_orig=h_cont1.states[location],
            h_other=h_cont0.states[location],
            Proj=Vh,
            rank=10
        )
    ))

predict_next_token(
    lm=lm, 
    input=prompt_cont1,
    patches=patches,
    interested_tokens=[14410]
)

([PredictedToken(token=' YES', prob=0.5225372314453125, logit=21.53125, token_id=14410),
  PredictedToken(token=' yes', prob=0.25071609020233154, logit=20.796875, token_id=10035),
  PredictedToken(token=' Yes', prob=0.20145605504512787, logit=20.578125, token_id=7566),
  PredictedToken(token=' **', prob=0.009570603258907795, logit=17.53125, token_id=3146),
  PredictedToken(token='Yes', prob=0.004812401253730059, logit=16.84375, token_id=9642)],
 {14410: (1,
   PredictedToken(token=' YES', prob=0.5225372314453125, logit=21.53125, token_id=14410))})

In [78]:
predict_next_token(
    lm=lm, 
    input=prompt_cont1,
    # patches=patches,
    interested_tokens=[14410]
)

([PredictedToken(token=' YES', prob=0.5212855339050293, logit=21.25, token_id=14410),
  PredictedToken(token=' yes', prob=0.2462378740310669, logit=20.5, token_id=10035),
  PredictedToken(token=' Yes', prob=0.19785767793655396, logit=20.28125, token_id=7566),
  PredictedToken(token=' **', prob=0.012259461916983128, logit=17.5, token_id=3146),
  PredictedToken(token=' NO', prob=0.005272728856652975, logit=16.65625, token_id=5782)],
 {14410: (1,
   PredictedToken(token=' YES', prob=0.5212855339050293, logit=21.25, token_id=14410))})