In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import torch.optim as optim
import torch.nn as nn
import json
from sklearn.model_selection import train_test_split
import chess
import ast
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim

import math
import numpy as np

In [None]:
class SequenceDataset(Dataset):
    def __init__(self, src_sequences, tar_sequences, src_padd_idx, tar_padd_idx, max_src_len, max_tar_len):
        self.src_sequences = [torch.tensor(seq) for seq in src_sequences]
        self.tar_sequences = [torch.tensor(seq) for seq in tar_sequences]
        self.src_padd_idx = src_padd_idx
        self.tar_padd_idx = tar_padd_idx
        self.max_src_len = max_src_len
        self.max_tar_len = max_tar_len

    def __len__(self):
        return len(self.src_sequences)

    def __getitem__(self, idx):
        src_seq = self.src_sequences[idx]
        tar_seq = self.tar_sequences[idx]

        src_seq = torch.nn.functional.pad(src_seq, (0, self.max_src_len - len(src_seq)), value=self.src_padd_idx)
        tar_seq = torch.nn.functional.pad(tar_seq, (0, self.max_tar_len - len(tar_seq)), value=self.tar_padd_idx)

        return src_seq, tar_seq

In [None]:
def load_data(csv_path, test_size=0.2, random_state=42):
    df = pd.read_csv(csv_path)
    sequences = df['sequence'].apply(ast.literal_eval).tolist()
    targets = df['target'].apply(ast.literal_eval).tolist()
    seq_train, seq_val, tar_train, tar_val = train_test_split(
        sequences, targets, test_size=test_size, random_state=random_state
    )
    return seq_train, seq_val, tar_train, tar_val


def find_max_length(sequences):
    return max(len(seq) for seq in sequences)

csv_path = 'tokenized.csv'
src_padd_idx = 46
tar_padd_idx = 70
batch=128

seq_train, seq_val, tar_train, tar_val = load_data(csv_path)

max_src_len = max(find_max_length(seq_train), find_max_length(seq_val))
max_tar_len = max(find_max_length(tar_train), find_max_length(tar_val))

dataset_train = SequenceDataset(seq_train, tar_train, src_padd_idx, tar_padd_idx, max_src_len, max_tar_len)
dataset_val = SequenceDataset(seq_val, tar_val, src_padd_idx, tar_padd_idx, max_src_len, max_tar_len)

dataloader_train = DataLoader(dataset_train, batch_size=batch, shuffle=True, drop_last=True)
dataloader_val = DataLoader(dataset_val, batch_size=batch, shuffle=True, drop_last=True)

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout_p, max_len):
        super().__init__()

        self.dropout = nn.Dropout(dropout_p)

        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).view(-1, 1)
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-math.log(10000.0)) / dim_model)

        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)

        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)

        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)

    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        seq_len = token_embedding.size(0)
        batch_size = token_embedding.size(1)
        pos_encoding = self.pos_encoding[:seq_len, :].expand(seq_len, batch_size, -1)
        return self.dropout(token_embedding + pos_encoding)

In [None]:
class Transformer(nn.Module):

    def __init__(
        self,
        src_num_tokens,
        tar_num_tokens,
        embed_size,
        num_heads,
        num_encoder_layers,
        num_decoder_layers,
        dropout_p,
    ):
        super().__init__()

        self.model_type = "Transformer"
        self.embed_size = embed_size

        self.positional_encoder = PositionalEncoding(
            dim_model=embed_size, dropout_p=dropout_p, max_len=5000
        )

        self.src_embedding = nn.Embedding(src_num_tokens, embed_size)
        self.tar_embedding = nn.Embedding(tar_num_tokens, embed_size)
        self.transformer = nn.Transformer(
            d_model=embed_size,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout_p,
        )
        self.out = nn.Linear(embed_size, tar_num_tokens)

    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        src = self.src_embedding(src) * math.sqrt(self.embed_size)
        tgt = self.tar_embedding(tgt) * math.sqrt(self.embed_size)
        src = self.positional_encoder(src)
        tgt = self.positional_encoder(tgt)

        src = src.permute(1,0,2)
        tgt = tgt.permute(1,0,2)

        transformer_out = self.transformer(src, tgt, tgt_mask=tgt_mask, src_key_padding_mask=src_pad_mask, tgt_key_padding_mask=tgt_pad_mask)
        out = self.out(transformer_out)

        return out

    def get_tgt_mask(self, size) -> torch.tensor:
        mask = torch.tril(torch.ones(size, size) == 1)
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf'))
        mask = mask.masked_fill(mask == 1, float(0.0))
        return mask

    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        return (matrix == pad_token)

In [None]:
def train_loop(model, opt, loss_fn, dataloader):

    model.train()
    total_loss = 0

    for batch in dataloader:
        X, y = batch
        X, y = X.to(device), y.to(device)

        y_input = y[:,:-1]
        y_expected = y[:,1:]

        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)
        src_mask = model.create_pad_mask(X, 46).to(device)

        pred = model(X, y_input, tgt_mask, src_mask)

        pred = pred.permute(1, 2, 0)
        loss = loss_fn(pred, y_expected)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.detach().item()

    return total_loss / len(dataloader)

In [None]:
def validation_loop(model, loss_fn, dataloader):

    model.eval()
    total_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            X, y = batch
            X, y = X.to(device), y.to(device)

            y_input = y[:,:-1]
            y_expected = y[:,1:]

            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)
            src_mask = model.create_pad_mask(X, 46).to(device)

            pred = model(X, y_input, tgt_mask, src_mask)
            pred = pred.permute(1, 2, 0)
            loss = loss_fn(pred, y_expected)
            total_loss += loss.detach().item()

    return total_loss / len(dataloader)

In [None]:
def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    train_loss_list, validation_loss_list = [], []

    for epoch in range(epochs):

        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]

        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]

        print(f"ep: {epoch}, train loss: {train_loss:.4f}, val loss: {validation_loss:.4f}")

    return train_loss_list, validation_loss_list


In [None]:

device = "cuda" if torch.cuda.is_available() else "cpu"
vocab_path = "fen_vocab.json"
tar_vocab_path = "vocab.json"
with open(vocab_path, "r") as f:
    vocab = json.load(f)
with open(tar_vocab_path, "r") as f:
    tar_vocab = json.load(f)

src_vocab_size = len(vocab.items()) + 1
trg_vocab_size = len(tar_vocab.items()) + 1

model = Transformer(
    src_num_tokens=src_vocab_size, tar_num_tokens=trg_vocab_size, embed_size=128, num_heads=8, num_encoder_layers=5, num_decoder_layers=5, dropout_p=0.1
).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

train_loss_list, validation_loss_list = fit(model, opt, loss_fn, dataloader_train, dataloader_val, 10)