# test training flow

## setup

### imports

In [132]:
from typing import List, Dict
import itertools

from fenparser import FenParser

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

from transformers import BertModel, BertTokenizer

from tqdm import tqdm

### magic numbers (and strings)

In [120]:
FEN_FILE = 'fens.txt'
MAX_LEN = 64
BATCH_SIZE = 32
LR = 3e-4
EPOCHS = 3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

## data stuff

### load data

In [121]:
with open(FEN_FILE, 'r') as f:
    fens = [row for row in f]

print(fens[:5])

['2kr2nr/pp1n1bpp/1bpq1p2/3p3P/1P1N2P1/2P1PP2/PB2N1B1/R2QK2R b - KQ\n', '2rq1rk1/pp1nppbp/5np1/2pP4/2P1b1P1/1P2QN1P/P3BP1B/2RN1RK1 w - -\n', 'rnbqkbn1/pppp4/3r1pp1/4p1Np/2B1P3/3P1Q2/PPP2PPP/RNB1K2R w - KQq\n', 'r2qkb1r/p2npppp/2pp1n2/1p6/4P3/1P3Q1N/1PPP1PPP/RNB1K2R w - KQkq\n', '5r2/B3p1b1/6kp/6n1/6p1/2P5/PP1r2PP/R3R1K1 w - -\n']


### create datasets

In [122]:
def fen_to_targets(f: str) -> torch.Tensor:
    p = FenParser(f)
    b = list(itertools.chain.from_iterable(p.parse()))
    pieces = ['p', 'r', 'n', 'b', 'q', 'k', 'P', 'R', 'N', 'B', 'Q', 'K']
    out = torch.zeros(768)
    for i in range(12):
        target = torch.zeros(64)
        target[[pieces[i] == x for x in b]] = 1
        out[64*i:64*(i+1)] = target
    return out

In [123]:
class FenDataset(Dataset):
    def __init__(self, fens: List[str], tokenizer: BertTokenizer, max_len: int = 64):
        self.fens = fens
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self) -> int:
        return len(self.fens)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        inputs = self.tokenizer.encode_plus(
            self.fens[idx],
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=True
        )
        targets = fen_to_targets(self.fens[idx])
        return {
            'input_ids': torch.tensor(inputs['input_ids']),
            'attention_mask': torch.tensor(inputs['attention_mask']),
            'token_type_ids': torch.tensor(inputs['token_type_ids']),
            'targets': targets
        }

In [124]:
split = int(len(fens) * 0.9)
train_split = fens[:split]
test_split = fens[split:]

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

train_ds = FenDataset(train_split, tokenizer, MAX_LEN)
test_ds = FenDataset(test_split, tokenizer, MAX_LEN)

print(train_ds[0])

{'input_ids': tensor([  101,   123,  1377,  1197,  1477,  1179,  1197,   120,  4329,  1475,
         1179,  1475,  1830,  8661,   120,   122,  1830,  1643,  4426,  1475,
         1643,  1477,   120,   124,  1643,  1495,  2101,   120,   122,  2101,
         1475,  2249,  1477,  2101,  1475,   120,   123,  2101,  1475, 20923,
         1477,   120,   153,  2064,  1477,  2249,  1475,  2064,  1475,   120,
          155,  1477,  4880,  2428,  1477,  2069,   171,   118,   148,  4880,
          102,     0,     0,     0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]), 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]), 'targets



### create dataloaders

In [125]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)

print(next(iter(train_dl)))

{'input_ids': tensor([[ 101,  187, 1477,  ...,    0,    0,    0],
        [ 101,  187, 1475,  ...,    0,    0,    0],
        [ 101,  124, 1377,  ...,    0,    0,    0],
        ...,
        [ 101,  129,  120,  ...,    0,    0,    0],
        [ 101,  187, 1527,  ...,    0,    0,    0],
        [ 101,  187, 1477,  ...,    0,    0,    0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'targets': tensor([[0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
 

## model setup

### model architecture

In [126]:
class FenModel(nn.Module):
    def __init__(self, out_size: int = 768, dropout: bool = False):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-cased')
        self.bert.pooler = nn.Linear(768, 768)
        self.linear = nn.Linear(768, out_size)
        if not dropout:
            self.dropout = nn.Dropout(0.2)
        else:
            self.dropout = None

    def forward(self, x: Dict[str, torch.Tensor]):
        x = self.bert(
            input_ids=x['input_ids'],
            attention_mask=x['attention_mask'],
            token_type_ids=x['token_type_ids'],
            return_dict=False
        )[0][:,0]
        if self.dropout is not None:
            x = self.dropout(x)
        x = self.linear(x)
        return x

In [127]:
model = FenModel()
print(model(next(iter(train_dl))))

tensor([[-0.1578, -0.3257, -0.7287,  ...,  0.2109,  0.2511,  0.5525],
        [-0.1092, -0.1571, -0.5159,  ..., -0.1429,  0.2435,  0.6094],
        [-0.2528, -0.2661, -0.6674,  ...,  0.2251,  0.5902,  0.5964],
        ...,
        [-0.1927, -0.1783, -0.7268,  ...,  0.1172,  0.1255,  0.4930],
        [-0.2626, -0.4377, -0.3554,  ...,  0.0797,  0.3587,  0.4315],
        [-0.0535, -0.5219, -0.5574,  ...,  0.0124,  0.3870,  0.4190]],
       grad_fn=<AddmmBackward0>)


## training setup

### optimizer

In [128]:
optimizer = AdamW(model.parameters(), lr=LR)

### training loop

In [135]:
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
for e in tqdm(range(EPOCHS)):
    model.train()
    train_loss = 0.0
    train_acc = 0.0

    for data in train_dl:
        data = {k: v.to(DEVICE) for k, v in data.items()}
        optimizer.zero_grad()
        labels_pred = model(data)
        batch_loss = F.binary_cross_entropy_with_logits(labels_pred, data['targets'])
        train_loss = train_loss + batch_loss.item()

        labels_pred_binary = torch.zeros_like(data['targets'])
        labels_pred_binary[labels_pred > 0] = 1.0
        batch_acc = torch.sum(labels_pred_binary == data['targets'])
        train_acc += batch_acc

        batch_loss.backward()
        optimizer.step()
    train_losses.append(train_loss / len(train_loader))
    train_accuracies.append(train_acc / (768 * batch_size * len(train_loader)))

    model.eval()
    val_loss = 0.0
    val_acc = 0.0

    with torch.no_grad():
        for data in test_dl:
            data = {k: v.to(DEVICE) for k, v in data.items()}
            labels_pred = model(data)
            v_batch_loss = loss(labels_pred, data['targets'])
            val_loss = val_loss + v_batch_loss.item()

            labels_pred_binary = torch.zeros_like(data['targets'])
            labels_pred_binary[lables_pred > 0] = 1.0
            batch_acc = torch.sum(labels_pred_binary == data['targets'])
            val_acc += batch_acc
        val_losses.append(val_loss / len(val_loader))
        val_accuracies.append(val_acc /(768 * batch_size * len(val_loader)))

  0%|                                                                              | 0/3 [00:00<?, ?it/s]

foobar
foobar
foobar


  0%|                                                                              | 0/3 [00:08<?, ?it/s]


KeyboardInterrupt: 