In [33]:
from datasets import load_from_disk, Dataset
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm

In [4]:
def get_model_prediction_labels(model, tokenizer, target_dataset, max_num_tokens=3):
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    def collate_fn(batch):
        input_texts = []

        for b in batch:
            prefix = b["context"]
            suffix = b["prompt"]

            input_text = f"{tokenizer.bos_token} {prefix}{suffix}"
            input_texts.append(input_text)

        inputs = tokenizer(
            input_texts, return_tensors="pt", padding=True, truncation=True
        )
        # inputs["position_ids"] = torch.cumsum(inputs["attention_mask"], dim=1) * inputs["attention_mask"] - 1

        return inputs

    model_predictions = []

    dataloader = DataLoader(
        target_dataset, batch_size=16, collate_fn=collate_fn, shuffle=False
    )

    model = model.to("cuda")
    for batch in tqdm(dataloader):
        batch = {k: v.to("cuda") for k, v in batch.items()}
        outputs = model.generate(**batch, max_new_tokens=max_num_tokens)

        outputs = outputs[:, batch["input_ids"].shape[1] :]

        predictions = [
            tokenizer.decode(output, skip_special_tokens=True) for output in outputs
        ]
        model_predictions.extend(predictions)

    target_dataset = target_dataset.add_column("model_predictions", model_predictions)

    return target_dataset

In [9]:
from transformers import AutoTokenizer, LlamaForCausalLM

model = LlamaForCausalLM.from_pretrained(
    "/nlp/scr/sjd24/llama3-8b", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained("/nlp/scr/sjd24/llama3-8b")

Loading checkpoint shards: 100%|██████████| 4/4 [04:10<00:00, 62.52s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
model = model.to("cuda")

In [55]:
import random


def process_autoravel_dataset(
    model, tokenizer, target_dataset, max_num_tokens=3, source_per_target=1, domain=None
):
    tokenizer.padding_side = "left"
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    processed_dataset = []

    all_entities = set(d["entity"] for d in target_dataset)

    def collate_fn(batch):
        input_texts = []

        for b in batch:
            input_text = tokenizer.bos_token + " " + b["input_prompt"]
            input_texts.append(input_text)

        inputs = tokenizer(
            input_texts, return_tensors="pt", padding=True, truncation=True
        )
        # inputs["position_ids"] = torch.cumsum(inputs["attention_mask"], dim=1) * inputs["attention_mask"] - 1

        return inputs

    for entity in tqdm(all_entities):
        entity_data = target_dataset.filter(lambda x: x["entity"] == entity)

        all_contexts = set(d["context"] for d in entity_data)

        for i, context in enumerate(all_contexts):
            context_data = entity_data.filter(lambda x: x["context"] == context)
            attributes = set([d["attribute"] for d in context_data])

            # Remove data entry with duplicate attributes
            for attribute in attributes:
                attribute_data = context_data.filter(
                    lambda x: x["attribute"] == attribute
                )
                if len(attribute_data) > 1:
                    context_data = context_data.filter(
                        lambda x: x["attribute"] != attribute
                    )

                    context_data = context_data.add_item(random.choice(attribute_data))

            for _ in range(source_per_target):
                source_entity = random.choice(list(all_entities - {entity}))
                source_entity_data = target_dataset.filter(
                    lambda x: x["entity"] == source_entity
                )
                source_all_contexts = set(d["context"] for d in source_entity_data)

                source_context = random.choice(list(source_all_contexts))

                labels_dict = {"base": dict(), "source": dict()}

                label_dataset = []
                for attribute in attributes:
                    base_attribute_data = context_data.filter(
                        lambda x: x["attribute"] == attribute
                    )
                    assert (
                        len(base_attribute_data) == 1
                    ), f"Attribute {attribute} has {len(base_attribute_data)} entries in the dataset"
                    base_attribute_data = base_attribute_data[0]

                    input_prompt = (
                        base_attribute_data["context"] + base_attribute_data["prompt"]
                    )
                    counterfactual_input_prompt = (
                        source_context + base_attribute_data["prompt"]
                    )

                    label_dataset.append(
                        {
                            "input_prompt": input_prompt,
                            "attribute": attribute,
                            "entity_type": "base",
                        }
                    )
                    label_dataset.append(
                        {
                            "input_prompt": counterfactual_input_prompt,
                            "attribute": attribute,
                            "entity_type": "source",
                        }
                    )

                label_dataset = Dataset.from_list(label_dataset)
                model_predictions = []

                label_dataloader = DataLoader(
                    label_dataset, batch_size=4, collate_fn=collate_fn, shuffle=False
                )
                for batch in tqdm(label_dataloader):
                    batch = {k: v.to(model.device) for k, v in batch.items()}
                    outputs = model.generate(**batch, max_new_tokens=max_num_tokens)
                    outputs = outputs[:, batch["input_ids"].shape[1] :]
                    predictions = [
                        tokenizer.decode(output, skip_special_tokens=True)
                        for output in outputs
                    ]
                    model_predictions.extend(predictions)

                for label, data in zip(model_predictions, label_dataset):
                    labels_dict[data["entity_type"]][data["attribute"]] = label

                for attribute in attributes:
                    base_attribute_data = context_data.filter(
                        lambda x: x["attribute"] == attribute
                    )
                    assert (
                        len(base_attribute_data) == 1
                    ), f"Attribute {attribute} has {len(base_attribute_data)} entries in the dataset"
                    base_attribute_data = base_attribute_data[0]

                    input_prefix = base_attribute_data["context"]
                    input_suffix = base_attribute_data["prompt"]
                    counterfactual_input_prefix = source_context
                    counterfactual_input_suffix = base_attribute_data["prompt"]
                    edit_instruction = attribute.title()
                    entity = entity
                    counterfactual_entity = source_entity
                    attribute_type = "cause"
                    domain = "unknown" if domain is None else domain
                    attribute = attribute
                    target = labels_dict["base"][attribute]
                    counterfactual_target = labels_dict["source"][attribute]

                    processed_dataset.append(
                        {
                            "input_prefix": input_prefix,
                            "input_suffix": input_suffix,
                            "counterfactual_input_prefix": counterfactual_input_prefix,
                            "counterfactual_input_suffix": counterfactual_input_suffix,
                            "edit_instruction": edit_instruction,
                            "entity": entity,
                            "counterfactual_entity": counterfactual_entity,
                            "target": target,
                            "counterfactual_target": counterfactual_target,
                            "attribute_type": attribute_type,
                            "domain": domain,
                            "attribute": attribute,
                        }
                    )

                    all_other_attributes = attributes - {attribute}

                    for other_attribute in all_other_attributes:
                        other_attribute_base_data = context_data.filter(
                            lambda x: x["attribute"] == other_attribute
                        )
                        assert (
                            len(other_attribute_base_data) == 1
                        ), f"Attribute {other_attribute} has {len(other_attribute_base_data)} entries in the dataset"
                        other_attribute_base_data = other_attribute_base_data[0]

                        other_target = labels_dict["base"][other_attribute]
                        other_counterfactual_target = labels_dict["source"][
                            other_attribute
                        ]

                        processed_dataset.append(
                            {
                                "input_prefix": input_prefix,
                                "input_suffix": other_attribute_base_data["prompt"],
                                "counterfactual_input_prefix": counterfactual_input_prefix,
                                "counterfactual_input_suffix": other_attribute_base_data[
                                    "prompt"
                                ],
                                "edit_instruction": edit_instruction,
                                "entity": entity,
                                "counterfactual_entity": counterfactual_entity,
                                "target": other_target,
                                "counterfactual_target": other_counterfactual_target,
                                "attribute_type": "isolate",
                                "domain": domain,
                                "attribute": other_attribute,
                            }
                        )

    return Dataset.from_list(processed_dataset)


human_dataset = load_from_disk("auto_ravel/Q5")
processed_dataset = process_autoravel_dataset(model, tokenizer, human_dataset)

  0%|          | 0/522 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
100%|██████████| 6/6 [00:00<00:00,  9.86it/s]
  0%|          | 1/522 [00:01<09:31,  1.10s/it]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
100%|██████████| 3/3 [00:00<00:00, 10.10it/s]
  0%|          | 2/522 [00:01<05:59,  1.45it/s]Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:1280

In [59]:
processed_dataset.save_to_disk("auto_ravel/Q5_processed")

Saving the dataset (1/1 shards): 100%|██████████| 35208/35208 [00:00<00:00, 385421.41 examples/s]
