In [1]:
# default_exp data_manager

In [2]:
#export
from datasets import load_dataset
import numpy as np
import torch
from transformers import GPT2Tokenizer
from kirby.run_params import RunParams

# Data Manager

> Prepares and Loads data

In [12]:
#export
class DataManager():
    def __init__(self, run_params):
        self.run_params = run_params
        self.block_size = 128
    # Load, Tokenize, and Augment data
    def prepare_data(self):
        train_ds, val_ds = map(self.prepare_ds, ('train', 'valid'))
        return train_ds, val_ds
    
    def prepare_ds(self, split):
        tokenizer = GPT2Tokenizer.from_pretrained(self.run_params.model)
        tokenizer.pad_token = tokenizer.eos_token 
        split = f'{split}[:{self.run_params.batch_size if self.run_params.debug else f"{self.run_params.data_set_percentage}%"}]'
        ds = load_dataset('text', data_files=self.run_params.data_files, split=split)
        ds = ds.map(self.tokenize, batched=True, num_proc=4, remove_columns=['text'], fn_kwargs={'tokenizer':tokenizer})
        ds = ds.map(
            self.group_texts,
            batched=True,
            batch_size=self.block_size,
            num_proc=self.run_params.num_workers
        )
        ds.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        return ds

    # Tokenize a sequence
    def tokenize(self, x, tokenizer=None):
        tokens = tokenizer(x['text'])
        return tokens
    
    def group_texts(self, examples):
        import pdb; pdb.set_trace()
        # Concatenate all texts.
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
        total_length = (total_length // self.block_size) * self.block_size
        if total_length == 0:
            total_length = self.block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result
    

# Testing

In [13]:
# Creation
from datasets import Dataset
run_params = RunParams()
data_manager = DataManager(run_params)
train_ds, valid_ds = data_manager.prepare_data()
assert isinstance(train_ds, Dataset)
import pdb; pdb.set_trace()
assert isinstance(valid_ds, Dataset)

Using custom data configuration default-04bff418a63932f2
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)






> [0;32m<ipython-input-12-f201a9877dc5>[0m(34)[0;36mgroup_texts[0;34m()[0m
[0;32m     32 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     33 [0;31m        [0;31m# Concatenate all texts.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 34 [0;31m        [0mconcatenated_examples[0m [0;34m=[0m [0;34m{[0m[0mk[0m[0;34m:[0m [0msum[0m[0;34m([0m[0mexamples[0m[0;34m[[0m[0mk[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m)[0m [0;32mfor[0m [0mk[0m [0;32min[0m [0mexamples[0m[0;34m.[0m[0mkeys[0m[0;34m([0m[0;34m)[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     35 [0;31m        [0mtotal_length[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mconcatenated_examples[0m[0;34m[[0m[0mlist[0m[0;34m([0m[0mexamples[0m[0;34m.[0m[0mkeys[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0

ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(35)[0;36mgroup_texts[0;34m()[0m
[0;32m     33 [0;31m        [0;31m# Concatenate all texts.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     34 [0;31m        [0mconcatenated_examples[0m [0;34m=[0m [0;34m{[0m[0mk[0m[0;34m:[0m [0msum[0m[0;34m([0m[0mexamples[0m[0;34m[[0m[0mk[0m[0;34m][0m[0;34m,[0m [0;34m[[0m[0;34m][0m[0;34m)[0m [0;32mfor[0m [0mk[0m [0;32min[0m [0mexamples[0m[0;34m.[0m[0mkeys[0m[0;34m([0m[0;34m)[0m[0;34m}[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 35 [0;31m        [0mtotal_length[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mconcatenated_examples[0m[0;34m[[0m[0mlist[0m[0;34m([0m[0mexamples[0m[0;34m.[0m[0mkeys[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m][0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     36 [0;31m        [0;31m# We drop the small remainder, we could add padding if the model supported it inst

ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(38)[0;36mgroup_texts[0;34m()[0m
[0;32m     36 [0;31m        [0;31m# We drop the small remainder, we could add padding if the model supported it instead of this drop, you can[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     37 [0;31m            [0;31m# customize this part to your needs.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 38 [0;31m        [0mtotal_length[0m [0;34m=[0m [0;34m([0m[0mtotal_length[0m [0;34m//[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m)[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0;32mif[0m [0mtotal_length[0m [0;34m==[0m [0;36m0[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m            [0mtotal_length[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(39)[0;36mgroup_texts[0;34m()[0m
[0;32m     37 [0;31m            [0;31m# customize this part to your needs.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     38 [0;31m        [0mtotal_length[0m [0;34m=[0m [0;34m([0m[0mtotal_length[0m [0;34m//[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m)[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 39 [0;31m        [0;32mif[0m [0mtotal_length[0m [0;34m==[0m [0;36m0[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     40 [0;31m            [0mtotal_length[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m        [0;31m# Split by chunks of max_len.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(40)[0;36mgroup_texts[0;34m()[0m
[0;32m     38 [0;31m        [0mtotal_length[0m [0;34m=[0m [0;34m([0m[0mtotal_length[0m [0;34m//[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m)[0m [0;34m*[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     39 [0;31m        [0;32mif[0m [0mtotal_length[0m [0;34m==[0m [0;36m0[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m            [0mtotal_length[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m        [0;31m# Split by chunks of max_len.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m        result = {
[0m


ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(42)[0;36mgroup_texts[0;34m()[0m
[0;32m     40 [0;31m            [0mtotal_length[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m        [0;31m# Split by chunks of max_len.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m        result = {
[0m[0;32m     43 [0;31m            [0mk[0m[0;34m:[0m [0;34m[[0m[0mt[0m[0;34m[[0m[0mi[0m [0;34m:[0m [0mi[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0mtotal_length[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m            [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0mconcatenated_examples[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  total_length


128


ipdb>  examples


{'attention_mask': [[], [1, 1, 1, 1, 1, 1, 1, 1]], 'input_ids': [[], [796, 569, 18354, 7496, 17740, 6711, 796, 220]]}


ipdb>  self.block_size


128


ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(44)[0;36mgroup_texts[0;34m()[0m
[0;32m     42 [0;31m        result = {
[0m[0;32m     43 [0;31m            [0mk[0m[0;34m:[0m [0;34m[[0m[0mt[0m[0;34m[[0m[0mi[0m [0;34m:[0m [0mi[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0mtotal_length[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 44 [0;31m            [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0mconcatenated_examples[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m        }
[0m[0;32m     46 [0;31m        [0mresult[0m[0;34m[[0m[0;34m"labels"[0m[0;34m][0m [0;34m=[0m [0mresult[0m[0;34m[[0m[0;34m"input_ids"[0m[0;34m][0m[0;34m.[0m[0mcopy[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34

ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(42)[0;36mgroup_texts[0;34m()[0m
[0;32m     40 [0;31m            [0mtotal_length[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m        [0;31m# Split by chunks of max_len.[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 42 [0;31m        result = {
[0m[0;32m     43 [0;31m            [0mk[0m[0;34m:[0m [0;34m[[0m[0mt[0m[0;34m[[0m[0mi[0m [0;34m:[0m [0mi[0m [0;34m+[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m][0m [0;32mfor[0m [0mi[0m [0;32min[0m [0mrange[0m[0;34m([0m[0;36m0[0m[0;34m,[0m [0mtotal_length[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0mblock_size[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     44 [0;31m            [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0mconcatenated_examples[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(46)[0;36mgroup_texts[0;34m()[0m
[0;32m     44 [0;31m            [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0mconcatenated_examples[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m        }
[0m[0;32m---> 46 [0;31m        [0mresult[0m[0;34m[[0m[0;34m"labels"[0m[0;34m][0m [0;34m=[0m [0mresult[0m[0;34m[[0m[0;34m"input_ids"[0m[0;34m][0m[0;34m.[0m[0mcopy[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     47 [0;31m        [0;32mreturn[0m [0mresult[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m[0;34m[0m[0m
[0m


ipdb>  result


{'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1]], 'input_ids': [[796, 569, 18354, 7496, 17740, 6711, 796, 220]]}


ipdb>  n


> [0;32m<ipython-input-12-f201a9877dc5>[0m(47)[0;36mgroup_texts[0;34m()[0m
[0;32m     44 [0;31m            [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0mconcatenated_examples[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m        }
[0m[0;32m     46 [0;31m        [0mresult[0m[0;34m[[0m[0;34m"labels"[0m[0;34m][0m [0;34m=[0m [0mresult[0m[0;34m[[0m[0;34m"input_ids"[0m[0;34m][0m[0;34m.[0m[0mcopy[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 47 [0;31m        [0;32mreturn[0m [0mresult[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


--Return--
{'attention_mask': [[1, 1, 1, 1, 1, 1, ...]], 'input_ids': [[796, 569, 18354, 7496, 17740, 6711, ...]], 'labels': [[796, 569, 18354, 7496, 17740, 6711, ...]]}
> [0;32m<ipython-input-12-f201a9877dc5>[0m(47)[0;36mgroup_texts[0;34m()[0m
[0;32m     44 [0;31m            [0;32mfor[0m [0mk[0m[0;34m,[0m [0mt[0m [0;32min[0m [0mconcatenated_examples[0m[0;34m.[0m[0mitems[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     45 [0;31m        }
[0m[0;32m     46 [0;31m        [0mresult[0m[0;34m[[0m[0;34m"labels"[0m[0;34m][0m [0;34m=[0m [0mresult[0m[0;34m[[0m[0;34m"input_ids"[0m[0;34m][0m[0;34m.[0m[0mcopy[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 47 [0;31m        [0;32mreturn[0m [0mresult[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     48 [0;31m[0;34m[0m[0m
[0m


ipdb>  n


> [0;32m/usr/local/lib/python3.8/dist-packages/datasets/arrow_dataset.py[0m(1378)[0;36mdoes_function_return_dict[0;34m()[0m
[0;32m   1376 [0;31m            [0;34m""" Does the function returns a dict. """[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1377 [0;31m            [0mfn_args[0m [0;34m=[0m [0;34m[[0m[0minputs[0m[0;34m][0m [0;32mif[0m [0minput_columns[0m [0;32mis[0m [0;32mNone[0m [0;32melse[0m [0;34m[[0m[0minputs[0m[0;34m[[0m[0mcol[0m[0;34m][0m [0;32mfor[0m [0mcol[0m [0;32min[0m [0minput_columns[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1378 [0;31m            processed_inputs = (
[0m[0;32m   1379 [0;31m                [0mfunction[0m[0;34m([0m[0;34m*[0m[0mfn_args[0m[0;34m,[0m [0mindices[0m[0;34m,[0m [0;34m**[0m[0mfn_kwargs[0m[0;34m)[0m [0;32mif[0m [0mwith_indices[0m [0;32melse[0m [0mfunction[0m[0;34m([0m[0;34m*[0m[0mfn_args[0m[0;34m,[0m [0;34m**[0m[0mfn_kwargs[0m[0;34m)[0m[0;34

ipdb>  c


IndexError: index out of bounds

In [None]:
train_ds[1]

  return torch.tensor(x, **format_kwargs)


{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'input_ids': tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256])}