In [3]:
from torch.utils.data import DataLoader
from datasets import load_from_disk
from src.hyperdas.data_utils import generate_ravel_dataset, get_ravel_collate_fn, filter_dataset

from transformers import AutoTokenizer

%load_ext autoreload
%autoreload 2

In [4]:
tokenizer = AutoTokenizer.from_pretrained("/scr-ssd/sjd24/llama3-8b")
tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

train_dataset = load_from_disk("./experiments/RAVEL/data/city_country_train")
test_dataset = load_from_disk("./experiments/RAVEL/data/city_country_test")

collate_fn = get_ravel_collate_fn(tokenizer, add_space_before_target=True, contain_entity_position=True, source_suffix_visibility=False, base_suffix_visibility=False)
dataloader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn, shuffle=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
from src.hyperdas.llama3.model import RavelInterpretorHypernetwork

hypernetwork = RavelInterpretorHypernetwork(
    model_name_or_path="/scr-ssd/sjd24/llama3-8b",
    num_editing_heads=32,
    intervention_layer=15,
    subspace_module="ReflectSelect",
    das_dimension=128,
)
hypernetwork = hypernetwork.to("cuda")

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s]
Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.37it/s]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [29]:
for batch in dataloader:
    pass

editor_input_ids = batch["editor_input_ids"].to("cuda")
base_input_ids = batch["base_input_ids"].to("cuda")
base_attention_mask = batch["base_attention_mask"].to("cuda")
base_intervention_mask = batch["base_intervention_mask"].to("cuda")
source_input_ids = batch["source_input_ids"].to("cuda")
source_attention_mask = batch["source_attention_mask"].to("cuda")
source_intervention_mask = batch["source_intervention_mask"].to("cuda")
labels = batch["labels"].to("cuda")

In [30]:
is_causal = batch["is_causal"].to("cuda")
is_causal

tensor([False, False, False, False,  True, False,  True,  True,  True,  True,
         True,  True,  True, False], device='cuda:0')

In [27]:
causal_labels =  batch["labels"].to("cuda")[is_causal]
final_list = []
for i in range(causal_labels.shape[0]):
    final_list.extend(causal_labels[i][causal_labels[i] != -100].tolist())

print(final_list)

[24922, 32164, 10384, 15302, 587, 277, 39563, 45606, 37766]

tokenizer.decode(final_list)

# [24922, 32164, 10384, 15302,   587,   277, 32164, 45606,  6890, 22404,
        # 51419, 13936, 39563,    23, 23078, 45606, 16327, 37766]
# [2., 2., 2., 2., 2., 2., 1., 1., 1., 1., 1., 1., 2., 1., 1., 2., 1., 2.]

[24922, 32164, 10384, 15302, 587, 277, 39563, 45606, 37766]


' Indonesia Argentina Africa/DakarAsiaRussia Kenya'

In [31]:
import torch
_pred = hypernetwork.interpretor(
    editor_input_ids=editor_input_ids,
    editor_attention_mask=editor_input_ids != hypernetwork.interpretor_config.eos_token_id,
    base_input_ids=base_input_ids,
    base_attention_mask=base_attention_mask,
    base_intervention_mask=base_intervention_mask,
    source_input_ids=source_input_ids,
    source_attention_mask=source_attention_mask,
    source_intervention_mask=source_intervention_mask,
    output_intervention_weight=True,
    intervention_weight=None,
    inference_mode=None
)

if labels is not None:
    log_prob_predictions = torch.nn.functional.log_softmax(
        _pred.logits.reshape(-1, _pred.logits.shape[-1]),
        dim=1,
    )
    
    if is_causal is not None:
        loss_weight = torch.ones_like(labels, dtype=log_prob_predictions.dtype)
        loss_weight[is_causal, :] = 2.0
        loss_weight[~is_causal, :] = 1
    
    labels = labels.reshape(-1)
    
    if is_causal is not None:
        loss_weight = loss_weight.reshape(-1)

    assert labels.shape == log_prob_predictions.shape[:-1]
    
    # Only consider the tokens that are not -100 in target_labels
    label_indices = labels != -100
    output_idices = torch.zeros_like(label_indices)
    output_idices[:-1] = label_indices[1:]
    
    log_prob_predictions = log_prob_predictions[output_idices, :]

    labels = labels[label_indices]
    print(labels)
    
    # Compute the cross-entropy loss with masking
    
    if is_causal is None:
        criterion = torch.nn.CrossEntropyLoss(reduction="mean")
        loss = criterion(log_prob_predictions, labels.long())
    else:
        loss_weight = loss_weight[label_indices]
        print(loss_weight)
        criterion = torch.nn.CrossEntropyLoss(reduction="none")
        loss = criterion(log_prob_predictions, labels.long())
        print(loss)
        print(loss * loss_weight)
        
        loss = (loss * loss_weight).mean()
        
    _pred["loss"] = loss

tensor([26070,  5270,   220,  1958, 10384,  4606,  8524, 10384,    14, 87995,
         4918, 24664,   263,  8524,  7505,   258, 11876,  8942,    78, 70606,
         8494,    42,  1394, 64847,  1644], device='cuda:0')
tensor([1., 1., 1., 1., 1., 1., 2., 1., 1., 1., 1., 2., 2., 2., 2., 2., 2., 2.,
        2., 2., 2., 2., 2., 2., 1.], device='cuda:0')
tensor([7.9364e-02, 9.1462e-03, 1.5779e-02, 4.1387e-01, 5.6562e-01, 9.8465e-02,
        9.7345e+00, 1.2639e+00, 1.0317e-01, 3.8390e-02, 9.7323e-04, 1.0995e+01,
        2.1405e-02, 7.8050e+00, 1.2287e+01, 5.7046e-02, 1.0867e+01, 3.5980e-02,
        4.7250e-03, 1.6534e+01, 1.2152e+01, 1.1328e+01, 2.8085e+00, 2.6975e-02,
        5.6782e-01], device='cuda:0', grad_fn=<NllLossBackward0>)
tensor([7.9364e-02, 9.1462e-03, 1.5779e-02, 4.1387e-01, 5.6562e-01, 9.8465e-02,
        1.9469e+01, 1.2639e+00, 1.0317e-01, 3.8390e-02, 9.7323e-04, 2.1989e+01,
        4.2810e-02, 1.5610e+01, 2.4573e+01, 1.1409e-01, 2.1735e+01, 7.1960e-02,
        9.4500e-03, 3.3