In [5]:
# default_exp data_manager
%load_ext autoreload
%autoreload 2

In [6]:
#export
from datasets import load_dataset
import numpy as np
import torch
from transformers import GPT2Tokenizer
from kirby.run_params import RunParams

# Data Manager

> Prepares and Loads data

In [7]:
#export
class DataManager():
    def __init__(self, run_params):
        self.run_params = run_params
        self.block_size = 128
    # Load, Tokenize, and Augment data
    def prepare_data(self):
        train_ds, val_ds = map(self.prepare_ds, ('train', 'valid'))
        return train_ds, val_ds
    
    def prepare_ds(self, split):
        tokenizer = GPT2Tokenizer.from_pretrained(self.run_params.model)
        tokenizer.pad_token = tokenizer.eos_token 
        split = f'{split}[:{self.run_params.batch_size*self.block_size if self.run_params.debug else f"{self.run_params.data_set_percentage}%"}]'
        ds = load_dataset(self.run_params.data_file_type, data_files=self.run_params.data_files, split=split)
        ds = ds.filter(function=self.criteria)
        ds = ds.map(self.tokenize, batched=True, num_proc=4, remove_columns=['text'], fn_kwargs={'tokenizer':tokenizer})
        ds = ds.map(
            self.group_texts,
            batched=True,
#             batch_size=self.block_size,
#             num_proc=self.run_params.num_workers
        )
        ds.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
        return ds

    # Tokenize a sequence
    def tokenize(self, x, tokenizer=None):
        tokens = tokenizer(x['text'])
        return tokens
    
    def group_texts(self, examples):
        # Concatenate all texts.
        concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
            # customize this part to your needs.
        total_length = (total_length // self.block_size) * self.block_size
        if total_length == 0:
            total_length = self.block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i : i + self.block_size] for i in range(0, total_length, self.block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result
    
    def criteria(self, x):
        x = x['text']
        # Remove blanks
        if len(x) == 1:
            return False
        # Remove headings
        if x[0:2] == ' =':
            return False
        return True
    

# Testing

In [8]:
# Creation
from datasets import Dataset
run_params = RunParams()
data_manager = DataManager(run_params)
train_ds, valid_ds = data_manager.prepare_data()
assert isinstance(train_ds, Dataset)
assert isinstance(valid_ds, Dataset)

Using custom data configuration default-d64f335cc8a13d66
Reusing dataset text (/home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-913970d3a7a8309a.arrow
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-d343d49cae84df08.arrow
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5/cache-bd75306c84ce1e89.arrow
Loading cached processed dataset at /home/rob/.cache/huggingface/datasets/text/default-d64f335cc8a13d66/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf2

In [9]:
train_ds[0]

{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]),
 'input_ids': tensor([ 2311,    73, 13090,   645,   569, 18354,  7496,   513,  1058,   791,
         47398, 17740,   357,  4960,  1058, 10545,   230,    99,   161,   254,
           112,  5641, 44444,  9202, 25084, 24440, 12675, 11839,    18,   837,
          6578,   764,   569, 18354,  7496,   286,   262, 30193,   513,  1267,
           837,  8811,  6412,   284,   355,   569, 18354,  7496, 17740,  6711,
          2354,  2869,   837,   318,   257, 16106,  2597,  2488,    12,    31,
          2712,  2008,   983,  4166,   416, 29490,   29