In [1]:
import tiktoken
import torch
import util
import os

from torch.utils.data import Dataset, DataLoader

In [2]:
text = util.text_corpus()
tokenizer = tiktoken.get_encoding('gpt2')

In [3]:
print(tokenizer.__class__)

<class 'tiktoken.core.Encoding'>


In [4]:
text_test = "Hello<|endoftext|> > !!!"

In [5]:
encoded = tokenizer.encode(text_test, allowed_special={'<|endoftext|>'})

In [6]:
class GPTDatasetV1(Dataset):
    _input_ids: list[torch.Tensor]
    _target_ids: list[torch.Tensor]

    def __init__(
        self,
        content: str,
        tokenizer: tiktoken.core.Encoding,
        context_window_size: int,
        stride: int
    ):
        self._input_ids = []
        self._target_ids = []

        token_ids = tokenizer.encode(content)

        for i in range(0, len(token_ids) - context_window_size, stride):
            input_chunk = token_ids[i:i + context_window_size]
            target_chunk = token_ids[i + 1: i + 1 + context_window_size]
            self._input_ids.append(torch.tensor(input_chunk))
            self._target_ids.append(torch.tensor(target_chunk))

    def __len__(self) -> int:
        return len(self._input_ids)

    def __getitem__(self, idx: int) -> (torch.Tensor, torch.Tensor):
        return self._input_ids[idx], self._target_ids[idx]

In [7]:
def create_dataloader_v1(
    content: str,
    batch_size: int,
    context_window: int, 
    stride: int = 1,
    shuffle: bool = True,
    drop_last: bool = True,
    num_workers: int = os.cpu_count(),
) -> DataLoader:
    tokenizer = tiktoken.get_encoding('gpt2')
    dataset = GPTDatasetV1(content, tokenizer, context_window, stride)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
    )

In [8]:
data_loader = create_dataloader_v1(text, batch_size=8, context_window=256, stride=128)
data_iter = iter(data_loader)

In [9]:
next(data_iter)

[tensor([[12917,   905,    11,  ...,   550,  1813,   510],
         [  526,   198,   198,  ...,    13,  1320,   338],
         [16153,   312,   328,  ...,  2982,   257,  2366],
         ...,
         [  550,  1908,   477,  ...,    11,   290,  3088],
         [  628,   198,   198,  ..., 22988,   198,   198],
         [  617,   286,   616,  ...,   616,  1243,    30]]),
 tensor([[  905,    11,  5025,  ...,  1813,   510,   465],
         [  198,   198,     1,  ...,  1320,   338,   262],
         [  312,   328,  3780,  ...,   257,  2366,  9662],
         ...,
         [ 1908,   477,   616,  ...,   290,  3088,   617],
         [  198,   198,  8585,  ...,   198,   198,    35],
         [  286,   616, 49025,  ...,  1243,    30,  1119]])]

In [11]:
x, y = next(data_iter)
print((x.shape, y.shape))

(torch.Size([8, 256]), torch.Size([8, 256]))
