In [1]:
from typing import List, Dict, Any

from datasets import load_dataset
from datasets import DatasetDict
from transformers import AutoTokenizer
import random
import numpy as np

random.seed(93)
np.random.seed(93)

In [2]:
bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
list(bert_tokenizer(["The quick brown fox", "jumped over the lazy dog"]))

['input_ids', 'token_type_ids', 'attention_mask']

In [3]:
def sliding_window(lst: List[Any], max_len: int, overlaps: int):
    return [lst[i:i + max_len] for i in range(0, max_len, overlaps)]

def chunk_data(batch: List[Dict[str, Any]], max_seq=512, overlaps=256):
    tokenized_texts = bert_tokenizer(batch["text"])
    effective_len = max_seq - 2 # [CLS] and [SEP]
    result = {}
    for k in tokenized_texts:
        result[k] = []
        for items in tokenized_texts[k]:
            result[k].extend([*sliding_window(items, effective_len, overlaps)])
    return result

In [4]:
wikisource_ds = load_dataset("data", data_files={
    "train": "wikisource_train.parquet",
    "valid": "wikisource_valid.parquet"
})
wikisource_ds

DatasetDict({
    train: Dataset({
        features: ['id', 'url', 'title', 'text'],
        num_rows: 187451
    })
    valid: Dataset({
        features: ['id', 'url', 'title', 'text'],
        num_rows: 20828
    })
})

In [5]:
wikisource_ds = wikisource_ds.map(chunk_data, batched=True, remove_columns=wikisource_ds["train"].column_names)
wikisource_ds

Map:   0%|          | 0/187451 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (3314 > 512). Running this sequence through the model will result in indexing errors


Map:   0%|          | 0/20828 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 374902
    })
    valid: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 41656
    })
})

In [7]:
wikisource_ds["train"].to_parquet("data/wikisource_train_512.parquet")
wikisource_ds["valid"] = wikisource_ds["valid"].to_parquet("data/wikisource_eval_512.parquet")

Creating parquet from Arrow format:   0%|          | 0/375 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/42 [00:00<?, ?ba/s]