In [2]:
%load_ext autoreload
%autoreload 2

from torch.utils.data.dataset import IterableDataset

In [3]:
class MyCustomDataset:
    def __init__(self, files):
        self.files = files

    def __iter__(self):
        for file_path in self.files:
            with open(file_path) as f:
                for line in f:
                    yield line.strip("\n")

In [4]:
from glob import glob

tweet_files = glob("../data/filtered_tweets/*.txt")[:1]
my_ds = MyCustomDataset(
    files=tweet_files
)

In [5]:
from tqdm.auto import tqdm
for x in tqdm(my_ds):
    pass

0it [00:00, ?it/s]

In [6]:
from datasets import load_dataset

dataset = load_dataset("text", data_files=tweet_files)

Using custom data configuration default-574c5e67937ee6f5
Reusing dataset text (/home/jmperez/.cache/huggingface/datasets/text/default-574c5e67937ee6f5/0.0.0/e16f44aa1b321ece1f87b07977cc5d70be93d69b20486d6dacd62e12cf25c9a5)


In [7]:
from tqdm.auto import tqdm
for x in tqdm(dataset["train"]):
    pass

  0%|          | 0/5233769 [00:00<?, ?it/s]

Es impresionante. Anda muy mal de la otra forma. Y con más datos es todavía peor



In [8]:
4 / 118

0.03389830508474576

Ese es el ratio

## Con transformación

In [9]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("../models/twerto-base-uncased")

Veamos qué pasa con la paralelización


In [18]:
tweets = ["@usuario este es un tweet sarasa"] * 500_000

for tw in tqdm(tweets):
    tokenizer(tw)

  0%|          | 0/100000 [00:00<?, ?it/s]

In [20]:
%%time

tokenizer(tweets); None

CPU times: user 10.5 s, sys: 3.92 s, total: 14.5 s
Wall time: 2.9 s


In [21]:
class NaiveProcessedDataset:
    def __init__(self, files, batch_size=1024):
        self.files = files

    def __iter__(self):
        for file_path in self.files:
            with open(file_path) as f:
                for line in f:
                    yield tokenizer(line.strip("\n"))

In [22]:
my_ds = NaiveProcessedDataset(
    files=tweet_files
)

In [23]:
for x in tqdm(my_ds):
    pass

0it [00:00, ?it/s]

In [41]:
tokenizer.model_max_length = 128

def tokenize(batch, padding='max_length'):
    return tokenizer(batch['text'], padding=padding, truncation=True, return_special_tokens_mask=True)

dataset["train"].set_transform(tokenize)

In [61]:
from torch.utils.data import IterableDataset
class BatchProcessedDataset(IterableDataset):
    def __init__(self, files, batch_size=1024):
        self.files = files
        self.batch_size = batch_size

    def __iter__(self):
        for file_path in self.files:
            with open(file_path) as f:

                next_batch = [x.strip("\n") for _, x in zip(range(self.batch_size), f)]
                
                while next_batch:
                    tokenized_batch = tokenizer(next_batch, padding='max_length', truncation=True, return_special_tokens_mask=True)
                    for encoding in tokenized_batch.encodings:
                        yield {
                            "input_ids": encoding.ids,
                            "token_type_ids": encoding.type_ids,
                            "attention_mask": encoding.attention_mask,
                            "special_tokens_mask": encoding.special_tokens_mask
                        }
                    next_batch = [x.strip("\n") for _, x in zip(range(self.batch_size), f)]
my_ds = BatchProcessedDataset(
    files=tweet_files,
    batch_size=1024,
)

In [39]:
my_ds = BatchProcessedDataset(
    files=tweet_files,
    batch_size=1024,
)
for x in tqdm(my_ds):
    pass

0it [00:00, ?it/s]

con 1024 (hice algunas pruebas) parece tener la mejor performance

Veamos si podemos engancharlo en el trainer

Ok, tenemos que emular esto

In [44]:
encoding = (next(iter(my_ds)))

encoding

Encoding(num_tokens=128, attributes=[ids, type_ids, tokens, offsets, attention_mask, special_tokens_mask, overflowing])

In [45]:

ret = {
    "input_ids": encoding.ids,
    "token_type_ids": encoding.type_ids,
    "attention_mask": encoding.attention_mask,
    "special_tokens_mask": encoding.special_tokens_mask
}

In [47]:
ds_ex = dataset["train"][0]

{k:(ret[k] == ds_ex[k]) for k in ds_ex}

{'input_ids': True,
 'token_type_ids': True,
 'attention_mask': True,
 'special_tokens_mask': True}

## set_transform vs custom

In [49]:
my_ds = BatchProcessedDataset(
    files=tweet_files,
    batch_size=1024,
)
for x in tqdm(my_ds):
    pass

0it [00:00, ?it/s]

In [None]:
for ex in tqdm(dataset["train"]):
    pass

Todo indica que tarda ḿucho más (al menos 3 veces más!)

In [74]:
from transformers import DataCollatorForLanguageModeling
from torch.utils.data.dataloader import DataLoader

batch_size = 1024

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15,
)

my_dataloader = DataLoader(
    my_ds,
    batch_size=batch_size,
    collate_fn=data_collator,
)

ds_dataloader = DataLoader(
    dataset["train"],
    batch_size=batch_size,
    collate_fn=data_collator,
)

In [75]:
from tqdm.auto import tqdm
for batch in tqdm(zip(my_dataloader, range(500))): #total is aprox
    pass

0it [00:00, ?it/s]

In [76]:
from tqdm.auto import tqdm
for batch in tqdm(zip(ds_dataloader, range(500))): #total is aprox
    pass

0it [00:00, ?it/s]

Bueno, pareciera andar *mucho* mejor nuestro dataset... habrá que probar en el finetuning a ver qué onda