In [135]:
from datasets import Dataset, load_dataset
from dataclasses import dataclass
from typing import List, Dict, Any
from transformers import PreTrainedTokenizerBase
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
import itertools

import torch
import numpy as np

In [136]:
import sys
sys.path.append('..')

In [139]:
ds = load_dataset('tasksource/bigbench', 'movie_recommendation', split="train")

Found cached dataset bigbench (/Users/yeeb/.cache/huggingface/datasets/tasksource___bigbench/movie_recommendation/1.0.0/c5da5ac497141c7435da10444495b8577405d4ed01e524265b144a7063718c0c)


In [140]:
ds

Dataset({
    features: ['inputs', 'targets', 'multiple_choice_targets', 'multiple_choice_scores', 'idx'],
    num_rows: 400
})

In [146]:
load_dataset(
    "json",
    data_files=f"../data/baseline_test/crash_blossom.json",
    split="train"
)

Found cached dataset json (/Users/yeeb/.cache/huggingface/datasets/json/default-8b0cc3f6e2fb8516/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)


Dataset({
    features: ['inputs', 'idx', 'targets', 'multiple_choice_scores', 'true_idx', 'is_generated', 'multiple_choice_targets'],
    num_rows: 22
})

In [127]:
ds = load_dataset('tasksource/bigbench', 'crash_blossom')["train"]

Found cached dataset bigbench (/Users/yeeb/.cache/huggingface/datasets/tasksource___bigbench/crash_blossom/1.0.0/c5da5ac497141c7435da10444495b8577405d4ed01e524265b144a7063718c0c)


  0%|          | 0/2 [00:00<?, ?it/s]

In [128]:
ds[0]

{'inputs': 'Identify the part of speech (verb, adjective, noun, or preposition) of the specified word in the following headlines.\n\nIn the following sentence, what part of speech is bears? Sentence: Nevada poll bears good news for Obama\nA:',
 'targets': ['verb'],
 'multiple_choice_targets': ['verb', 'adjective', 'noun', 'preposition'],
 'multiple_choice_scores': [1, 0, 0, 0],
 'idx': 0}

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('huggyllama/llama-7b')
tokenizer.bos_token = "<s>"
tokenizer.eos_token = "</s>"
tokenizer.pad_token = tokenizer.eos_token

In [None]:
\n  choice:

In [129]:
def preprocess_dataset(
    ds: Dataset,
    tokenizer: PreTrainedTokenizerBase,
    method="direct",
    num_procs=8,
) -> Dataset:
    def remove_choices(s: str) -> str:
        choice_start = s.find("\n  choice:")
        if choice_start == -1:
            return s
        return s[:choice_start]

    def target_to_index(choices: List, target: List[str]) -> int:
        if not target:
            return -1
        try:
            index = choices.index(target[0])
            return index
        except ValueError:
            return -1

    def preprocess_function(examples: dict) -> dict:
        """
        * tokenizes dataset
        * token_type_ids are 1 where there are label tokens and 0 otherwise
        """
        inputs = [
            f"{remove_choices(inp)}\n"
            + "\n".join([f"choice {i}: {choice}" for i, choice in enumerate(choices)])
            + "\nanswer:"
            for inp, choices in zip(
                examples["inputs"], examples["multiple_choice_targets"]
            )
        ]

        targets = [
            f"{target_to_index(choices, target)}\n\n"
            for target, choices in zip(
                examples["targets"], examples["multiple_choice_targets"]
            )
        ]

        # swap inputs and targets if method is "channel"
        if method == "channel":
            inputs, targets = targets, inputs

        # tokenize inputs and targets, and prepare outputs dictionary
        input_tokenized = tokenizer(inputs, add_special_tokens=False)
        target_tokenized = tokenizer(targets, add_special_tokens=False)
        outputs = {
            "input_ids": [],
            "attention_mask": [],
            "token_type_ids": [],
        }

        # merge input and target tokens and prepare outputs
        for i in range(len(input_tokenized["input_ids"])):
            input_ids = input_tokenized["input_ids"][i]
            target_ids = target_tokenized["input_ids"][i]
            outputs["input_ids"].append(input_ids + target_ids)

            input_attention = input_tokenized["attention_mask"][i]
            target_attention = target_tokenized["attention_mask"][i]
            outputs["attention_mask"].append(input_attention + target_attention)

            input_token_type = [0] * (len(input_ids) + 1)
            target_token_type = [1] * (len(target_ids) - 3) + [0, 0]
            outputs["token_type_ids"].append(input_token_type + target_token_type)

            print(len(outputs["input_ids"][i]), len(outputs["attention_mask"][i]), len(outputs["token_type_ids"][i]))

        return outputs

    ds = ds.map(
        preprocess_function,
        batched=True,
        num_proc=num_procs,
    )
    return ds

In [130]:
pds = preprocess_dataset(ds, tokenizer, method='direct')

Map (num_proc=8):   0%|          | 0/22 [00:00<?, ? examples/s]

94 94 94
90 90 90
96 9688  9688 
88
96 96 96
94 9493  9493
 8993 
8995  8995
 9295 
9289  899293
  999389  
9399
 8999
 89 89
93 93 93
94 94 94
95 95 95
93 93 93
102 102 102
91 91 91
89 89 89
89 89 89


In [131]:
pds[0]['token_type_ids']

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0]

In [132]:
print(tokenizer.decode(pds[0]['input_ids']))

Identify the part of speech (verb, adjective, noun, or preposition) of the specified word in the following headlines.

In the following sentence, what part of speech is bears? Sentence: Nevada poll bears good news for Obama
A:
choice 0: verb
choice 1: adjective
choice 2: noun
choice 3: preposition
answer: 0




In [133]:
for i in range(len(pds)):
    print(len(pds[i]['input_ids']), len(pds[i]['attention_mask']), len(pds[i]['token_type_ids']))

94 94 94
90 90 90
96 96 96
88 88 88
96 96 96
94 94 94
93 93 93
95 95 95
89 89 89
89 89 89
92 92 92
99 99 99
93 93 93
89 89 89
93 93 93
94 94 94
95 95 95
93 93 93
102 102 102
91 91 91
89 89 89
89 89 89


In [134]:
tokenizer.decode(np.array(pds[0]['input_ids'])*np.array(pds[0]['token_type_ids']))

'<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>0<unk><unk>'

In [99]:
@dataclass
class ICLCollator:
    tokenizer: PreTrainedTokenizerBase
    k_examples: int = 16
    max_length: int = 2048
    return_tensors: str = "pt"
    for_eval: bool = False

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        """
        * creates batches for in context/few shot learning
        * length of [features] should be (k_examples * batch_size)
        * if for_eval create a labels field
        """
        batch = {"input_ids": [], "attention_mask": [], "token_type_ids": []}

        if self.for_eval:
            # if collation for evaluation, features is a List[List[Dict[str, Any]]]
            # where the inner list contains our k_examples, so flatten it
            features = list(itertools.chain.from_iterable(features))

        for i in range(0, len(features), self.k_examples):
            batch["input_ids"].append(
                list(
                    itertools.chain.from_iterable(
                        example["input_ids"]
                        for example in features[i : i + self.k_examples]
                    )
                )[: self.max_length]
            )
            batch["attention_mask"].append(
                list(
                    itertools.chain.from_iterable(
                        example["attention_mask"]
                        for example in features[i : i + self.k_examples]
                    )
                )[: self.max_length]
            )
            batch["token_type_ids"].append(
                list(
                    itertools.chain.from_iterable(
                        example["token_type_ids"]
                        for example in features[i : i + self.k_examples]
                    )
                )[: self.max_length]
            )

        batch = self.tokenizer.pad(
            batch,
            padding="longest",
            max_length=self.max_length,
            pad_to_multiple_of=None,
            return_tensors=self.return_tensors,
        )

        if self.for_eval:
            batch["labels"] = batch["input_ids"].clone()
            batch["labels"] *= batch["token_type_ids"]

        return batch

In [100]:
k=3
batch_size=8

In [102]:
collate_fn = ICLCollator(tokenizer, k_examples=k)
dataloader = DataLoader(pds, batch_size=k*batch_size, collate_fn=collate_fn)

In [103]:
data = next(iter(dataloader))

You're using a LlamaTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


In [104]:
data['input_ids'].shape

torch.Size([8, 253])

In [105]:
print(tokenizer.batch_decode(data['input_ids'], skip_special_tokens=True)[0])

Find a movie similar to Batman, The Mask, The Fugitive, Pretty Woman:
choice 0: The Front Page
choice 1: Maelstrom
choice 2: The Lion King
choice 3: Lamerica
answer: 2

 Find a movie similar to The Sixth Sense, The Matrix, Forrest Gump, The Shawshank Redemption:
choice 0: Street Fighter II The Animated Movie
choice 1: The Sheltering Sky
choice 2: The Boy Who Could Fly
choice 3: Terminator 2 Judgment Day
answer: 3

 Find a movie similar to Schindler's List, Braveheart, The Silence of the Lambs, Tombstone:
choice 0: Orlando
choice 1: Guilty of Romance
choice 2: Forrest Gump
choice 3: All the Real Girls
answer: 2




In [106]:
print(tokenizer.batch_decode(data['input_ids'] * data['token_type_ids'], skip_special_tokens=True)[0])

232
