In [1]:
import datasets
import numpy as np
import sys
sys.path.append('..')
from src.data import EvalDatasetWrapper, ICLCollator, preprocess_dataset

In [2]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

In [3]:
train = datasets.load_dataset('json', data_files='../data/baseline_train/abstract_narrative_understanding.json', split='train')
test = datasets.load_dataset('json', data_files='../data/baseline_test/crash_blossom.json', split='train')

Found cached dataset json (/Users/yeeb/.cache/huggingface/datasets/json/default-96676ccee0c43a59/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Found cached dataset json (/Users/yeeb/.cache/huggingface/datasets/json/default-8b0cc3f6e2fb8516/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


In [4]:
train[0]

{'inputs': "In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Ralph had a clothing store and sales had peaked. He looked at his sales data and noticed one demographic wasn't buying his clothes. So, Ralph started a rumor that he wouldn't sell his clothes to them and when word got out that group of people started purchasing his clothes just to spite him. Now, Ralph enjoys taking money from all demographic groups.\nThis narrative is a good illustration of the following proverb:",
 'idx': 463,
 'targets': ['All publicity is good publicity'],
 'multiple_choice_scores': [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
 'true_idx': 2063,
 'is_generated': False,
 'multiple_choice_targets': ['He who laughs last laughs longest',
  'Honey catches more flies than vinegar',
  'Build a better mousetrap and the world will beat a path to your doorLink to proverb',
  'All publicity is good publicity',
  'Revenge is a dish best served cold',
  'Give credit where credi

In [5]:
test[0]

{'inputs': 'Identify the part of speech (verb, adjective, noun, or preposition) of the specified word in the following headlines.\n\nIn the following sentence, what part of speech is watch? Sentence: Watch batteries while you wait\nA:',
 'idx': 13,
 'targets': ['noun'],
 'multiple_choice_scores': [0, 0, 1, 0],
 'true_idx': 13,
 'is_generated': False,
 'multiple_choice_targets': ['verb', 'adjective', 'noun', 'preposition']}

In [6]:
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [7]:
train_processed = preprocess_dataset(train, tokenizer)
test_processed = preprocess_dataset(test, tokenizer)

Loading cached processed dataset at /Users/yeeb/.cache/huggingface/datasets/json/default-96676ccee0c43a59/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-ff2376a2b0057b13_*_of_00008.arrow
Loading cached processed dataset at /Users/yeeb/.cache/huggingface/datasets/json/default-8b0cc3f6e2fb8516/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-ad266f0c897c866b_*_of_00008.arrow


In [76]:
from dataclasses import dataclass
from transformers import PreTrainedTokenizerBase
from typing import List, Dict, Any, Union
import itertools

In [84]:
@dataclass
class ICLCollator:
    tokenizer: PreTrainedTokenizerBase
    k_examples: int = 16
    max_length: int = 1024
    return_tensors: str = "pt"
    for_eval: bool = False

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        * creates batches for in context/few shot learning
        * length of [features] should be (k_examples * batch_size)
        """
        batch = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

        if self.for_eval:
            # if collation for evaluation, features is a List[List[Dict[str, Any]]] 
            # where the inner list contains our k_examples, so flatten it
            features = list(itertools.chain.from_iterable(features))

        for i in range(0, len(features), self.k_examples):
            batch["input_ids"].append(
                list(
                    itertools.chain.from_iterable(
                        example["input_ids"]
                        for example in features[i : i + self.k_examples]
                    )
                )
            )
            batch["attention_mask"].append(
                list(
                    itertools.chain.from_iterable(
                        example["attention_mask"]
                        for example in features[i : i + self.k_examples]
                    )
                )
            )
            batch["token_type_ids"].append(
                list(
                    itertools.chain.from_iterable(
                        example["token_type_ids"]
                        for example in features[i : i + self.k_examples]
                    )
                )
            )

        batch = self.tokenizer.pad(
            batch,
            padding="longest",
            max_length=self.max_length,
            pad_to_multiple_of=None,
            return_tensors=self.return_tensors,
        )

        return batch

In [85]:
wrapped_dataset = EvalDatasetWrapper(train_processed, test_processed)

In [90]:
collate_fn = ICLCollator(tokenizer, for_eval=True)

In [91]:
dl = DataLoader(wrapped_dataset, batch_size=3, collate_fn=collate_fn)

In [92]:
it = iter(dl)

In [95]:
tokenizer.batch_decode(next(it)['input_ids'])

[{'input_ids': [818, 644, 5679, 11, 356, 2148, 1790, 26274, 11, 1123, 286, 543, 21290, 257, 2219, 36950, 13, 220, 198, 45750, 876, 25, 35015, 750, 407, 30419, 2031, 1141, 2159, 1810, 2873, 287, 257, 29923, 13, 2399, 37451, 287, 11228, 1321, 290, 4634, 510, 26355, 2957, 284, 1943, 13, 198, 1212, 8689, 318, 257, 922, 20936, 286, 262, 1708, 36950, 25, 220, 628, 37, 11608, 284, 1410, 318, 5410, 284, 2038, 220, 628, 198], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]}, {'input_ids': [818, 644, 5679, 11, 356, 2148, 1790, 26274, 11, 1123, 286, 543, 21290, 257, 2219, 369

['In what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: Eisenhower did not invade Europe during World War II in a whim. His brilliance in gathering information and setting up logistics led to success.\nThis narrative is a good illustration of the following proverb: \n\nFailing to plan is planning to fail \n\n\nIn what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: One of my students invited me on the opening of his new coffee shop business. I advised him that never comprise on quality of products and services as once people don\'t like anything they will never come back\nThis narrative is a good illustration of the following proverb: \n\nOnce bitten, twice shy \n\n\nIn what follows, we provide short narratives, each of which illustrates a common proverb. \nNarrative: The man wanted to write a new program to play chess. He thought he could start from scratch and develop something novel. The 