# reproduce data load in train_gpt2.py

**learning**

L1. although each raw data text starts with eco, it is not true for each mini-batch

L2. data loader only saves one shard in memory, which is the point of using shard



**question**

Q1. data loader self.current_position += B * T or B * T +1 

A1. B*T

The one-token overlap between y(t-1) and x(t) ensures that no token transition is skipped. By advancing the window by B*T but using B*T + 1 tokens per batch, the model captures all sequential relationships in the data — every token and its next-token prediction is used exactly once.

Example: shard = [1,3,4,2,10,11] where B=1 and T = 2
 

  first batch: x [1,3] y [3,4]

second batch: x [4,2] y [2,10]

In [1]:
import os 
from typing import Tuple
import numpy as np
import torch
from config import local_dir


data_dir = local_dir

try:
    contents = os.listdir(data_dir)
except:
    print("no local dir created, please run 1_fineweb_re1.ipynb first")

if len(contents) == 0:
    print("no shards created, please run 1_fineweb_re1.ipynb first")
else:
    print(f"found {len(contents)} shards")

example_shard_path = contents[0]

found 11 shards


In [5]:
# load tokens to tensor
data  = np.load(os.path.join(data_dir, example_shard_path))

data_tensor =torch.tensor(data, dtype=torch.long)

data_tensor.shape == data.shape


# create function
def _load_tokens(filename): # L1

    """function to load tokens from a file and convert to tensor, used in data loader
    """

    if data_dir not in filename:
        filename = os.path.join(data_dir, filename)

    data = np.load(filename)
    # data = data.astype(np.int32) # for values strictly within the uint16 range (0–65535), converting directly to torch.long without the intermediate np.int32 works fine.
    data_tensor = torch.tensor(data, dtype=torch.long)

    return data_tensor

In [8]:
class DataLoader:

    def __init__(self, B:int, T:int, data_dir:str, split:str):

        self.B = B
        self.T = T
        self.data_dir = data_dir
        self.split = split
        self.all_shards_paths = [ os.path.join(data_dir, filename) for filename in os.listdir(data_dir) if split in filename]
        self.reset()

    def reset(self):
        
        self.current_position = 0
        self.current_shard_index = 0
        self.current_shard = _load_tokens(self.all_shards_paths[self.current_shard_index])  # L2

    def next_batch(self) -> Tuple[torch.Tensor, torch.Tensor]:

        delta_n_tokens = self.B*self.T+1 
        batch = self.current_shard[self.current_position: self.current_position + delta_n_tokens]
        x = batch[:-1].view(self.B,self.T)        
        y = batch[1:].view(self.B,self.T)
        self.current_position += delta_n_tokens - 1 # Q1

        # evaluate if need to load next shard
        if self.current_position + delta_n_tokens > len(self.current_shard):
            self.current_shard_index = (self.current_shard_index + 1) % len(self.all_shards_paths)
            self.current_shard = _load_tokens(self.all_shards_paths[self.current_shard_index])
            self.current_position = 0
        
        return x, y

In [4]:
data_loader = DataLoader(B=2, T=4, data_dir=data_dir, split='train')
x, y = data_loader.next_batch()
print(x)
print(y)

tensor([[10416,   351,   663, 18875],
        [ 6770,   357,    69,   451]])
tensor([[  351,   663, 18875,  6770],
        [  357,    69,   451,   286]])
