# Experimenting with transformers
Transformers remain as a promising replacement of RNNs due to their parallelizability. However, RNNs are unique in their hidden state which tends to be uniquely useful for games. 



In [1]:
import pickle
import torch
import pandas as pd
import importlib
import utils
import models
importlib.reload(utils)
from utils import *
importlib.reload(models)
from models import *
from torch.utils.data import DataLoader, Subset
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
grouped_df = pd.read_csv('../data/haha-longer-longer-001.csv')

In [2]:
# Function to load a memmap file
def load_memmap(filename, dtype, shape):
    # Load the memmap file with read-only mode
    return np.memmap(filename, dtype=dtype, mode='r', shape=shape)

In [16]:
# For trainX
dtype_trainX = np.int32  # or the correct dtype for your data
shape_trainX = (130177, 750)  # replace with the correct shape
trainX = load_memmap('./../data/transformer/trainX.memmap', dtype_trainX, shape_trainX)

# For trainY
dtype_trainY = np.int32 # or the correct dtype for your data
shape_trainY = (130177, 7)  # replace with the correct shape
trainY = load_memmap('./../data/transformer/trainY.memmap', dtype_trainY, shape_trainY)

# For trainY
dtype_seq_lengths = np.int32 # or the correct dtype for your data
shape_seq_lengths = (130177, 7)  # replace with the correct shape
seq_lengths = load_memmap('./../data/transformer/seq_lengths.memmap', dtype_seq_lengths, shape_seq_lengths)

with open('./../data/transformer/vocab.pkl', 'rb') as inp:
    vocab = pickle.load(inp)

In [46]:
print(seq_lengths.shape)

(130177, 7)


In [38]:

class TransformerDataset(Dataset):
    def __init__(self, sequences, seq_lengths, labels):
        self.sequences, self.seq_lengths, self.labels = sequences, seq_lengths, labels
        
    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return torch.tensor(self.sequences[idx], dtype=torch.long), torch.tensor(self.seq_lengths, dtype = torch.long), torch.tensor(self.labels[idx], dtype=torch.long)

In [39]:
dataset = TransformerDataset(trainX, seq_lengths, trainY)
total_size = len(dataset)
# We're scaling the model size so let's bring in more data as well
train_size = int(0.98 * total_size)
val_size = int(total_size * 0.02)

# Create subsets for training and validation
train_dataset = Subset(dataset, range(0, train_size))
val_dataset = Subset(dataset, range(train_size, train_size + val_size))
print(train_size)
# Reload the data with particular batch size
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


127573


In [40]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    

class ChessTransformer(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_layers, max_seq_length=750, dropout=0.1):
        super(ChessTransformer, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_length)
        self.transformer = nn.Transformer(d_model=d_model, nhead=nhead,
                                          num_encoder_layers=num_layers,
                                          num_decoder_layers=num_layers,
                                          batch_first=True)
        self.fc = nn.Linear(d_model, vocab_size)
        self.max_seq_length = max_seq_length

    def forward(self, src, src_lengths, tgt):
        # Create source padding mask
        src_padding_mask = self.create_padding_mask(src_lengths, self.max_seq_length).to(src.device)
        
        # Embedding and Positional Encoding for src
        src_emb = self.embedding(src)  # [batch_size, seq_len] -> [batch_size, seq_len, d_model]
        src_emb = self.pos_encoder(src_emb)

        # Assuming tgt is provided correctly and also needs embedding & positional encoding
        tgt_emb = self.embedding(tgt)
        tgt_emb = self.pos_encoder(tgt_emb)

        # Transformer
        output = self.transformer(src_emb, tgt_emb, src_key_padding_mask=src_padding_mask)
        
        # Linear layer to predict vocab
        output = self.fc(output)
        return output
    
    def create_padding_mask(self, seq_lengths, max_len):
        print(seq_lengths.shape)
        batch_size = seq_lengths.size(0)
        mask = (torch.arange(max_len).expand(batch_size, max_len) >= seq_lengths.unsqueeze(1))
        return mask
    
    def generate_sequence(self, src, src_length, start_symbol_id, sep_token_id, max_length=100):
        """
        Generate a sequence autoregressively using the trained transformer model.

        Args:
        - src (Tensor): The input source sequence tensor.
        - src_length (Tensor): The length of the source sequence.
        - start_symbol_id (int): The ID of the start symbol to begin generation.
        - sep_token_id (int): The ID of the SEP token for sequence termination.
        - max_length (int): Maximum length of the generated sequence to prevent infinite loops.

        Returns:
        - The generated sequence tensor.
        """
        self.eval()  # Ensure the model is in eval mode

        # Initialize the target sequence with the start symbol
        tgt = torch.tensor([start_symbol_id], dtype=torch.long).to(src.device)
        
        for _ in range(max_length):
            # Assuming src_length is a tensor with the length of src. Adjust as needed.

            # Perform a forward pass to get logits for the next token
            logits = self.forward(src, src_length, tgt, src)
            # Get the last token logits and apply softmax to get probabilities
            probs = torch.softmax(logits[:, -1, :], dim=-1)
            # Sample the next token from the probability distribution
            next_token = torch.multinomial(probs, 1)
            
            # Append the predicted token to the target sequence
            tgt = torch.cat((tgt, next_token), dim=1)
            
            # Check if the <SEP> token is generated
            if next_token.item() == sep_token_id:
                break

        return tgt


In [15]:
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel

Functions for training

In [41]:
# Function to calculate top-3 accuracy
def top_3_accuracy(y_true, y_pred):
    top3 = torch.topk(y_pred, 3, dim=1).indices
    correct = top3.eq(y_true.view(-1, 1).expand_as(top3))
    return correct.any(dim=1).float().mean().item()

def train_transformer(device, model, train_loader, val_loader, criterion, optimizer, num_epochs, learn_decay):
    train_loss_values = []
    train_error = []
    val_loss_values = []
    val_error = []
    val_3_accuracy = []
    for epoch in range(num_epochs):
        train_correct = 0
        train_total = 0
        training_loss = 0.0
        # Training
        model.train()
        count = 0
        for sequences, lengths, labels in train_loader:
            count += 1
            sequences, lengths, labels = sequences.to(device), lengths.to(device), labels.to(device)
            # Forward Pass
            logits = model(sequences, lengths, labels).data
            loss = criterion(logits.view(-1, model.vocab_size), tgt_labels.contiguous().view(-1))
            # Backpropogate & Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # For logging purposes
            training_loss += loss.item()
            # _, predicted = torch.max(output.data, 1)
            # train_total += labels.size(0)
            # train_correct += (predicted == labels).sum().item()
            if count % 1000 == 0:
                print(f'Epoch {epoch+1}, Batch: {count}| Training Loss: {training_loss/count}')
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        validation_loss = 0.0
        # if val_loader is not None:
        #     with torch.no_grad():
        #         val_correct = 0
        #         val_total = 0
        #         val_top3_correct = 0
        #         validation_loss = 0

        #         for sequences, lengths, labels in val_loader:
        #             sequences, lengths, labels = sequences.to(device), lengths.to(device), labels.to(device)
        #             outputs = model.generate(sequences, lengths)
        #             _, predicted = torch.max(outputs.data, 1)
        #             val_total += labels.size(0)
        #             val_correct += (predicted == labels).sum().item()
        #             val_top3_correct += top_3_accuracy(labels, outputs) * labels.size(0)
        #             loss = criterion(outputs, labels)
        #             validation_loss += loss.item()

        #         val_loss_values.append(validation_loss / len(val_loader))
        #         val_accuracy = 100 * val_correct / val_total
        #         val_top3_accuracy = 100 * val_top3_correct / val_total
        #         val_error.append(100 - val_accuracy)
        #         val_3_accuracy.append(val_top3_accuracy)

        # Log Model Performance  
        train_loss_values.append(training_loss)
        train_error.append(100-100*train_correct/train_total)
        print(f'Epoch {epoch+1}, Training Loss: {training_loss/len(train_loader)}, Validation Error: {val_error[-1]}, Validation Top-3 Accuracy: {val_3_accuracy[-1]}, Training Error: {train_error[-1]}')
        for op_params in optimizer.param_groups:
            op_params['lr'] = op_params['lr'] * learn_decay
    return train_error,train_loss_values, val_error, val_loss_values

# Experiments

Experiment 1

In [42]:
# Reload the data with particular batch size
batch_size = 8
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function, and optimizer
d_model = 256
NUM_EPOCHS = 5
d_out = len(vocab.id_to_word.keys())
nhead = 8
num_layers = 4
model = ChessTransformer(d_out, d_model, nhead, num_layers = num_layers)
model = model.to(device)
criterion = nn.CrossEntropyLoss(ignore_index=vocab.get_id('<PAD>'))  # Assuming you have a PAD token
lr = 2e-3
weight_decay=1e-7
learn_decay = 0.65 # This causes the LR to be 2e-5 by epoch 10
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

11598892


In [43]:
# Train the model
train_error,train_loss_values, val_error, val_loss_value = train_transformer(device, model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS, learn_decay)

# Plot the training error
plt.figure(figsize=(10, 5))
plt.plot(val_error, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.title('Validation Error')
plt.legend()
plt.show()
plt.savefig('transformer-4-04.png')  # This will save the plot as an image

torch.Size([8, 130177, 7])


RuntimeError: The size of tensor a (750) must match the size of tensor b (7) at non-singleton dimension 3