In [1]:
from datasets import load_from_disk
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, LlamaForCausalLM

from src.hyperdas.data_utils import (
    filter_dataset,
    generate_ravel_dataset,
    get_ravel_collate_fn,
)

%load_ext autoreload
%autoreload 2

tokenizer = AutoTokenizer.from_pretrained("/scr-ssd/sjd24/llama3-8b")

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset = load_from_disk("./experiments/RAVEL/data/city_country_train")
test_dataset = load_from_disk("./experiments/RAVEL/data/city_country_test")

collate_fn = get_ravel_collate_fn(
    tokenizer,
    add_space_before_target=True,
    contain_entity_position=True,
    source_suffix_visibility=False,
    base_suffix_visibility=False,
)
dataloader = DataLoader(
    test_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False
)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [2]:
from src.hyperdas.llama3.model import RavelInterpretorHypernetwork

hypernetwork = RavelInterpretorHypernetwork(
    model_name_or_path="/scr-ssd/sjd24/llama3-8b",
    num_editing_heads=32,
    intervention_layer=15,
    subspace_module="DAS",
    das_dimension=128,
)
hypernetwork = hypernetwork.to("cuda")

import os

import torch

state_dict = torch.load(
    os.path.join(
        "/nlp/scr/sjd24/MDAS_dimension/ravel_mdas_128_country", "final_das_module.pt"
    )
)
state_dict.pop("embed_dim")
state_dict.pop("interchange_dim")

hypernetwork.interpretor.hypernetwork.load_state_dict(
    torch.load(
        os.path.join(
            f"/nlp/scr/sjd24/HyperDAS_layers/ravel_layer_15/final_model",
            "hypernetwork.pt",
        )
    )
)
hypernetwork.interpretor.das_module.load_state_dict(state_dict)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.21it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.64it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
  state_dict = torch.load(os.path.join("/nlp/scr/sjd24/MDAS_dimension/ravel_mdas_128_country", "final_das_module.pt"))
  hypernetwork.interpretor.hypernetwork.load_state_dict(torch.load(os.path.join(f"/nlp/scr/sjd24/HyperDAS_layers/ravel_layer_15/final_model", "hypernetwork.pt")))


<All keys matched successfully>

In [40]:
from pyvene import (
    IntervenableConfig,
    IntervenableModel,
    LowRankRotatedSpaceIntervention,
    RepresentationConfig,
    count_parameters,
)

intervention_config = IntervenableConfig(
    model_type=type(hypernetwork.interpretor.target_model),
    representations=[
        RepresentationConfig(
            15,  # layer
            "block_output",  # intervention repr
            "pos",  # intervention unit
            1,  # max number of unit
            128,
        )
    ],
    intervention_types=LowRankRotatedSpaceIntervention,
)

intervenable = IntervenableModel(
    intervention_config, hypernetwork.interpretor.target_model
)
intervenable.set_device(hypernetwork.interpretor.target_model.device)
intervenable.disable_model_gradients()

intervention_key = list(intervenable.interventions.keys())[0]
intervenable.interventions[intervention_key][0].load_state_dict(
    torch.load(
        os.path.join(
            "/nlp/scr/sjd24/MDAS_dimension/ravel_mdas_128_country",
            "final_das_module.pt",
        )
    )
)

  intervenable.interventions[intervention_key][0].load_state_dict(torch.load(os.path.join("/nlp/scr/sjd24/MDAS_dimension/ravel_mdas_128_country", "final_das_module.pt")))


<All keys matched successfully>

In [4]:
for batch in dataloader:
    break

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


In [41]:
def forward(
    base_input_ids: torch.Tensor = None,
    base_attention_mask: torch.Tensor = None,
    base_intervention_position: torch.Tensor = None,
    base_position_ids: torch.Tensor = None,
    source_input_ids: torch.Tensor = None,
    source_attention_mask: torch.Tensor = None,
    source_intervention_position: torch.Tensor = None,
    source_position_ids: torch.Tensor = None,
    intervention_layer: int = None,
):
    if intervention_layer is None:
        raise ValueError("intervention_layer must be specified")

    if base_position_ids is None:
        # 0 for all the padding tokens and start from 1 for the rest
        base_position_ids = (
            torch.cumsum(base_attention_mask, dim=1) * base_attention_mask - 1
        )

    if source_position_ids is None:
        source_position_ids = (
            torch.cumsum(source_attention_mask, dim=1) * source_attention_mask - 1
        )

    # print(source_intervention_position.unsqueeze(0).shape, base_intervention_position.unsqueeze(0).shape)
    b_s = base_input_ids.shape[0]
    intervention_locations = {
        "sources->base": (
            source_intervention_position.unsqueeze(0).unsqueeze(-1),
            base_intervention_position.unsqueeze(0).unsqueeze(-1),
        )
    }

    _, counterfactual_outputs = intervenable(
        {
            "input_ids": base_input_ids,
            "attention_mask": base_attention_mask,
            "position_ids": base_position_ids,
        },
        [
            {
                "input_ids": source_input_ids,
                "attention_mask": source_attention_mask,
                "position_ids": source_position_ids,
            }
        ],
        intervention_locations,
        output_original_output=True,
    )

    return counterfactual_outputs

In [7]:
batch.keys()

dict_keys(['editor_input_ids', 'is_causal', 'base_input_ids', 'base_attention_mask', 'base_intervention_mask', 'source_input_ids', 'source_attention_mask', 'source_intervention_mask', 'labels', 'source_entity_position_ids', 'base_entity_position_ids'])

In [6]:
editor_input_ids = batch["editor_input_ids"].to("cuda")
is_causal = batch["is_causal"].to("cuda")

base_intervention_position = batch["base_entity_position_ids"].to("cuda")
source_intervention_position = batch["source_entity_position_ids"].to("cuda")

base_input_ids = batch["base_input_ids"].to("cuda")
base_attention_mask = batch["base_attention_mask"].to("cuda")
base_intervention_mask = batch["base_intervention_mask"].to("cuda")

source_input_ids = batch["source_input_ids"].to("cuda")
source_attention_mask = batch["source_attention_mask"].to("cuda")
source_intervention_mask = batch["source_intervention_mask"].to("cuda")

labels = batch["labels"].to("cuda")

intervention_weight = torch.zeros(
    len(batch["editor_input_ids"]),
    batch["source_input_ids"].shape[1] + 1,
    batch["base_input_ids"].shape[1],
).to("cuda")
intervention_weight[:, -1, :] = 1.0

for i in range(len(batch["base_entity_position_ids"])):
    intervention_weight[i, -1, batch["base_entity_position_ids"][i]] = 0.0
    intervention_weight[
        i, batch["source_entity_position_ids"][i], batch["base_entity_position_ids"][i]
    ] = 1.0

In [42]:
mdas_output = forward(
    base_input_ids=base_input_ids,
    base_attention_mask=base_attention_mask,
    base_intervention_position=base_intervention_position,
    source_input_ids=source_input_ids,
    source_attention_mask=source_attention_mask,
    source_intervention_position=source_intervention_position,
    intervention_layer=10,
)

torch.Size([16, 29, 4096])
tensor([-0.0096,  0.0107, -0.0063,  ...,  0.0110,  0.0236,  0.0074],
       device='cuda:0', dtype=torch.bfloat16)


In [19]:
hypernet_output = hypernetwork.forward(
    editor_input_ids=editor_input_ids,
    base_input_ids=base_input_ids,
    base_attention_mask=base_attention_mask,
    base_intervention_mask=base_intervention_mask,
    source_input_ids=source_input_ids,
    source_attention_mask=source_attention_mask,
    source_intervention_mask=source_intervention_mask,
    labels=labels,
    output_intervention_weight=True,
    inference_mode="groundtruth",
    intervention_weight=intervention_weight,
)

nnsight is not detected. Please install via 'pip install nnsight' for nnsight backend.
torch.Size([16, 29, 33, 4096])
tensor([-0.0060,  0.0381, -0.0095,  ...,  0.0767, -0.0615,  0.0269],
       device='cuda:0', dtype=torch.bfloat16)


RuntimeError: No active exception to reraise

In [25]:
hypernet_output.logits[0, 0]

tensor([-4.7500, -2.8594, -3.5469,  ...,  5.0312,  5.0312,  5.0312],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [26]:
mdas_output.logits[0, 0]

tensor([-4.7500, -2.8750, -3.4531,  ...,  5.0000,  5.0000,  5.0000],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [13]:
base_prompt_output = hypernetwork.interpretor.target_model(
    input_ids=base_input_ids,
    attention_mask=base_attention_mask,
    position_ids=torch.cumsum(base_attention_mask, dim=1) * base_attention_mask - 1,
    return_dict=True,
    output_hidden_states=True,
)

In [95]:
base_input_ids.shape

torch.Size([16, 29])

In [14]:
base_prompt_output.hidden_states[15][0, 23]

tensor([-0.0060,  0.0381, -0.0095,  ...,  0.0767, -0.0615,  0.0269],
       device='cuda:0', dtype=torch.bfloat16)

In [41]:
base_logits.shape

torch.Size([16, 29, 128256])

In [100]:
from src.hyperdas.data_utils import (
    filter_dataset,
    generate_ravel_dataset,
    get_ravel_collate_fn,
)

dataset = generate_ravel_dataset(
    1000,
    root_path="./data/RAVEL",
    target_attributes=["Country"],
    isolate_attributes=["Continent"],
    template_split="train",
    entity_split="both",
)

In [101]:
filter_dataset(
    hypernetwork.interpretor.target_model, tokenizer, dataset, relative_position=False
)

63it [00:02, 24.40it/s]

Accuracy: 0.811; filtered out 189 examples





Dataset({
    features: ['input_prefix', 'input_suffix', 'counterfactual_input_prefix', 'counterfactual_input_suffix', 'edit_instruction', 'entity', 'counterfactual_entity', 'target', 'counterfactual_target', 'attribute_type', 'domain', 'attribute', 'verify_text'],
    num_rows: 811
})

In [105]:
filter_dataset(
    hypernetwork.interpretor.target_model,
    tokenizer,
    dataset,
    relative_position_ids=True,
)

63it [00:02, 26.28it/s]

Accuracy: 0.81; filtered out 190 examples





Dataset({
    features: ['input_prefix', 'input_suffix', 'counterfactual_input_prefix', 'counterfactual_input_suffix', 'edit_instruction', 'entity', 'counterfactual_entity', 'target', 'counterfactual_target', 'attribute_type', 'domain', 'attribute', 'verify_text'],
    num_rows: 810
})