# Chapter 5. Pretraining on unlabeled data

In [1]:
import util
import tiktoken
import torch

from torch.utils.data import DataLoader

In [2]:
data = util.text_corpus()
split_ratio = 0.8
split_idx = int(len(data) * split_ratio)
train_data, val_data = data[:split_idx], data[split_idx:]

In [3]:
tokenizer: tiktoken.Encoding = tiktoken.get_encoding('gpt2')

In [4]:
config = util.Config.gpt2_small()
torch.manual_seed(123)
train_loader = util.create_dataloader_v1(
    train_data,
    batch_size=2,
    context_window=config.context_length,
    stride=config.context_length,
    shuffle=True,
    drop_last=True,
    tokenizer=tokenizer,
)
val_loader = util.create_dataloader_v1(
    val_data,
    batch_size=2,
    context_window=config.context_length,
    stride=config.context_length,
    shuffle=False,
    drop_last=False,
    tokenizer=tokenizer,
)

In [5]:
for x, y in train_loader:
    print(f'train x: {x.shape}, y: {y.shape}')
for x, y in val_loader:
    print(f'val x: {x.shape}, y: {y.shape}')

train x: torch.Size([2, 1024]), y: torch.Size([2, 1024])
train x: torch.Size([2, 1024]), y: torch.Size([2, 1024])
val x: torch.Size([1, 1024]), y: torch.Size([1, 1024])


In [6]:
def calc_loss_batch(
    input_batch: torch.Tensor,
    target_batch: torch.Tensor,
    model: util.GPTModel,
    device: torch.device,
) -> float:
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(
        logits.flatten(0, 1), target_batch.flatten()
    )
    return loss.item()

def calc_loss_data_loader(
    data_loader: DataLoader,
    model: util.GPTModel,
    device: torch.device,
    num_batches: int | None = None,
) -> float | None:
    loss: float = 0
    
    if len(data_loader) <= 0:
        return None
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (x, y) in enumerate(data_loader):
        if i >= num_batches:
            break
        loss += calc_loss_batch(x, y, model, device)
    return loss / num_batches

In [7]:
model = util.GPTModel(config)

In [8]:
device: torch.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
with torch.no_grad():
    print(f'Train loss: {calc_loss_data_loader(train_loader, model, device):.3f}')
    print(f'Val loss: {calc_loss_data_loader(val_loader, model, device):.3f}')

Train loss: 10.996
Val loss: 11.017
