In [1]:
from gpt2 import GPT, GPTConfig # our GPT class
import time
import tiktoken
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
enc = tiktoken.get_encoding('gpt2')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # dynamic device

In [3]:
torch.manual_seed(13) # for reproducibility
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(13)

## Data Loader Lite

In [4]:
class DataLoaderLite:

    def __init__(self, B, T):

        self.B, self.T = B, T

        with open('data/input.txt', 'r') as file:
            text = file.read().replace('\n', '')
        
        enc = tiktoken.get_encoding('gpt2')
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens, dtype=torch.long, device=device)

        self.current_batch = 0
        self.number_of_batches = len(self.tokens) // (B * T)

        print(f'Loaded {len(self.tokens)} tokens, {self.number_of_batches} batches of size {B}x{T}')

    
    def next_batch(self):

        B, T = self.B, self.T

        buf = self.tokens[self.current_batch * B * T : (self.current_batch + 1) * B * T + 1]
        x = buf[:-1].view(B, T)
        y = buf[1:].view(B, T)

        self.current_batch += 1
        if self.current_batch >= self.number_of_batches:
            self.current_batch = 0
        
        return x, y

## Training and Timing

In [5]:
model = GPT(GPTConfig).to(device)

In [6]:
B, T = 2, 1024
data_loader = DataLoaderLite(B, T)

model.train();
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
for i in range(25):
    t0 = time.time()
    x, y = data_loader.next_batch()
    optimizer.zero_grad()
    logits, loss = model(x, y)
    loss.backward()
    optimizer.step()
    torch.cuda.synchronize() # wait for GPU to finish work
    t1 = time.time()
    dt = (t1 - t0) * 1000 # time difference in milliseconds
    thoughput = (B * T) / (t1 - t0) # tokens per second
    print(f"Step {i} | Loss {loss.item():.4f} | {dt:.1f} ms | {thoughput:.2f} tok/s")

Loaded 297884 tokens, 145 batches of size 2x1024
Step 0 | Loss 11.0395 | 552.2 ms | 3708.59 tok/s
Step 1 | Loss 10.1234 | 373.6 ms | 5481.76 tok/s
Step 2 | Loss 9.5636 | 375.3 ms | 5456.97 tok/s
Step 3 | Loss 9.3886 | 371.6 ms | 5510.60 tok/s
Step 4 | Loss 9.1647 | 372.4 ms | 5500.17 tok/s
Step 5 | Loss 8.8967 | 371.7 ms | 5510.55 tok/s
Step 6 | Loss 8.7383 | 373.3 ms | 5486.83 tok/s
Step 7 | Loss 8.3745 | 373.0 ms | 5491.34 tok/s
Step 8 | Loss 8.2043 | 372.8 ms | 5493.51 tok/s
Step 9 | Loss 7.9448 | 373.3 ms | 5485.48 tok/s
Step 10 | Loss 7.6409 | 374.6 ms | 5467.39 tok/s
Step 11 | Loss 7.5783 | 373.7 ms | 5480.69 tok/s
Step 12 | Loss 7.4408 | 374.1 ms | 5474.25 tok/s
Step 13 | Loss 7.3493 | 374.7 ms | 5465.77 tok/s
Step 14 | Loss 7.4255 | 373.7 ms | 5480.61 tok/s
Step 15 | Loss 7.0674 | 378.8 ms | 5406.97 tok/s
Step 16 | Loss 6.8323 | 375.3 ms | 5457.26 tok/s
Step 17 | Loss 7.0912 | 376.0 ms | 5447.38 tok/s
Step 18 | Loss 7.0000 | 373.8 ms | 5479.19 tok/s
Step 19 | Loss 7.0514 | 375.