In [93]:
from datasets import Dataset, load_dataset
from dataclasses import dataclass
from typing import List, Dict, Any
from transformers import PreTrainedTokenizerBase
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import itertools

import torch
import numpy as np

In [94]:
ds = load_dataset('tasksource/bigbench', 'movie_recommendation', split="train[:100]")



In [95]:
ds[0]

{'inputs': 'Find a movie similar to Batman, The Mask, The Fugitive, Pretty Woman:\n  choice: Maelstrom\n  choice: The Lion King\n  choice: Lamerica\n  choice: The Front Page',
 'targets': ['The Lion King'],
 'multiple_choice_targets': ['The Front Page',
  'Maelstrom',
  'The Lion King',
  'Lamerica'],
 'multiple_choice_scores': [0, 0, 1, 0],
 'idx': 0}

In [96]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [193]:
def preprocess_dataset(
    ds, tokenizer, method="direct", num_procs=8
):
    def preprocess_function(examples):
        """
        * tokenizes dataset
        * token_type_ids are 1 where there are label tokens and 0 otherwise
        """
        inputs = tokenizer(
            [inputs + " \n\n" for inputs in examples["inputs"]], add_special_tokens=False
        )
        targets = tokenizer(
            [" ".join(targets) + " \n\n\n" for targets in examples["targets"]],
            add_special_tokens=False,
        )

        # flip the location of inputs and targets for channel method
        if method == "channel":
            inputs, targets = targets, inputs

        outputs = {
            "input_ids": [],
            "attention_mask": [],
            "token_type_ids": [],
        }

        for i in range(len(inputs["input_ids"])):
            outputs["input_ids"].append(
                inputs["input_ids"][i] + targets["input_ids"][i]
            )
            outputs["attention_mask"].append(
                inputs["attention_mask"][i] + targets["attention_mask"][i]
            )
            outputs["token_type_ids"].append(
                [0] * len(inputs["input_ids"][i])
                + [1] * (len(targets["input_ids"][i])-2)
                + [0, 0]
            )

        return outputs

    ds = ds.map(
        preprocess_function,
        remove_columns=ds.column_names,
        batched=True,
        num_proc=num_procs,
    )
    return ds

In [194]:
pds = preprocess_dataset(ds, tokenizer, method='direct')



In [195]:
print(tokenizer.decode(pds[0]['input_ids']))

Find a movie similar to Batman, The Mask, The Fugitive, Pretty Woman:
  choice: Maelstrom
  choice: The Lion King
  choice: Lamerica
  choice: The Front Page 

The Lion King 





In [197]:
for i in range(len(pds)):
    print(len(pds[i]['input_ids']), len(pds[i]['attention_mask']), len(pds[i]['token_type_ids']))

53 53 53
67 67 67
60 60 60
60 60 60
58 58 58
60 60 60
66 66 66
65 65 65
61 61 61
79 79 79
68 68 68
68 68 68
58 58 58
59 59 59
59 59 59
57 57 57
67 67 67
71 71 71
60 60 60
61 61 61
60 60 60
63 63 63
77 77 77
56 56 56
71 71 71
54 54 54
54 54 54
71 71 71
52 52 52
61 61 61
64 64 64
57 57 57
61 61 61
64 64 64
74 74 74
78 78 78
62 62 62
58 58 58
60 60 60
50 50 50
50 50 50
67 67 67
65 65 65
63 63 63
60 60 60
64 64 64
66 66 66
56 56 56
66 66 66
64 64 64
79 79 79
67 67 67
57 57 57
65 65 65
52 52 52
72 72 72
73 73 73
77 77 77
73 73 73
64 64 64
67 67 67
60 60 60
58 58 58
61 61 61
65 65 65
65 65 65
56 56 56
61 61 61
62 62 62
59 59 59
68 68 68
64 64 64
55 55 55
63 63 63
65 65 65
71 71 71
61 61 61
69 69 69
56 56 56
71 71 71
68 68 68
67 67 67
67 67 67
63 63 63
61 61 61
72 72 72
68 68 68
55 55 55
76 76 76
74 74 74
60 60 60
62 62 62
71 71 71
69 69 69
63 63 63
59 59 59
61 61 61
61 61 61
59 59 59
49 49 49


In [198]:
tokenizer.decode(np.array(pds[0]['input_ids'])*np.array(pds[0]['token_type_ids']))

'!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!The Lion King!!'

In [199]:
@dataclass
class ICLCollator:
    tokenizer: PreTrainedTokenizerBase
    batch_size: int = 1
    max_length: int = 1024
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        * creates batches for in context/few shot learning
        """
        batch = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

        for i in range(0, len(features), k):
            batch["input_ids"].append(
                list(itertools.chain.from_iterable(
                    example["input_ids"]
                    for example in features[i : i + k]
                ))
            )
            batch["attention_mask"].append(
                list(itertools.chain.from_iterable(
                    example["attention_mask"]
                    for example in features[i : i + k]
                ))
            )
            batch["token_type_ids"].append(
                list(itertools.chain.from_iterable(
                    example["token_type_ids"]
                    for example in features[i : i + k]
                ))
            )

        batch = self.tokenizer.pad(
            batch,
            padding="longest",
            max_length=self.max_length,
            pad_to_multiple_of=None,
            return_tensors=self.return_tensors,
        )

        return batch

In [200]:
k=3
batch_size=8

In [201]:
collate_fn = ICLCollator(tokenizer, batch_size=batch_size)
dataloader = DataLoader(pds, batch_size=k*batch_size, collate_fn=collate_fn)

In [202]:
dataloader = DataLoader(pds, batch_size=k*batch_size, collate_fn=collate_fn)

In [203]:
data = next(iter(dataloader))

In [204]:
data['input_ids'].shape

torch.Size([8, 215])

In [205]:
print(tokenizer.batch_decode(data['input_ids'], skip_special_tokens=True)[0])

Find a movie similar to Batman, The Mask, The Fugitive, Pretty Woman:
  choice: Maelstrom
  choice: The Lion King
  choice: Lamerica
  choice: The Front Page 

The Lion King 


Find a movie similar to The Sixth Sense, The Matrix, Forrest Gump, The Shawshank Redemption:
  choice: Terminator 2 Judgment Day
  choice: The Sheltering Sky
  choice: Street Fighter II The Animated Movie
  choice: The Boy Who Could Fly 

Terminator 2 Judgment Day 


Find a movie similar to Schindler's List, Braveheart, The Silence of the Lambs, Tombstone:
  choice: Forrest Gump
  choice: Orlando
  choice: All the Real Girls
  choice: Guilty of Romance 

Forrest Gump 





In [206]:
print(tokenizer.batch_decode(data['input_ids'] * data['token_type_ids'], skip_special_tokens=True)[0])

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!The Lion King!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Terminator 2 Judgment Day!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!Forrest Gump!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
