In [1]:
import sys

sys.path.append("../..")
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, LlamaForCausalLM

from src.hyperdas.data_utils import (
    filter_dataset,
    generate_ravel_dataset,
)

In [5]:
model_name_or_path = "meta-llama/Meta-Llama-3-8B"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16)
model = model.cuda()

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
all_attributes = [
    "Country",
    "Continent",
    "Language",
    "Latitude",
    "Longitude",
    "Timezone",
]

for attribute in all_attributes:
    for split, size in [("train", 20000), ("test", 4000)]:
        print(f"Generating data for {attribute} in split {split}")
        all_other_attributes = [a for a in all_attributes if a != attribute]

        dataset = generate_ravel_dataset(
            size,
            root_path="/home/ubuntu/HyperDAS/data/ravel",
            target_attributes=[attribute],
            isolate_attributes=all_other_attributes,
            template_split=split,
            entity_split="both",
        )

        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

        dataset.save_to_disk(
            f"/home/ubuntu/HyperDAS/experiments/ravel/data/city_{attribute.lower()}_{split}"
        )

Generating data for Country in split train


FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/HyperDAS/data/ravel/ravel_city_entity_to_split.json'

In [13]:
def generate_ravel_causal_dataset(split):
    dataset = generate_ravel_dataset(
        10000,
        root_path="/workspace/HyperDAS/assets/data/ravel",
        target_attributes=["Country"],
        isolate_attributes=[],
        template_split=split,
        entity_split=split,
    )

    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
    dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

    dataset.save_to_disk(
        f"/workspace/HyperDAS/experiments/RAVEL/data/ravel_country_causal_only_{split}"
    )

    return dataset

In [14]:
train_dataset, test_dataset = (
    generate_ravel_causal_dataset("train"),
    generate_ravel_causal_dataset("test"),
)

625it [00:46, 13.33it/s]


Accuracy: 0.7735; filtered out 2265 examples


484it [00:34, 14.22it/s]


Accuracy: 0.9987071751777634; filtered out 10 examples


483it [00:34, 14.14it/s]

Accuracy: 0.9981877022653721; filtered out 14 examples





Saving the dataset (0/1 shards):   0%|          | 0/7711 [00:00<?, ? examples/s]

625it [00:45, 13.83it/s]


Accuracy: 0.7611; filtered out 2389 examples


476it [00:33, 14.42it/s]


Accuracy: 0.9978977795296282; filtered out 16 examples


475it [00:32, 14.62it/s]


Accuracy: 0.9981566820276497; filtered out 14 examples


Saving the dataset (0/1 shards):   0%|          | 0/7581 [00:00<?, ? examples/s]

In [22]:
train_dataset[0]

{'input_prefix': '[{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "Quiemo',
 'input_suffix': '", "country": "',
 'counterfactual_input_prefix': '[{"city": "Paris", "country": "France"}, {"city": "Bongor',
 'counterfactual_input_suffix': '", "country": "',
 'edit_instruction': 'Quiemo ; Bongor - Country',
 'entity': 'Quiemo',
 'counterfactual_entity': 'Bongor',
 'target': 'China',
 'counterfactual_target': 'Chad',
 'attribute_type': 'causal',
 'domain': 'city',
 'attribute': 'Country',
 'verify_text': '[{"city": "Rio de Janeiro", "country": "Brazil"}, {"city": "Bongor", "country": "'}

In [15]:
all_country = [d for d in test_dataset if d["attribute"] == "Country"]
all_country_with_isolation = [
    d
    for d in test_dataset
    if d["attribute"] == "Country" and d["attribute_type"] == "isolate"
]
len(test_dataset), len(all_country), len(all_country_with_isolation)

(7581, 7581, 0)

In [16]:
attr_dict = {}

for d in test_dataset:
    if d["attribute"] not in attr_dict:
        attr_dict[d["attribute"]] = 0

    attr_dict[d["attribute"]] += 1

attr_dict

{'Country': 7581}