In [1]:
import sys

sys.path.append('../..')
%load_ext autoreload
%autoreload 2

In [2]:
from torch.utils.data import DataLoader
from datasets import load_from_disk
from src.hyperdas.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset

from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
tokenizer = AutoTokenizer.from_pretrained("/scr-ssd/sjd24/llama3-8b")

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset = load_from_disk("../../experiments/RAVEL/data/city_country_train")
test_dataset = load_from_disk("../../experiments/RAVEL/data/city_country_test")

collate_fn = get_ravel_collate_fn(tokenizer, add_space_before_target=True, contain_entity_position=True, source_suffix_visibility=False, base_suffix_visibility=False)
dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
LAYER = 15

from src.hyperdas.llama3.model import RavelInterpretorHypernetwork

hypernetwork = RavelInterpretorHypernetwork(
    model_name_or_path="/scr-ssd/sjd24/llama3-8b",
    num_editing_heads=32,
    intervention_layer=LAYER,
    subspace_module="ReflectSelect",
    das_dimension=128,
)
hypernetwork = hypernetwork.to("cuda")
# hypernetwork.load_model(f"/nlp/scr/sjd24/HyperDAS_layers/ravel_layer_{LAYER}/final_model")
hypernetwork.load_model("/scr-ssd/sjd24/city_masked/final_model")

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.27it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.86it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  self.interpretor.hypernetwork.load_state_dict(torch.load(os.path.join(load_dir, "hypernetwork.pt")))
  self.interpretor.das_module.load_state_dict(torch.load(os.path.join(load_dir, "das.pt")))


In [5]:
from analysis_utils import get_run_data, get_max_weight_type
dataset = get_run_data(hypernetwork, tokenizer, test_dataset)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [6]:
source_token_distribution = {
    "BOS Token": [],
    "Subject Tokens": [],
    "Sentence Last Token": [],
    "JSON Syntax": [],
    "Country": [],
    "Others": [],
    "Label": [],
}

base_token_distribution = {
    "BOS Token": [],
    "Subject Tokens": [],
    "Sentence Last Token": [],
    "JSON Syntax": [],
    "Country": [],
    "Others": [],
    "Label": [],
}

for d in dataset:
    source_intervention_token_type, source_intervened_token, base_intervention_token_type , base_intervened_token = get_max_weight_type(d, tokenizer)
    source_token_distribution[source_intervention_token_type].append(source_intervened_token)
    base_token_distribution[base_intervention_token_type].append(base_intervened_token)
    
source_data = {k: len(v) for k, v in source_token_distribution.items() if len(v) > 0}
base_data = {k: len(v) for k, v in base_token_distribution.items() if len(v) > 0}


layer_stat = {
    "Source": source_data,
    "Base": base_data,
    "Source Token": source_token_distribution,
    "Base Token": base_token_distribution,
}

import json
json.dump(layer_stat, open(f"layer_{LAYER}_stat_first_version.json", "w"), indent=4)

### Layer 29 Plot

In [19]:
source_data

{'Subject Tokens': 1462, 'Sentence Last Token': 31, 'Others': 51}

In [20]:
base_data

{'Subject Tokens': 780,
 'Sentence Last Token': 446,
 'JSON Syntax': 263,
 'Others': 55}

In [22]:
base_token_distribution["JSON Syntax"]

[' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ':',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ',',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ':',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 ' "',
 '