# Chapter 5. Pretraining on unlabeled data

In [1]:
import util
import tiktoken
import torch

from torch.utils.data import DataLoader
from datetime import datetime

In [2]:
data = util.text_corpus()
split_ratio = 0.8
split_idx = int(len(data) * split_ratio)
train_data, val_data = data[:split_idx], data[split_idx:]

In [3]:
tokenizer: tiktoken.Encoding = tiktoken.get_encoding('gpt2')

In [4]:
config = util.Config.gpt2_xl()
torch.manual_seed(123)
train_loader = util.create_dataloader_v1(
    train_data,
    batch_size=2,
    context_window=config.context_length,
    stride=config.context_length,
    shuffle=True,
    drop_last=True,
    tokenizer=tokenizer,
)
val_loader = util.create_dataloader_v1(
    val_data,
    batch_size=2,
    context_window=config.context_length,
    stride=config.context_length,
    shuffle=False,
    drop_last=False,
    tokenizer=tokenizer,
)

In [5]:
for x, y in train_loader:
    print(f'train x: {x.shape}, y: {y.shape}')
for x, y in val_loader:
    print(f'val x: {x.shape}, y: {y.shape}')

train x: torch.Size([2, 1024]), y: torch.Size([2, 1024])
train x: torch.Size([2, 1024]), y: torch.Size([2, 1024])
val x: torch.Size([1, 1024]), y: torch.Size([1, 1024])


In [6]:
def calc_loss_batch(
    input_batch: torch.Tensor,
    target_batch: torch.Tensor,
    model: util.GPTModel,
    device: torch.device,
) -> torch.Tensor:
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)
    loss = torch.nn.functional.cross_entropy(
        logits.flatten(0, 1), target_batch.flatten()
    )
    return loss

In [7]:
model = util.GPTModel(config)
device: torch.device = util.auto_device()
model = model.to(device)

Using CUDA backend.


In [8]:
def calc_loss_data_loader(
    data_loader: DataLoader,
    model: util.GPTModel,
    device: torch.device,
    num_batches: int | None = None,
) -> float | None:
    loss: float = 0
    
    if len(data_loader) <= 0:
        return None
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
    for i, (x, y) in enumerate(data_loader):
        if i >= num_batches:
            break
        loss += calc_loss_batch(x, y, model, device).item()
    return loss / num_batches

with torch.no_grad():
    print(f'Train loss: {calc_loss_data_loader(train_loader, model, device):.3f}')
    print(f'Val loss: {calc_loss_data_loader(val_loader, model, device):.3f}')

Train loss: 10.978
Val loss: 10.976


In [9]:
def evaluate_model(
    model: util.GPTModel,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    eval_iter: int,
) -> tuple[float, float]:
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_data_loader(
            train_loader, model, device, num_batches=eval_iter,
        )
        val_loss = calc_loss_data_loader(
            val_loader, model, device, num_batches=eval_iter,
        )
    model.train()
    return train_loss, val_loss

def generate_and_print_sample(
    model: util.GPTModel,
    tokenizer: tiktoken.Encoding,
    device: torch.device,
    start_context: str,
    max_len: int,
) -> None:
    generated = util.predict_text(model, device, tokenizer, start_context, max_len=max_len)
    print(generated.replace('\n', ' ')) # for easier reading

def train_model_simple(
    model: util.GPTModel,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    num_epochs: int,
    eval_freq: int,
    eval_iter: int,
    start_context: str,
    tokenizer: tiktoken.Encoding
) -> None:
    train_losses: list[float] = []
    val_losses: list[float] = []
    track_tokens_seen: list[float] = []
    tokens_seen, global_step = 0, -1

    for epoch in range(num_epochs):
        model.train()
        for x, y in train_loader:
            optimizer.zero_grad()
            loss = calc_loss_batch(x, y, model, device)
            loss.backward()
            optimizer.step()
            tokens_seen += x.numel()
            global_step += 1

            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(
                    model, train_loader, val_loader, device, eval_iter
                )
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                track_tokens_seen.append(tokens_seen)
                print(
                    f'[{datetime.now().strftime("%H:%M:%S")}] Epoch {epoch + 1:04d} (step {global_step + 1:06d}): '
                    f'train loss {train_loss:.3f}, val loss {val_loss:.3f}'
                )
                generate_and_print_sample(model, tokenizer, device, start_context, max_len=8)
    return train_losses, val_losses, track_tokens_seen

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)
num_epochs = 1_000
train_losses, val_losses, tokens_seen = train_model_simple(
    model, train_loader, val_loader, optimizer, device,
    num_epochs, eval_freq=50, eval_iter=1,
    start_context="He", tokenizer=tokenizer,
)

[16:28:21] Epoch 0001 (step 000001): train loss 9.200, val loss 9.463
He,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,
[16:28:53] Epoch 0026 (step 000051): train loss 5.766, val loss 8.308
He I, I I      
[16:29:25] Epoch 0051 (step 000101): train loss 5.603, val loss 8.299
He the of of of the of the of
[16:29:56] Epoch 0076 (step 000151): train loss 4.586, val loss 8.696
He't- was, and I I "
[16:30:28] Epoch 0101 (step 000201): train loss 2.846, val loss 10.248
He" life of of who, to his
[16:31:00] Epoch 0126 (step 000251): train loss 0.891, val loss 11.717
He" widow Clauderaft Ver Ver widow
[16:31:31] Epoch 0151 (step 000301): train loss 0.088, val loss 12.393
He	 liked Claudeanielive ad existedrees
[16:32:03] Epoch 0176 (step 000351): train loss 0.041, val loss 12.635
He	 liked Claude Claude Ver women left	
[16:32:35] Epoch 0201 (step 000401): train loss 0.008, val loss 12.800
He of discussion only first prism, on any
[16:33:07] Epoch 0226 (step 000451): train loss 0.006, val loss 12.801
