In [None]:
# default_exp data_manager

In [None]:
#export
from datasets import load_dataset
import numpy as np
import torch
from transformers import GPT2Tokenizer
from kirby.run_params import RunParams

# Data Manager

> Prepares and Loads data

In [None]:
#export
class DataManager():
    def __init__(self, run_params):
        self.run_params = run_params
    # Load, Tokenize, and Augment data
    def prepare_data(self):
        train_ds, val_ds = map(self.prepare_ds, ('train', 'valid'))
        return train_ds, val_ds
    
    def prepare_ds(self, split):
        tokenizer = GPT2Tokenizer.from_pretrained(self.run_params.model)
        tokenizer.pad_token = tokenizer.eos_token 
        split = f'{split}[:{self.run_params.batch_size if self.run_params.debug else f"{self.run_params.data_set_percentage}%"}]'
        ds = load_dataset('text', data_files=self.run_params.data_files, split=split)
        ds = ds.map(self.tokenize, batched=True, fn_kwargs={'tokenizer':tokenizer})
        ds.set_format(type='torch', columns=['input_ids', 'attention_mask'])
        return ds

    # Tokenize a sequence
    def tokenize(self, x, tokenizer=None):
        tokens = tokenizer(
            x['text'],
            max_length=self.run_params.seq_length,
            truncation=True,
            padding=True)
        return tokens
    

# Testing

In [None]:
# Creation
from datasets import Dataset
run_params = RunParams()
data_manager = DataManager(run_params)
train_ds, valid_ds = data_manager.prepare_data()
assert isinstance(train_ds, Dataset)
assert isinstance(valid_ds, Dataset)

Using custom data configuration default


Downloading and preparing dataset text/default-04bff418a63932f2 (download: Unknown size, generated: Unknown size, post-processed: Unknown size, total: Unknown size) to /root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab...


HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(width='20px'), max=1.0), HTML(value=''…

Dataset text downloaded and prepared to /root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab. Subsequent calls will reuse this data.


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




Using custom data configuration default
Reusing dataset text (/root/.cache/huggingface/datasets/text/default-04bff418a63932f2/0.0.0/daf90a707a433ac193b369c8cc1772139bb6cca21a9c7fe83bdd16aad9b9b6ab)


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [None]:
train_ds[1]

  return torch.tensor(x, **format_kwargs)


{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0]),
 'input_ids': tensor([  796,   569, 18354,  7496, 17740,  6711,   796,   220, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256,
         50256, 50256])}