# Experimenting with transformers
Transformers remain as a promising replacement of RNNs due to their parallelizability. However, RNNs are unique in their hidden state which tends to be uniquely useful for games. 



In [1]:
import pickle
import torch
import pandas as pd
import importlib
import numpy as np
import utils
import models

importlib.reload(utils)
from utils import *
importlib.reload(models)
from models import *

import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")



In [2]:
grouped_df = pd.read_csv('../data/haha-longer-mid.csv')

In [10]:

    
def df_to_data_with_special_tokens(df, fixed_window=False, fixed_window_size=16, sampling_rate=1, algebraic_notation=True):
    """
    Input: Dataframe of training data in which each row represents a full game played between players
    Output: List in which each item represents some game's history up until a particular move, List in the same order in which the associated label is the following move
    """
    subsequences = []
    next_moves = []
    vocab = VocabularyWithCLS()
    board = chess.Board()
    for game in df['moves']:
        moves = game.split()
        # Turn the game into a list of moves
        encoded_moves = [1]
        for move in moves:
            # Create a move object from the coordinate notation
            move_obj = chess.Move.from_uci(move)
            if move_obj not in board.legal_moves:
                break 
            else:
                if algebraic_notation:
                    algebraic_move = board.san(move_obj)
                    board.push(move_obj)
                    vocab.add_move(algebraic_move)
                    encoded_move = vocab.get_id(algebraic_move)
                    encoded_moves.append(encoded_move)
                else:
                    encoded_move = vocab.get_id(move)
                    encoded_moves.append(encoded_move)
        board.reset()
        # Turn the list of moves into subsequences
        for i in range(len(encoded_moves)-1):
            if random.uniform(0, 1) <= sampling_rate:
                subseq = encoded_moves[0:i+1]
                if fixed_window and len(subseq) > fixed_window_size:
                    subseq = subseq[-fixed_window_size:]
                label = encoded_moves[i+1]
                subsequences.append(subseq)
                next_moves.append(label)

    return subsequences, next_moves, vocab

In [161]:
trainX, trainY, vocab = df_to_data_with_special_tokens(grouped_df, fixed_window=True, sampling_rate=0.0875)
trainX, trainX_seqlengths  = pad_sequences(trainX)

In [20]:
# Function to load a memmap file
def load_memmap(filename, dtype, shape):
    # Load the memmap file with read-only mode
    return np.memmap(filename, dtype=dtype, mode='r', shape=shape)

In [21]:
# For trainX
dtype_trainX = np.int32  # or the correct dtype for your data
shape_trainX = (2161482, 750)  # replace with the correct shape
trainX = load_memmap('./../data/transformer/jan/trainX.memmap', dtype_trainX, shape_trainX)

# For trainY
dtype_trainY = np.int32 # or the correct dtype for your data
shape_trainY = (2161482, 7)  # replace with the correct shape
trainY = load_memmap('./../data/transformer/jan/trainY.memmap', dtype_trainY, shape_trainY)

with open('./../data/transformer/jan/vocab.pkl', 'rb') as inp:
    vocab = pickle.load(inp)

In [4]:
print(trainX[:1,:4])  

[[1 0 2 0]]


In [31]:

class TransformerDataset(Dataset):
    def __init__(self, sequences, labels):
        self.sequences, self.labels = sequences, labels
        
    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long), torch.tensor(self.labels[idx], dtype=torch.long)

In [162]:
dataset = TransformerDataset(trainX, trainY)
total_size = len(dataset)
# We're scaling the model size so let's bring in more data as well
train_size = int(0.97 * total_size)
val_size = int(total_size * 0.02)

# Create subsets for training and validation
train_dataset = Subset(dataset, range(0, train_size))
val_dataset = Subset(dataset, range(train_size, train_size + val_size))
print(train_size)
# Reload the data with particular batch size
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


2084538


In [156]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
class ChessTransformerDecoder(nn.Module):
    def __init__(self, vocab, d_model, nhead, num_layers, max_seq_length=750, dropout=0.1):
        super(ChessTransformerDecoder, self).__init__()
        self.vocab = vocab
        self.d_model =d_model
        self.vocab_size = len(vocab.id_to_move.keys())
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
        
        # Only decoder is needed for autoregressive models
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, 
                                       nhead=nhead, 
                                       dropout=dropout,
                                       batch_first=True,),
            num_layers=num_layers,
            norm=nn.LayerNorm(d_model)
        )
        
        self.fc = nn.Linear(d_model, self.vocab_size)

    def forward(self, tgt):
        # Memory is optional and could be used for incorporating encoder states in a hybrid model
        tgt_padding_mask = self.create_padding_mask(tgt).to(tgt.device)
        tgt_mask = self.square_subsequent_mask(tgt).to(tgt.device)

        # Embedding and Positional Encoding for tgt
        tgt_emb = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoder(tgt_emb)

        # Autoregressive decoding using the Transformer Decoder
        output = self.transformer_decoder(tgt_emb, memory=None,
                                          tgt_mask=tgt_mask,
                                          tgt_is_causal = True,
                                          tgt_key_padding_mask=tgt_padding_mask)
        
        # Linear layer to predict vocab
        output = self.fc(output)
        return output
    
    def create_padding_mask(self, src):
        PAD_IDX = 0
        src_padding_mask = (src == PAD_IDX)
        return src_padding_mask
    
    def square_subsequent_mask(self, tgt):
        """ Generate a square mask for the sequence to mask out subsequent positions. """
        sz = tgt.size(1)
        mask = torch.triu(torch.ones(sz, sz, device=tgt.device, dtype=torch.bool), diagonal=1)
        return mask


class ChessTransformer(nn.Module):
    def __init__(self, vocab, d_model, nhead, num_layers, max_seq_length=750, dropout=0.1):
        super(ChessTransformer, self).__init__()
        self.vocab = vocab
        self.d_model = d_model
        self.vocab_size = len(vocab.id_to_move.keys())
        self.embedding = nn.Embedding(self.vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers,
                                          batch_first=True)
        self.fc = nn.Linear(d_model, self.vocab_size)
        self.max_seq_length = max_seq_length

    def forward(self, src, tgt):

        # Create source padding mask
        src_padding_mask = self.create_padding_mask(src).to(src.device)
        tgt_padding_mask = self.create_padding_mask(tgt).to(tgt.device)
        # Embedding and Positional Encoding for src

        src_emb = self.embedding(src) * math.sqrt(self.d_model) # [batch_size, seq_len] -> [batch_size, seq_len, d_model]

        src_emb = self.pos_encoder(src_emb)
    
        # Transformer
        output = self.transformer(src_emb, src_emb, 
                                  src_key_padding_mask=src_padding_mask, 
                                  tgt_key_padding_mask=tgt_padding_mask,
                                  tgt_is_causal = True,
                                  src_is_causal = True, 
                                  src_mask = self.square_subsequent_mask(src),
                                  tgt_mask = self.square_subsequent_mask(tgt))
        # Linear layer to predict vocab
        output = self.fc(output)
        return output
    
    def create_padding_mask(self, src):
        PAD_IDX = 0
        src_padding_mask = (src == PAD_IDX)
        return src_padding_mask
    
    def square_subsequent_mask(self, tgt):
        """
        Generate a square mask for the sequence. The masked positions are filled with `True`.
        This mask ensures that for any position `i` in `tgt`, the decoder's self-attention mechanism
        can only attend to positions at or before `i`.
        
        Args:
            tgt (Tensor): The target input tensor of shape [batch_size, tgt_len].
        
        Returns:
            Tensor: A mask of shape [tgt_len, tgt_len] where `True` indicates that attention is not allowed.
        """
        # tgt_len could be derived from the second dimension of tgt
        tgt_len = tgt.size(1)
        
        # Generate an upper triangular matrix with `True` in the upper triangle
        mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.bool), diagonal=1)
        return mask
    
    def generate_sequence(self, src, src_length, start_symbol_id, sep_token_id, max_length=100):
        """
        Generate a sequence autoregressively using the trained transformer model.

        Args:
        - src (Tensor): The input source sequence tensor.
        - src_length (Tensor): The length of the source sequence.
        - start_symbol_id (int): The ID of the start symbol to begin generation.
        - sep_token_id (int): The ID of the SEP token for sequence termination.
        - max_length (int): Maximum length of the generated sequence to prevent infinite loops.

        Returns:
        - The generated sequence tensor.
        """
        self.eval()  # Ensure the model is in eval mode

        # Initialize the target sequence with the start symbol
        tgt = torch.tensor([start_symbol_id], dtype=torch.long).to(src.device)
        
        for _ in range(max_length):
            # Assuming src_length is a tensor with the length of src. Adjust as needed.

            # Perform a forward pass to get logits for the next token
            logits = self.forward(src, src_length, tgt, src)
            # Get the last token logits and apply softmax to get probabilities
            probs = torch.softmax(logits[:, -1, :], dim=-1)
            # get most likely token from probs
            next_token = torch.max(probs, 1)
            
            # Append the predicted token to the target sequence
            tgt = torch.cat((tgt, next_token), dim=1)
            
            # Check if the <SEP> token is generated
            if next_token.item() == sep_token_id:
                break

        return tgt


Functions for training

In [159]:
# Function to calculate top-3 accuracy
def top_3_accuracy(y_true, y_pred):
    top3 = torch.topk(y_pred, 3, dim=1).indices
    correct = top3.eq(y_true.view(-1, 1).expand_as(top3))
    return correct.any(dim=1).float().mean().item()

def train_decoder(device, model, train_loader, val_loader, criterion, optimizer, num_epochs, learn_decay):
    train_loss_values = []
    train_error = []
    val_loss_values = []
    val_error = []
    val_3_accuracy = []
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_correct = 0
        train_total = 0
        training_loss = 0.0
        count = 0
        for sequences, labels in train_loader:
            sequences, labels = sequences.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            # Forward Pass
            tgt_labels = torch.cat([sequences[:,1:],labels.unsqueeze(1)],dim=1).to(device)
            logits = model(sequences, sequences)
            loss = criterion(logits.view(-1, model.vocab_size), tgt_labels.view(-1))

            # Backpropogate & Optimize
            optimizer.zero_grad()
            loss.backward()
            # Clip it
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            optimizer.step()

            # For logging purposes
            training_loss += loss.item()
            # Calculate accuracy
            _, predicted = torch.max(logits, dim=2)
            train_correct += (predicted == tgt_labels).sum().item()
            train_total += tgt_labels.numel()
            count += 1
            if count % 1000 == 0:
                print(f'Epoch {epoch+1}, Batch: {count}| Training Loss: {training_loss/count}')
        # Validation
        model.eval()

        if val_loader is not None:
            with torch.no_grad():
                val_correct = 0
                val_total = 0
                val_top3_correct = 0
                val_loss = 0

                for sequences, labels in val_loader:
                    sequences, labels = sequences.to(device), labels.to(device)
                    tgt_labels = torch.cat([sequences[:,1:],labels.unsqueeze(1)],dim=1).to(device)
                    logits = model(sequences, sequences)
                    loss = criterion(logits.view(-1, model.vocab_size), tgt_labels.view(-1))

                    # For logging purposes
                    val_loss += loss.item()
                    # Calculate accuracy
                    _, predicted = torch.max(logits, dim=2)
                    val_correct += (predicted == tgt_labels).sum().item()
                    val_total += tgt_labels.numel()

                    val_loss_values.append(val_loss / len(val_loader))
                    val_accuracy = 100 * val_correct / val_total
                    val_top3_accuracy = 100 * val_top3_correct / val_total
                    val_error.append(100 - val_accuracy)
                    val_3_accuracy.append(val_top3_accuracy)
        # Log Model Performance  
        train_loss_values.append(training_loss)
        train_error.append(100-100*train_correct/train_total)
        print(f'Epoch {epoch+1}, Training Loss: {training_loss/len(train_loader)}, Validation Error: {val_error[-1]}, Validation Top-3 Accuracy: {val_3_accuracy[-1]}, Training Error: {train_error[-1]}')
        for op_params in optimizer.param_groups:
            op_params['lr'] = op_params['lr'] * learn_decay
    return train_error,train_loss_values, val_error, val_loss_values

def train_transformer(device, model, train_loader, val_loader, criterion, optimizer, num_epochs, learn_decay):
    train_loss_values = []
    train_error = []
    val_loss_values = []
    val_error = []
    val_3_accuracy = []
    for epoch in range(num_epochs):
        train_correct = 0
        train_total = 0
        training_loss = 0.0
        # Training
        model.train()
        count = 0
        for sequences, labels in train_loader:
            count += 1
            sequences, labels = sequences.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            # Forward Pass
            logits = model(sequences, labels)
            print(logits)
            loss = criterion(logits.view(-1, model.vocab_size), labels.contiguous().view(-1))
            # Backpropogate & Optimize
            optimizer.zero_grad()
            loss.backward()
            # Clip it
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            # For logging purposes
            training_loss += loss.item()
            # _, predicted = torch.max(output.data, 1)
            # train_total += labels.size(0)
            # train_correct += (predicted == labels).sum().item()
            # Get the predicted class indices for each position in each sequence
            _, predicted = torch.max(logits.data, dim=2)  # Shape: (batch_size, seq_length)
            correct_predictions = predicted == labels  # Shape: (batch_size, seq_length)
            correct_sequences = correct_predictions.all(dim=1)  # Shape: (batch_size)
            train_correct += correct_sequences.sum().item()
            train_total += labels.size(0) 
            break
            if count % 1000 == 0:
                print(f'Epoch {epoch+1}, Batch: {count}| Training Loss: {training_loss/count}')
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        validation_loss = 0.0
        # if val_loader is not None:
        #     with torch.no_grad():
        #         val_correct = 0
        #         val_total = 0
        #         val_top3_correct = 0
        #         validation_loss = 0

        #         for sequences, lengths, labels in val_loader:
        #             sequences, lengths, labels = sequences.to(device), lengths.to(device), labels.to(device)
        #             outputs = model.generate(sequences, lengths)
        #             _, predicted = torch.max(outputs.data, 1)
        #             val_total += labels.size(0)
        #             val_correct += (predicted == labels).sum().item()
        #             val_top3_correct += top_3_accuracy(labels, outputs) * labels.size(0)
        #             loss = criterion(outputs, labels)
        #             validation_loss += loss.item()

        #         val_loss_values.append(validation_loss / len(val_loader))
        #         val_accuracy = 100 * val_correct / val_total
        #         val_top3_accuracy = 100 * val_top3_correct / val_total
        #         val_error.append(100 - val_accuracy)
        #         val_3_accuracy.append(val_top3_accuracy)
        # Log Model Performance  
        train_loss_values.append(training_loss)
        train_error.append(100-100*train_correct/train_total)
        print(f'Epoch {epoch+1}, Training Loss: {training_loss/len(train_loader)}, Validation Error: {val_error[-1]}, Validation Top-3 Accuracy: {val_3_accuracy[-1]}, Training Error: {train_error[-1]}')
        for op_params in optimizer.param_groups:
            op_params['lr'] = op_params['lr'] * learn_decay
    return train_error,train_loss_values, val_error, val_loss_values

# Experiments

Experiment 1

In [163]:
# Reload the data with particular batch size
# torch.multiprocessing.set_start_method('fork', force=True)
# batch_size = 1
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2,pin_memory=True)

# Initialize model, loss function, and optimizer
d_model = 128
NUM_EPOCHS = 10
vocab_size = len(vocab.id_to_move.keys())
nhead = 8
num_layers = 2
model = ChessTransformer(vocab, d_model, nhead, num_layers = num_layers)
model = model.to(device)
# This ignores loss on pad tokens from the label's perspective
criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_id('<PAD>'))  # Assuming you have a PAD token
lr = 2e-3
weight_decay=1e-7
learn_decay = 0.65 # This causes the LR to be 2e-5 by epoch 10
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

4592570


In [164]:
# Train the model
train_error,train_loss_values, val_error, val_loss_value = train_decoder(device, model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, learn_decay)

# Plot the training error
plt.figure(figsize=(10, 5))
plt.plot(val_error, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.title('Validation Error')
plt.legend()
plt.show()
plt.savefig('transformer-decoder-4-22.png')  # This will save the plot as an image

Epoch 1, Batch: 1000| Training Loss: 6.008774438381195
Epoch 1, Batch: 2000| Training Loss: 5.91263079738617
Epoch 1, Batch: 3000| Training Loss: 5.866261313597361
Epoch 1, Batch: 4000| Training Loss: 5.8419452521801
Epoch 1, Batch: 5000| Training Loss: 5.8258527850151065
Epoch 1, Batch: 6000| Training Loss: 5.8119971399307255
Epoch 1, Batch: 7000| Training Loss: 5.803614270176206
Epoch 1, Batch: 8000| Training Loss: 5.792056399792433
Epoch 1, Batch: 9000| Training Loss: 5.780102597051196
Epoch 1, Batch: 10000| Training Loss: 5.768266700720787
Epoch 1, Batch: 11000| Training Loss: 5.76097933754054
Epoch 1, Batch: 12000| Training Loss: 5.752273487309615
Epoch 1, Batch: 13000| Training Loss: 5.743259393471938
Epoch 1, Batch: 14000| Training Loss: 5.737206680689539
Epoch 1, Batch: 15000| Training Loss: 5.730790600506465
Epoch 1, Batch: 16000| Training Loss: 5.727749186083674
Epoch 1, Batch: 17000| Training Loss: 5.721433475382188
Epoch 1, Batch: 18000| Training Loss: 5.71827676361137
Epoc

KeyboardInterrupt: 