In [1]:
import sys

sys.path.append("../..")
%load_ext autoreload
%autoreload 2

In [2]:
import json
import os

import datasets
import torch
from transformers import AutoTokenizer, LlamaForCausalLM

from src.hyperdas.data_utils import (
    filter_dataset,
    generate_ravel_dataset,
)

In [3]:
p = "/workspace/HyperDAS/assets/data/ravel/ravel_nobel_prize_winner_attribute_to_prompts.json"
with open(p, "r") as f:
    prompts = json.load(f)
    print(prompts.keys())

dict_keys(['Field', 'Award Year', 'Birth Year', 'Country of Birth', 'Gender'])


In [4]:
p = "/workspace/HyperDAS/assets/data/ravel/ravel_occupation_attribute_to_prompts.json"
with open(p, "r") as f:
    prompts = json.load(f)
    print(prompts.keys())

dict_keys(['Duty', 'Gender Bias', 'Industry', 'Work Location'])


In [5]:
model_name_or_path = "meta-llama/Meta-Llama-3-8B"

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
model = LlamaForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map={"": 0}
)

tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
# CHANGE: {base_entity} → {source_entity} | ATTR: {target_attribute}
from collections import defaultdict

all_attributes = {
    "city": ["Country", "Continent", "Language", "Latitude", "Longitude", "Timezone"],
    "nobel_prize_winner": [
        "Field",
        "Award Year",
        "Birth Year",
        "Country of Birth",
        "Gender",
    ],
    "occupation": ["Duty", "Industry", "Work Location"],
}

target_attributes = {
    "city": ["Country"],
    "nobel_prize_winner": ["Field"],
    "occupation": ["Duty"],
}


domains = ["city", "nobel_prize_winner", "occupation"]
all_datasets = defaultdict(list)

for domain in domains:
    for split in ["train", "test"]:
        args = {
            "n_samples": 10000,
            "root_path": "/workspace/HyperDAS/assets/data/ravel",
            "target_attributes": target_attributes[domain],
            "isolate_attributes": list(
                set(all_attributes[domain]) - set(target_attributes[domain])
            ),
            "template_split": split,
            "entity_split": split,
            "domain": domain,
            # "edit_instruction_template": "CHANGE: {base_entity} -> {source_entity} | ATTR: {random_target_attribute}",
        }

        dataset = generate_ravel_dataset(**args)

        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

        metadata = {
            **args,
            "target_attributes": tuple(args["target_attributes"]),
            "isolate_attributes": tuple(args["isolate_attributes"]),
        }

        all_datasets[split].append((dataset, metadata))


625it [00:48, 13.00it/s]


Accuracy: 0.6289; filtered out 3711 examples


394it [00:29, 13.27it/s]


Accuracy: 0.9984099220861822; filtered out 10 examples


393it [00:29, 13.27it/s]


Accuracy: 0.9982481286829112; filtered out 11 examples


625it [00:47, 13.18it/s]


Accuracy: 0.6137; filtered out 3863 examples


384it [00:28, 13.37it/s]


Accuracy: 0.9967410787029494; filtered out 20 examples


383it [00:28, 13.36it/s]


Accuracy: 0.998855648193559; filtered out 7 examples


625it [00:49, 12.53it/s]


Accuracy: 0.7337; filtered out 2663 examples


459it [00:37, 12.23it/s]


Accuracy: 0.9970014992503748; filtered out 22 examples


458it [00:37, 12.27it/s]


Accuracy: 0.9986329460013671; filtered out 10 examples


625it [00:49, 12.57it/s]


Accuracy: 0.7145; filtered out 2855 examples


447it [00:36, 12.40it/s]


Accuracy: 0.9960811756473058; filtered out 28 examples


445it [00:36, 12.32it/s]


Accuracy: 0.9978923703807784; filtered out 15 examples


625it [00:52, 11.85it/s]


Accuracy: 0.2814; filtered out 7186 examples


176it [00:14, 11.84it/s]


Accuracy: 0.9832977967306326; filtered out 47 examples


173it [00:14, 11.96it/s]


Accuracy: 0.9916877484640405; filtered out 23 examples


625it [00:52, 11.91it/s]


Accuracy: 0.249; filtered out 7510 examples


156it [00:13, 11.84it/s]


Accuracy: 0.9863453815261044; filtered out 34 examples


154it [00:13, 11.71it/s]

Accuracy: 0.995114006514658; filtered out 12 examples





In [8]:
for split, dataset_list in all_datasets.items():
    dataset_list, metadata_list = zip(*dataset_list)
    combined = datasets.concatenate_datasets(dataset_list)
    path = f"/workspace/HyperDAS/experiments/RAVEL/data/city_nobel_prize_winner_occupation_{split}"
    combined.save_to_disk(path)
    with open(os.path.join(path, "metadata.json"), "w") as f:
        json.dump({"metadata": metadata_list}, f)

Saving the dataset (0/1 shards):   0%|          | 0/16317 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/15656 [00:00<?, ? examples/s]

In [9]:
# CHANGE: {base_entity} → {source_entity} | ATTR: {target_attribute}
from collections import defaultdict

all_attributes = {
    "city": ["Country", "Continent", "Language", "Latitude", "Longitude", "Timezone"],
    "nobel_prize_winner": [
        "Field",
        "Award Year",
        "Birth Year",
        "Country of Birth",
        "Gender",
    ],
}

target_attributes = {
    "city": ["Country"],
    "nobel_prize_winner": ["Field"],
}


domains = ["city", "nobel_prize_winner"]
all_datasets = defaultdict(list)

for domain in domains:
    for split in ["train", "test"]:
        args = {
            "n_samples": 10000,
            "root_path": "/workspace/HyperDAS/assets/data/ravel",
            "target_attributes": target_attributes[domain],
            "isolate_attributes": list(
                set(all_attributes[domain]) - set(target_attributes[domain])
            ),
            "template_split": split,
            "entity_split": split,
            "domain": domain,
            # "edit_instruction_template": "CHANGE: {base_entity} -> {source_entity} | ATTR: {random_target_attribute}",
        }

        dataset = generate_ravel_dataset(**args)

        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)
        dataset = filter_dataset(model, tokenizer, dataset, batch_size=16)

        metadata = {
            **args,
            "target_attributes": tuple(args["target_attributes"]),
            "isolate_attributes": tuple(args["isolate_attributes"]),
        }

        all_datasets[split].append((dataset, metadata))


625it [00:47, 13.06it/s]


Accuracy: 0.6289; filtered out 3711 examples


394it [00:29, 13.28it/s]


Accuracy: 0.9984099220861822; filtered out 10 examples


393it [00:29, 13.27it/s]


Accuracy: 0.9982481286829112; filtered out 11 examples


625it [00:47, 13.09it/s]


Accuracy: 0.6137; filtered out 3863 examples


384it [00:28, 13.36it/s]


Accuracy: 0.9967410787029494; filtered out 20 examples


383it [00:28, 13.35it/s]


Accuracy: 0.998855648193559; filtered out 7 examples


625it [00:49, 12.54it/s]


Accuracy: 0.7337; filtered out 2663 examples


459it [00:37, 12.25it/s]


Accuracy: 0.9970014992503748; filtered out 22 examples


458it [00:37, 12.27it/s]


Accuracy: 0.9986329460013671; filtered out 10 examples


625it [00:49, 12.56it/s]


Accuracy: 0.7145; filtered out 2855 examples


447it [00:36, 12.32it/s]


Accuracy: 0.9960811756473058; filtered out 28 examples


445it [00:36, 12.33it/s]

Accuracy: 0.9978923703807784; filtered out 15 examples





In [10]:
for split, dataset_list in all_datasets.items():
    dataset_list, metadata_list = zip(*dataset_list)
    combined = datasets.concatenate_datasets(dataset_list)
    path = f"/workspace/HyperDAS/experiments/RAVEL/data/city_nobel_prize_winner_{split}"
    combined.save_to_disk(path)
    with open(os.path.join(path, "metadata.json"), "w") as f:
        json.dump({"metadata": metadata_list}, f)

Saving the dataset (0/1 shards):   0%|          | 0/13573 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/13212 [00:00<?, ? examples/s]