In [42]:
import datasets
import numpy as np
import torch
import sys
sys.path.append('..')
from src.data import EvalDatasetWrapper, ICLCollator, preprocess_dataset

In [43]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

In [19]:
train = datasets.load_dataset('json', data_files='../data/baseline_train/abstract_narrative_understanding.json', split='train')
test = datasets.load_dataset('json', data_files='../data/baseline_test/crash_blossom.json', split='train')

Found cached dataset json (/Users/yeeb/.cache/huggingface/datasets/json/default-96676ccee0c43a59/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Found cached dataset json (/Users/yeeb/.cache/huggingface/datasets/json/default-8b0cc3f6e2fb8516/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [20]:
train[0]

{'inputs': "In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Ralph had a clothing store and sales had peaked. He looked at his sales data and noticed one demographic wasn't buying his clothes. So, Ralph started a rumor that he wouldn't sell his clothes to them and when word got out that group of people started purchasing his clothes just to spite him. Now, Ralph enjoys taking money from all demographic groups.\nThis narrative is a good illustration of the following proverb:",
 'idx': 463,
 'targets': ['All publicity is good publicity'],
 'multiple_choice_scores': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 'true_idx': 2063,
 'is_generated': False,
 'multiple_choice_targets': ['He who laughs last laughs longest',
  'Honey catches more flies than vinegar',
  'Build a better mousetrap and the world will beat a path to your doorLink to proverb',
  'All publicity is good publicity',
  'Revenge is a dish best served cold',
  'Give credit where credi

In [44]:
test[0]

{'inputs': 'Identify the part of speech (verb, adjective, noun, or preposition) of the specified word in the following headlines.\n\nIn the following sentence, what part of speech is watch? Sentence: Watch batteries while you wait\nA:',
 'idx': 13,
 'targets': ['noun'],
 'multiple_choice_scores': [0, 0, 1, 0],
 'true_idx': 13,
 'is_generated': False,
 'multiple_choice_targets': ['verb', 'adjective', 'noun', 'preposition']}

In [45]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [46]:
train_processed = preprocess_dataset(train, tokenizer)
test_processed = preprocess_dataset(test, tokenizer)

Loading cached processed dataset at /Users/yeeb/.cache/huggingface/datasets/json/default-96676ccee0c43a59/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-ff2376a2b0057b13_*_of_00008.arrow
Loading cached processed dataset at /Users/yeeb/.cache/huggingface/datasets/json/default-8b0cc3f6e2fb8516/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-ad266f0c897c866b_*_of_00008.arrow


In [47]:
from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
from typing import List, Dict, Any, Union
import itertools

In [48]:
@dataclass
class ICLCollator:
    tokenizer: PreTrainedTokenizerBase
    k_examples: int = 16
    max_length: int = 2048
    return_tensors: str = "pt"
    for_eval: bool = False

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        * creates batches for in context/few shot learning
        * length of [features] should be (k_examples * batch_size)
        * if for_eval create a labels field
        """
        batch = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

        if self.for_eval:
            # if collation for evaluation, features is a List[List[Dict[str, Any]]]
            # where the inner list contains our k_examples, so flatten it
            features = list(itertools.chain.from_iterable(features))

        for i in range(0, len(features), self.k_examples):
            batch["input_ids"].append(
                list(
                    itertools.chain.from_iterable(
                        example["input_ids"]
                        for example in features[i : i + self.k_examples]
                    )
                )[: self.max_length]
            )
            batch["attention_mask"].append(
                list(
                    itertools.chain.from_iterable(
                        example["attention_mask"]
                        for example in features[i : i + self.k_examples]
                    )
                )[: self.max_length]
            )
            batch["token_type_ids"].append(
                list(
                    itertools.chain.from_iterable(
                        example["token_type_ids"]
                        for example in features[i : i + self.k_examples]
                    )
                )[: self.max_length]
            )

        batch = self.tokenizer.pad(
            batch,
            padding="longest",
            max_length=self.max_length,
            pad_to_multiple_of=None,
            return_tensors=self.return_tensors,
        )

        if self.for_eval:
            batch["labels"] = batch["input_ids"].clone()
            batch["labels"] *= batch["token_type_ids"]

        return batch

In [94]:
@dataclass
class EvalDatasetWrapper(torch.utils.data.Dataset):
    """
    Simple Dataset wrapper that returns k_examples-1 random
    examples from the training set for each evaluation example
    """

    train_dataset: datasets.Dataset
    eval_dataset: datasets.Dataset
    k_examples: int = 16

    def __len__(self):
        return len(self.eval_dataset)

    def __getitem__(self, index):
        random_examples = np.random.randint(
            0, len(self.train_dataset), size=(self.k_examples - 1,)
        )
        examples = [self.train_dataset[i.item()] for i in random_examples]
        for x in examples:
            # ignore label mask for the examples, we only care about the last one
            x["token_type_ids"] = [0] * len(x["token_type_ids"])

        target = self.eval_dataset[index]

        label_start = target["token_type_ids"].index(1)
        target["input_ids"] = target["input_ids"][:label_start]
        target["attention_mask"] = target["attention_mask"][:label_start]
        target["token_type_ids"] = target["token_type_ids"][:label_start]

        # remove the label from the target (cheating if we dont do this lol)
        # target["input_ids"] = [
        #     target["input_ids"][i]
        #     for i, label in enumerate(target["token_type_ids"])
        #     if not label
        # ]
        # target["attention_mask"] = [
        #     target["attention_mask"][i]
        #     for i, label in enumerate(target["token_type_ids"])
        #     if not label
        # ]
        # target["token_type_ids"] = [
        #     label for label in target["token_type_ids"] if not label
        # ]

        return examples + [target]

In [95]:
wrapped_dataset = EvalDatasetWrapper(train_processed, test_processed, k_examples=3)

In [96]:
collate_fn = ICLCollator(tokenizer, k_examples=3, for_eval=True)

In [97]:
dl = DataLoader(wrapped_dataset, batch_size=3, collate_fn=collate_fn)

In [98]:
it = iter(dl)

In [99]:
ex = next(it)

In [100]:
tokenizer.decode(ex['input_ids'][1])

'In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Herbert is man of words but sometimes he tells stories that are unimaginable. We thought that he was always lying. He once said about a two-legged dog which everyone thought was a lie. Then Herbert actually showed up with the dog and we were shocked.\nThis narrative is a good illustration of the following proverb: \n\nSeeing is believing \n\n\nIn what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Robbie delighted in telling people stories about his wealthy background even though he actually came from a poor home.  In college, his new roommate confronted Robbie about his true background.  His roommate told Robbie he also told people false stories about his rich parents.\nThis narrative is a good illustration of the following proverb: \n\nIt takes a thief to catch a thief \n\n\nIdentify the part of speech (verb, adjective, noun, or prepo

In [101]:
tokenizer.decode(ex['input_ids'][1])

'In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Herbert is man of words but sometimes he tells stories that are unimaginable. We thought that he was always lying. He once said about a two-legged dog which everyone thought was a lie. Then Herbert actually showed up with the dog and we were shocked.\nThis narrative is a good illustration of the following proverb: \n\nSeeing is believing \n\n\nIn what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Robbie delighted in telling people stories about his wealthy background even though he actually came from a poor home.  In college, his new roommate confronted Robbie about his true background.  His roommate told Robbie he also told people false stories about his rich parents.\nThis narrative is a good illustration of the following proverb: \n\nIt takes a thief to catch a thief \n\n\nIdentify the part of speech (verb, adjective, noun, or prepo

In [102]:
print(tokenizer.batch_decode(ex['input_ids'])[1])

In what follows, we provide short narratives, each of which illustrates a common proverb. 
Narrative: Herbert is man of words but sometimes he tells stories that are unimaginable. We thought that he was always lying. He once said about a two-legged dog which everyone thought was a lie. Then Herbert actually showed up with the dog and we were shocked.
This narrative is a good illustration of the following proverb: 

Seeing is believing 


In what follows, we provide short narratives, each of which illustrates a common proverb. 
Narrative: Robbie delighted in telling people stories about his wealthy background even though he actually came from a poor home.  In college, his new roommate confronted Robbie about his true background.  His roommate told Robbie he also told people false stories about his rich parents.
This narrative is a good illustration of the following proverb: 

It takes a thief to catch a thief 


Identify the part of speech (verb, adjective, noun, or preposition) of the 