### create vizualization tech

In [1]:
import matplotlib.pyplot as plt

def plot_heatmap(matrix):
    """
    Plots an 8x8 heatmap using the provided matrix of floats ranging from 0 to 1.

    :param matrix: A nested list (8x8) of floats (0-1)
    """
    plt.imshow(matrix, cmap='hot', interpolation='nearest')
    plt.colorbar()
    plt.show()

### get model and make a prediction

In [2]:
import torch
from train_fen import FenModel

model = FenModel()
model.load_state_dict(torch.load('models/fenmodel.pt'))
model.eval()

FenModel(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=T

## get example fens

In [4]:
with open('fens.txt', 'r') as f:
    fens = [row for row in f]

print(fens[:5])

['2kr2nr/pp1n1bpp/1bpq1p2/3p3P/1P1N2P1/2P1PP2/PB2N1B1/R2QK2R b - KQ\n', '2rq1rk1/pp1nppbp/5np1/2pP4/2P1b1P1/1P2QN1P/P3BP1B/2RN1RK1 w - -\n', 'rnbqkbn1/pppp4/3r1pp1/4p1Np/2B1P3/3P1Q2/PPP2PPP/RNB1K2R w - KQq\n', 'r2qkb1r/p2npppp/2pp1n2/1p6/4P3/1P3Q1N/1PPP1PPP/RNB1K2R w - KQkq\n', '5r2/B3p1b1/6kp/6n1/6p1/2P5/PP1r2PP/R3R1K1 w - -\n']


In [5]:
def fen_to_targets(f: str) -> torch.Tensor:
    p = FenParser(f)
    b = list(itertools.chain.from_iterable(p.parse()))
    pieces = ['p', 'r', 'n', 'b', 'q', 'k', 'P', 'R', 'N', 'B', 'Q', 'K']
    out = torch.zeros(BOARD_LEN)
    for i in range(12):
        target = torch.zeros(FEN_LEN)
        target[[pieces[i] == x for x in b]] = 1
        out[FEN_LEN*i:FEN_LEN*(i+1)] = target
    return out

In [None]:
        inputs = self.tokenizer.encode_plus(
            self.fens[idx],
            max_length=self.max_len,
            pad_to_max_length=True,
            truncation=True,
            return_token_type_ids=True
        )
        targets = fen_to_targets(self.fens[idx])
        return {
            'input_ids': torch.tensor(inputs['input_ids']),
            'attention_mask': torch.tensor(inputs['attention_mask']),
            'token_type_ids': torch.tensor(inputs['token_type_ids']),
            'targets': targets
        }

In [None]:
split = int(len(fens) * 0.9)
train_split = fens[:split]
test_split = fens[split:]

tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

train_ds = FenDataset(train_split, tokenizer, FEN_LEN)
val_ds = FenDataset(test_split, tokenizer, FEN_LEN)

print(train_ds[0])