## Train Othello-GPT and save to `ckpts`

Use `jupyter nbconvert --execute --to notebook --allow-errors --ExecutePreprocessor.timeout=-1 train_gpt_othello.ipynb --inplace --output ckpts/checkpoint.ipynb` to run in background

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
# make deterministic
from mingpt.utils import set_seed
set_seed(44)

In [5]:
import os
import math
import time
from tqdm import tqdm
import numpy as np
from copy import deepcopy
import torch
import torch.nn as nn
from torch.nn import functional as F
from data import get_othello
from data.othello import permit, start_hands, OthelloBoardState, permit_reverse
from mingpt.dataset import CharDataset
from mingpt.utils import sample
from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig

In [6]:
synthetic_or_championship = True  # True for training on the synthetic dataset

In [7]:
othello = get_othello(ood_num=-1, data_root=None if synthetic_or_championship else "data/othello_championship", wthor=True)

from torch.utils.data import DataLoader
# Initialize the CharDataset with your data
train_dataset = CharDataset(othello)  # Replace `othello` with your actual data
# Define the batch size
batch_size = 64
# Create the DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size, n_layer=8, n_head=8, n_embd=512)
model = GPT(mconf)

Mem Used: 3.647 GB: 100%|███████████████████████████████████████████████████████████████████████████████| 65/65 [00:08<00:00,  7.78it/s]


Deduplicating...
Deduplicating finished with 6499157 games left
Using 20 million for training, 0 for validation
Dataset created has 6499157 sequences, 61 unique words.


In [None]:
import torch
import torch.nn as nn
import os
import time

# Assuming the model and optimizer are already defined
# model = ...
# optimizer = ...
import torch.optim as optim
# Define the learning rate
learning_rate = 5e-4
# Create the optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Custom loss function for causal language modeling
class CausalLanguageModelingLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, outputs, targets):
        # Flatten the output and target tensors to fit CrossEntropyLoss requirements
        logits = outputs[:, :-1, :].reshape(-1, outputs.size(-1))
        shifted_targets = targets[:, 1:].reshape(-1)
        return self.loss_fn(logits, shifted_targets)

loss_fn = CausalLanguageModelingLoss()

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

max_epochs = 250
checkpoint_step_interval = 20  # Interval for saving checkpoints

step_global = 0  # Initialize a global step counter

for epoch in range(max_epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        # Extract logits if outputs is a tuple
        logits = outputs if not isinstance(outputs, tuple) else outputs[0]
        loss = loss_fn(logits, targets)
        loss.backward()
        optimizer.step()

        if (step_global + 1) % checkpoint_step_interval == 0:
            ckpt_path = f"./ckpts/gpt_step_{step_global+1}_{time.strftime('_%Y%m%d_%H%M%S')}.ckpt"
            save_checkpoint({
                'epoch': epoch + 1,
                'global_step': step_global + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, filename=ckpt_path)
        
        step_global += 1

        if batch_idx % 100 == 0:  # Log information every 100 steps
            print(f"Epoch: {epoch+1}, Step: {step_global}, Loss: {loss.item()}")


# Notes:
# - Ensure your 'train_loader' is correctly set up to provide batches of (inputs, targets).
# - This script assumes that 'model' and 'optimizer' are already defined and configured.
# - Adjust 'device' setting as per your setup.
# - The checkpoint includes the epoch, global step, model state, and optimizer state for comprehensive saving.


Epoch: 1, Step: 1, Loss: 4.220006942749023
=> Saving checkpoint
=> Saving checkpoint


In [8]:
max_epochs = 250
# initialize a trainer instance and kick off training
t_start = time.strftime("_%Y%m%d_%H%M%S")
tconf = TrainerConfig(
    max_epochs=max_epochs, 
    batch_size=512*4,  # assuming 8 GPU's
    learning_rate=5e-4,
    lr_decay=True, 
    warmup_tokens=len(train_dataset)*train_dataset.block_size*5, 
    final_tokens=len(train_dataset)*train_dataset.block_size*max_epochs,
    num_workers=0, 
    ckpt_path=f"./ckpts/gpt_at{t_start}.ckpt", 
)
trainer = Trainer(model, train_dataset, None, tconf)
device = trainer.device
print(t_start)
trainer.train()

_20240212_171716


  0%|                                                                                                          | 0/1587 [01:26<?, ?it/s]


KeyboardInterrupt: 

## Or load trained model from `ckpts`

In [11]:
checkpoint_path = "./ckpts/gpt_step_20__20240212_173548.ckpt" if synthetic_or_championship else "./ckpts/gpt_championship.ckpt"
checkpoint = torch.load(checkpoint_path, map_location="cpu")  # Load checkpoint to CPU to avoid CUDA memory issues

# Extract the model state dict
model_state_dict = checkpoint['model_state_dict']

# Now load the model state dict
load_res = model.load_state_dict(model_state_dict)

# After loading, move the model to the appropriate device
if torch.cuda.is_available():
    device = torch.cuda.current_device()
    model = model.to(device)

#load_res = model.load_state_dict(torch.load("./ckpts/gpt_step_20__20240212_173548.ckpt" if synthetic_or_championship else "./ckpts/gpt_championship.ckpt"))
#if torch.cuda.is_available():
#    device = torch.cuda.current_device()
#    model = model.to(device)

## Validate it: for what percentage of all partial games in validation set, the top-1 prediction is legal

In [None]:
if not synthetic_or_championship:  # for GPT trained on both datasets, use the validation set of synthetic for validation
    othello = get_othello(ood_num=-1, data_root=None, wthor=True)

In [None]:
total_nodes = 0
success_nodes = 0

bar = tqdm(othello.val[:1000])
for whole_game in bar:
    length_of_whole_game = len(whole_game)
    for length_of_partial_game in range(1, length_of_whole_game):
        total_nodes += 1
        context = whole_game[:length_of_partial_game]
        x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None, ...].to(device)
        y = sample(model, x, 1, temperature=1.0)[0]
        completion = [train_dataset.itos[int(i)] for i in y if i != -1]
        try:
            OthelloBoardState().update(completion, prt=False)
        except Exception:
#             fail_nodes.append([permit_reverse(_) for _ in context])
            pass
        else:
            success_nodes += 1
    bar.set_description(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")
print(f"{success_nodes/total_nodes*100:.2f}% pass rate: {success_nodes}/{total_nodes} among all searched nodes")

In [None]:
1 - success_nodes/total_nodes