In [1]:
from datasets import load_dataset

ds = load_dataset("roneneldan/TinyStories")

In [2]:
# display the 10th row, wish is a short story by itself.
ds["train"]["text"][10]

'Once upon a time, there was a big car named Dependable. He had a very important job. Dependable would take a family to the park every day. The family had a mom, dad, and a little girl named Lily. They all had a lot of love for each other.\n\nOne day, when they got to the park, they saw a big sign that said, "Fun Race Today!" The family was very excited. They knew that Dependable was very fast and could win the race. So, they decided to join the race.\n\nThe race started, and Dependable went very fast. The other cars tried to catch up, but Dependable was too quick. In the end, Dependable won the race! The family was so happy and proud of their car. They knew that their love for each other and their trust in Dependable made them win the race. And from that day on, they had even more fun at the park, knowing that they had the fastest and most dependable car around.'

In [3]:
!pip install tiktoken 

Collecting tiktoken
  Using cached tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl.metadata (6.7 kB)
Using cached tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl (993 kB)
Installing collected packages: tiktoken
Successfully installed tiktoken-0.12.0


In [4]:
import tiktoken
import os
import numpy as np
from tqdm.auto import tqdm

In [6]:
encoding = tiktoken.get_encoding("gpt2")

In [7]:
def process(record):
    ids = encoding.encode_ordinary(record["text"])
    out = {"ids": ids, "len": len(ids)}
    return out

In [10]:
process(ds["train"][10])

{'ids': [7454,
  2402,
  257,
  640,
  11,
  612,
  373,
  257,
  1263,
  1097,
  3706,
  37947,
  540,
  13,
  679,
  550,
  257,
  845,
  1593,
  1693,
  13,
  37947,
  540,
  561,
  1011,
  257,
  1641,
  284,
  262,
  3952,
  790,
  1110,
  13,
  383,
  1641,
  550,
  257,
  1995,
  11,
  9955,
  11,
  290,
  257,
  1310,
  2576,
  3706,
  20037,
  13,
  1119,
  477,
  550,
  257,
  1256,
  286,
  1842,
  329,
  1123,
  584,
  13,
  198,
  198,
  3198,
  1110,
  11,
  618,
  484,
  1392,
  284,
  262,
  3952,
  11,
  484,
  2497,
  257,
  1263,
  1051,
  326,
  531,
  11,
  366,
  24629,
  12588,
  6288,
  2474,
  383,
  1641,
  373,
  845,
  6568,
  13,
  1119,
  2993,
  326,
  37947,
  540,
  373,
  845,
  3049,
  290,
  714,
  1592,
  262,
  3234,
  13,
  1406,
  11,
  484,
  3066,
  284,
  4654,
  262,
  3234,
  13,
  198,
  198,
  464,
  3234,
  2067,
  11,
  290,
  37947,
  540,
  1816,
  845,
  3049,
  13,
  383,
  584,
  5006,
  3088,
  284,
  4929,
  510,
  11,
  475,
  37

In [11]:
if not os.path.exists("train.bin"):
    tokenized = ds.map(
        process,
        remove_columns=["text"],
        desc="tokenization of data",
        num_proc=8
    )

tokenization of data (num_proc=8):   0%|          | 0/2119719 [00:00<?, ? examples/s]

tokenization of data (num_proc=8):   0%|          | 0/21990 [00:00<?, ? examples/s]

In [12]:
tokenized

DatasetDict({
    train: Dataset({
        features: ['ids', 'len'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['ids', 'len'],
        num_rows: 21990
    })
})

In [13]:
for split, dset in tokenized.items():
        arr_len = np.sum(dset['len'], dtype=np.uint64)
        filename = f'{split}.bin'
        dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16)
        arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,))
        total_batches = 1024

        idx = 0
        for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'):
            # Batch together samples for faster write
            batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy')
            arr_batch = np.concatenate(batch['ids'])
            # Write into mmap
            arr[idx : idx + len(arr_batch)] = arr_batch
            idx += len(arr_batch)
        arr.flush()

writing train.bin:   0%|          | 0/1024 [00:00<?, ?it/s]

writing validation.bin:   0%|          | 0/1024 [00:00<?, ?it/s]