# test training flow

## setup

### imports

In [1]:
from typing import List, Dict
import itertools

from fenparser import FenParser

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam

from transformers import BertModel, BertTokenizer

from tqdm.notebook import tqdm

### magic numbers (and strings)

In [2]:
FEN_FILE = 'fens.txt'
BATCH_SIZE = 32
LR = 3e-4
ITERATIONS = 100
FEN_LEN = 64
BOARD_LEN = 768
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Device is: {DEVICE}')

Device is: mps


## data stuff

### load data

In [3]:
with open(FEN_FILE, 'r') as f:
    fens = [row for row in f]

print(fens[:5])

['2kr2nr/pp1n1bpp/1bpq1p2/3p3P/1P1N2P1/2P1PP2/PB2N1B1/R2QK2R b - KQ\n', '2rq1rk1/pp1nppbp/5np1/2pP4/2P1b1P1/1P2QN1P/P3BP1B/2RN1RK1 w - -\n', 'rnbqkbn1/pppp4/3r1pp1/4p1Np/2B1P3/3P1Q2/PPP2PPP/RNB1K2R w - KQq\n', 'r2qkb1r/p2npppp/2pp1n2/1p6/4P3/1P3Q1N/1PPP1PPP/RNB1K2R w - KQkq\n', '5r2/B3p1b1/6kp/6n1/6p1/2P5/PP1r2PP/R3R1K1 w - -\n']


### create datasets

In [4]:
def fen_to_targets(f: str) -> torch.Tensor:
    p = FenParser(f)
    b = list(itertools.chain.from_iterable(p.parse()))
    pieces = ['p', 'r', 'n', 'b', 'q', 'k', 'P', 'R', 'N', 'B', 'Q', 'K']
    out = torch.zeros(BOARD_LEN)
    for i in range(12):
        target = torch.zeros(FEN_LEN)
        target[[pieces[i] == x for x in b]] = 1
        out[FEN_LEN*i:FEN_LEN*(i+1)] = target
    return out

In [5]:
class FenDataset(Dataset):
    def __init__(self, fens: List[str], tokenizer: BertTokenizer, max_len: int = 64):
        self.fens = fens
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.fens)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        inputs = self.tokenizer.encode_plus(
            self.fens[idx],
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=True
        )
        targets = fen_to_targets(self.fens[idx])
        return {
            'input_ids': torch.tensor(inputs['input_ids']),
            'attention_mask': torch.tensor(inputs['attention_mask']),
            'token_type_ids': torch.tensor(inputs['token_type_ids']),
            'targets': targets
        }

In [6]:
split = int(len(fens) * 0.9)
train_split = fens[:split]
test_split = fens[split:]

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

train_ds = FenDataset(train_split, tokenizer, FEN_LEN)
val_ds = FenDataset(test_split, tokenizer, FEN_LEN)

print(train_ds[0])

{'input_ids': tensor([  101,   123,  1377,  1197,  1477,  1179,  1197,   120,  4329,  1475,
         1179,  1475,  1830,  8661,   120,   122,  1830,  1643,  4426,  1475,
         1643,  1477,   120,   124,  1643,  1495,  2101,   120,   122,  2101,
         1475,  2249,  1477,  2101,  1475,   120,   123,  2101,  1475, 20923,
         1477,   120,   153,  2064,  1477,  2249,  1475,  2064,  1475,   120,
          155,  1477,  4880,  2428,  1477,  2069,   171,   118,   148,  4880,
          102,     0,     0,     0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]), 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'targets



### create dataloaders

In [8]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)

print(next(iter(train_dl)))

{'input_ids': tensor([[ 101,  129,  120,  ...,    0,    0,    0],
        [ 101,  187, 1179,  ...,    0,    0,    0],
        [ 101,  187, 1477,  ...,    0,    0,    0],
        ...,
        [ 101,  124, 1197,  ...,    0,    0,    0],
        [ 101,  187, 1477,  ...,    0,    0,    0],
        [ 101,  123, 1197,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'targets': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 1.],
 

## model setup

### model architecture

In [9]:
class FenModel(nn.Module):
    def __init__(self, out_size: int = BOARD_LEN, dropout: bool = False):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.bert.pooler = nn.Linear(BOARD_LEN, BOARD_LEN)
        self.linear = nn.Linear(BOARD_LEN, out_size)
        if not dropout:
            self.dropout = nn.Dropout(0.2)
        else:
            self.dropout = None

    def forward(self, x: Dict[str, torch.Tensor]):
        x = self.bert(
            input_ids=x['input_ids'],
            attention_mask=x['attention_mask'],
            token_type_ids=x['token_type_ids'],
            return_dict=False
        )[0][:,0]
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.linear(x)
        return x

In [18]:
# just for testing
model = FenModel()
data = next(iter(train_dl))
model(data)

tensor([[-3.4766e-01,  8.8781e-02, -6.9183e-01,  ..., -8.8613e-02,
         -4.9581e-01,  5.7306e-01],
        [-7.4112e-01, -1.0786e-01, -2.2075e-01,  ...,  1.2325e-01,
         -2.7112e-01,  1.6935e-01],
        [-2.5089e-01, -6.2655e-02, -9.7424e-01,  ...,  1.2594e-02,
         -1.6093e-01,  3.7004e-01],
        ...,
        [-1.0087e-04, -1.0422e-01, -7.0711e-01,  ...,  8.9315e-02,
         -4.0760e-01,  4.7180e-01],
        [-2.8052e-01, -9.2481e-02, -7.7559e-01,  ..., -1.0110e-01,
         -4.4745e-01,  4.2153e-01],
        [-3.2615e-01, -2.9795e-02, -8.4383e-01,  ...,  8.4649e-02,
         -4.9991e-01,  7.8660e-01]], grad_fn=<AddmmBackward0>)

## training setup

## training loop

*(We don't need to do full epochs as they're too big and unnecessary!)*

In [12]:
def train_model(model, optimizer, dataloader, iterations, device='cpu'):
    train_losses = []
    train_acc = []
    
    for i, data in tqdm(enumerate(dataloader), total=iterations):
        model.train()
        data = {k: v.to(device) for k, v in data.items()}
        optimizer.zero_grad()
        labels_pred = model(data)
        batch_loss = F.binary_cross_entropy_with_logits(labels_pred, data['targets'])
    
        batch_loss.backward()
        optimizer.step()
    
        train_losses.append(batch_loss.item())
    
        labels_pred_binary = torch.zeros_like(data['targets'])
        labels_pred_binary[labels_pred > 0] = 1.0
        train_acc.append(torch.mean((labels_pred_binary == data['targets']).float()).item())
    
        if (i+1) % (iterations // 20) == 0:
            print(f'Iteration: {i+1}, train_loss: {train_losses[-1]:.4f}, train_acc: {train_acc[-1]:.4f}')

        if i == iterations:
            return train_losses, train_acc

### validation test

In [13]:
def get_val(model, dataloader, iterations, device='cpu'):
    val_losses = []
    val_acc = []
    
    model.eval()
    with torch.no_grad():
        for i, data in tqdm(enumerate(dataloader), total=iterations):
            data = {k: v.to(device) for k, v in data.items()}
            optimizer.zero_grad()
            labels_pred = model(data)
            batch_loss = F.binary_cross_entropy_with_logits(labels_pred, data['targets'])
            
            val_losses.append(batch_loss.item())
            labels_pred_binary = torch.zeros_like(data['targets'])
            labels_pred_binary[labels_pred > 0] = 1.0
            val_acc.append(torch.mean((labels_pred_binary == data['targets']).float()).item())
            
            if i == iterations:
                return val_losses, val_acc

## train and examine the model!

### train the model!

In [16]:
# create model
model = FenModel().to(DEVICE)
optimizer = Adam(model.parameters(), lr=LR)

# train model
train_losses, train_acc = train_model(model, optimizer, train_dl, ITERATIONS, device=DEVICE)

# get validation stats
val_losses, val_acc = get_val(model, val_dl, ITERATIONS // 10, device=DEVICE)

  0%|          | 0/100 [00:00<?, ?it/s]

Iteration: 5, train_loss: 0.4725, train_acc: 0.8892
Iteration: 10, train_loss: 0.2792, train_acc: 0.9709
Iteration: 15, train_loss: 0.1840, train_acc: 0.9716
Iteration: 20, train_loss: 0.1320, train_acc: 0.9726
Iteration: 25, train_loss: 0.1137, train_acc: 0.9708
Iteration: 30, train_loss: 0.0958, train_acc: 0.9731
Iteration: 35, train_loss: 0.0956, train_acc: 0.9707
Iteration: 40, train_loss: 0.0955, train_acc: 0.9695
Iteration: 45, train_loss: 0.0847, train_acc: 0.9741
Iteration: 50, train_loss: 0.0852, train_acc: 0.9731
Iteration: 55, train_loss: 0.0845, train_acc: 0.9715
Iteration: 60, train_loss: 0.0839, train_acc: 0.9727
Iteration: 65, train_loss: 0.0845, train_acc: 0.9711
Iteration: 70, train_loss: 0.0784, train_acc: 0.9733
Iteration: 75, train_loss: 0.0876, train_acc: 0.9702
Iteration: 80, train_loss: 0.0857, train_acc: 0.9709
Iteration: 85, train_loss: 0.0848, train_acc: 0.9713
Iteration: 90, train_loss: 0.0828, train_acc: 0.9731
Iteration: 95, train_loss: 0.0860, train_acc: 0

  0%|          | 0/10 [00:00<?, ?it/s]

In [17]:
print(f'val loss: {torch.mean(torch.tensor(val_losses)).item():.4f}, val acc: {torch.mean(torch.tensor(val_acc)).item():.4f}')

val loss: 0.0827, val acc: 0.9721
