# Multimodal


In [1]:
import pickle
import torch
import pandas as pd
import importlib
import utils
import models
import data_processing_utils

importlib.reload(data_processing_utils)
from data_processing_utils import *
importlib.reload(utils)
from utils import *
importlib.reload(models)
from models import *
from torch.utils.data import DataLoader, Subset
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")



In [4]:
# If you want to use the csv.format
grouped_df = pd.read_csv('../data/haha-longer-longer-001.csv')

prepare the data

In [2]:
with open('../data/combined/vocab.pkl', 'rb') as inp:
    vocab = pickle.load(inp)

In [48]:
len(vocab.id_to_move.keys())

9223

In [5]:
trainX_sequences, fens, trainX, trainY, vocab = df_to_data(grouped_df[:1000], vocab, sampling_rate=1,color= 'w', with_checkmate = False)
trainX_sequences, trainX_seqlengths  = pad_sequences(trainX_sequences)

In [8]:
filenames = [
    save_as_memmap(trainX, '../data/1500/black/trainX_boards.memmap'),
    save_as_memmap(trainX_sequences, '../data/1500/black/trainX_sequences.memmap'),
    save_as_memmap(trainX_seqlengths, '../data/1500/black/trainX_seqlengths.memmap'),
    save_as_memmap(trainY, '../data/1500/black/trainY.memmap'),
]

In [9]:
df = pd.DataFrame(fens, columns=['fens'])
csv_filename = './../data/1500/black/fens.csv'
df.to_csv(csv_filename, index=False)

In [4]:

fens = pd.read_csv('../data/1500/black/fens.csv')['fens']

In [5]:
# For trainX
dtype_trainX = np.bool_  # or the correct dtype for your data
shape_trainX = (18568313, 12,8,8)  # replace with the correct shape
trainX = load_memmap('./../data/1500/black/trainX_boards.memmap', dtype_trainX, shape_trainX)

# For trainY
dtype_trainY = np.int64 # or the correct dtype for your data
shape_trainY = (18568313,)  # replace with the correct shape
trainY = load_memmap('./../data/1500/black/trainY.memmap', dtype_trainY, shape_trainY)

# For trainY
dtype_trainX_seqlengths = np.int64 # or the correct dtype for your data
shape_trainX_seqlengths = (18568313,)  # replace with the correct shape
trainX_seqlengths = load_memmap('./../data/1500/black/trainX_seqlengths.memmap', dtype_trainX_seqlengths, shape_trainX_seqlengths)

# For trainY
dtype_trainX_sequences = np.int64 # or the correct dtype for your data
shape_trainX_sequences = (18568313, 16)  # replace with the correct shape
trainX_sequences = load_memmap('./../data/1500/black/trainX_sequences.memmap', dtype_trainX_sequences, shape_trainX_sequences)


Functions for training

In [5]:
def train(device, model, train_loader, val_loader, criterion, optimizer, num_epochs, learn_decay):
    train_loss_values = []
    train_error = []
    val_loss_values = []
    val_error = []
    val_3_accuracy = []
    for epoch in range(num_epochs):
        train_correct = 0
        train_total = 0
        training_loss = 0.0
        # Training
        model.train()
        count = 0
        for boards, sequences, lengths, labels in train_loader:
            count += 1
            boards, sequences, lengths, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), lengths, labels.to(device, non_blocking=True)
            # Forward Pass
            output = model(boards, sequences, lengths)
            loss = criterion(output, labels)
            # Backpropogate & Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # For logging purposes
            training_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            if count % 1000 == 0:
                print(f'Epoch {epoch+1}, Batch: {count}| Training Loss: {training_loss/count}')
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        validation_loss = 0.0
        if val_loader is not None:
            with torch.no_grad():
                val_correct = 0
                val_total = 0
                val_top3_correct = 0
                validation_loss = 0

                for boards, sequences, lengths, labels in val_loader:
                    boards, sequences, lengths, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), lengths, labels.to(device, non_blocking=True)
                    outputs = model(boards, sequences, lengths)
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
                    val_top3_correct += top_3_accuracy(labels, outputs) * labels.size(0)
                    loss = criterion(outputs, labels)
                    validation_loss += loss.item()

                val_loss_values.append(validation_loss / len(val_loader))
                val_accuracy = 100 * val_correct / val_total
                val_top3_accuracy = 100 * val_top3_correct / val_total
                val_error.append(100 - val_accuracy)
                val_3_accuracy.append(val_top3_accuracy)

        # Log Model Performance  
        train_loss_values.append(training_loss)
        train_error.append(100-100*train_correct/train_total)
        print(f'Epoch {epoch+1}, Training Loss: {training_loss/len(train_loader)}, Validation Error: {val_error[-1]}, Validation Top-3 Accuracy: {val_3_accuracy[-1]}, Training Error: {train_error[-1]}')
        if epoch <= 10:
            for op_params in optimizer.param_groups:
                op_params['lr'] = op_params['lr'] * learn_decay
    return train_error,train_loss_values, val_error, val_loss_values

def train_with_fen(device, model, train_loader, val_loader, criterion, optimizer, num_epochs, learn_decay, experiment_name):
    train_loss_values = []
    train_error = []
    val_loss_values = []
    val_error = []
    val_3_accuracy = []
    for epoch in range(num_epochs):
        train_correct = 0
        train_total = 0
        training_loss = 0.0
        # Training
        model.train()
        count = 0
        for boards, sequences, lengths, fens, labels in train_loader:
            count += 1
            boards, sequences, lengths, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), lengths, labels.to(device, non_blocking=True)
            # Forward Pass
            output = model(boards, sequences, lengths)
            loss = criterion(output, labels)
            # Backpropogate & Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # For logging purposes
            training_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            if count % 1000 == 0:
                print(f'Epoch {epoch+1}, Batch: {count}| Training Loss: {training_loss/count}')
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        validation_loss = 0.0
        if val_loader is not None:
            with torch.no_grad():
                val_correct = 0
                val_total = 0
                val_top3_correct = 0
                validation_loss = 0

                for boards, sequences, lengths, fens, labels in val_loader:
                    boards, sequences, lengths, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), lengths, labels.to(device, non_blocking=True)
                    outputs = model(boards, sequences, lengths)
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
                    val_top3_correct += top_3_accuracy(labels, outputs) * labels.size(0)
                    loss = criterion(outputs, labels)
                    validation_loss += loss.item()

                val_loss_values.append(validation_loss / len(val_loader))
                val_accuracy = 100 * val_correct / val_total
                val_top3_accuracy = 100 * val_top3_correct / val_total
                val_error.append(100 - val_accuracy)
                val_3_accuracy.append(val_top3_accuracy)

        # Log Model Performance  
        train_loss_values.append(training_loss)
        train_error.append(100-100*train_correct/train_total)
        print(f'Epoch {epoch+1}, Training Loss: {training_loss/len(train_loader)}, Validation Error: {val_error[-1]}, Validation Top-3 Accuracy: {val_3_accuracy[-1]}, Training Error: {train_error[-1]}')
        for op_params in optimizer.param_groups:
            op_params['lr'] = op_params['lr'] * learn_decay
        torch.save(model.state_dict(), f'model_images/multimodalmodel-exp-{experiment_name}-checkpoint-{epoch}.pth')
    return train_error,train_loss_values, val_error, val_loss_values


def train_for_multi_transformer(device, model, train_loader, val_loader, criterion, optimizer, num_epochs, learn_decay):
    train_loss_values = []
    train_error = []
    val_loss_values = []
    val_error = []
    val_3_accuracy = []
    for epoch in range(num_epochs):
        train_correct = 0
        train_total = 0
        training_loss = 0.0
        # Training
        model.train()
        count = 0
        for boards, sequences, lengths, fens, labels in train_loader:
            count += 1
            boards, sequences, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            # Forward Pass
            output = model(boards, sequences)
            loss = criterion(output, labels)
            # Backpropogate & Optimize
            optimizer.zero_grad()
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), 5)
            optimizer.step()
            # For logging purposes
            training_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            if count % 1000 == 0:
                print(f'Epoch {epoch+1}, Batch: {count}| Training Loss: {training_loss/count} | Training Error: {train_correct/train_total}')
        # Validation
        model.eval()
        val_correct = 0
        val_total = 0
        validation_loss = 0.0
        if val_loader is not None:
            with torch.no_grad():
                val_correct = 0
                val_total = 0
                val_top3_correct = 0
                validation_loss = 0

                for boards, sequences, lengths, fens, labels in val_loader:
                    boards, sequences, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                    outputs = model(boards, sequences)
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
                    val_top3_correct += top_3_accuracy(labels, outputs) * labels.size(0)
                    loss = criterion(outputs, labels)
                    validation_loss += loss.item()

                val_loss_values.append(validation_loss / len(val_loader))
                val_accuracy = 100 * val_correct / val_total
                val_top3_accuracy = 100 * val_top3_correct / val_total
                val_error.append(100 - val_accuracy)
                val_3_accuracy.append(val_top3_accuracy)

        # Log Model Performance  
        train_loss_values.append(training_loss)
        train_error.append(100-100*train_correct/train_total)
        print(f'Epoch {epoch+1}, Training Loss: {training_loss/len(train_loader)}, Validation Error: {val_error[-1]}, Validation Top-3 Accuracy: {val_3_accuracy[-1]}, Training Error: {train_error[-1]}')
        for op_params in optimizer.param_groups:
            op_params['lr'] = op_params['lr'] * learn_decay
    return train_error,train_loss_values, val_error, val_loss_values

Now let's train

In [11]:
dataset = MultimodalDatasetWithFEN(trainX_sequences, trainX, trainX_seqlengths, fens, trainY)
batch_size = 64
val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=6, pin_memory=True)
print(len(dataset))

26912


In [10]:
d_hidden = 256
d_embed = 64
NUM_EPOCHS = 7
d_out = len(vocab.id_to_move.keys())
model = MultiModalSeven(vocab,d_embed,d_hidden,d_out) 
model = model.to(device)
model.load_state_dict(torch.load('model_images/multimodalmodel-exp-12-white-1500-checkpoint-8.pth',map_location=torch.device('cpu')))


<All keys matched successfully>

In [13]:

model.eval()
val_correct_3 = 0
val_correct = 0
val_total = 0

if val_loader is not None:
    with torch.no_grad():
        for boards, sequences, lengths, fens, labels in val_loader:
            boards, sequences, lengths, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), lengths, labels.to(device, non_blocking=True)
            outputs = model(boards, sequences, lengths)
            probabilities = torch.softmax(outputs, dim=1)
            minus = 0
            batch_correct = 0
            for idx, (sequence, fen, label) in enumerate(zip(sequences, fens, labels)):
                output = probabilities[idx]
                sorted_probs, sorted_indices = torch.sort(output, descending=True)
                chess_board = chess.Board(fen)
                legal_moves_found = 0
                correct_move_found_within_top_3 = False

                for move_idx in sorted_indices:
                    move = vocab.get_move(move_idx.item())  # Convert index to move
                    if is_legal_move(chess_board, move):
                        if legal_moves_found == 0:
                            pred_move = vocab.get_id(move)
                        if vocab.get_id(move) == label.item():  # Check if this legal move is the correct one
                            correct_move_found_within_top_3 = True
                            break
                        legal_moves_found += 1
                        if legal_moves_found == 3:  # Stop after finding top 3 legal moves
                            break
                if pred_move == label.item():
                    batch_correct +=1
                if correct_move_found_within_top_3:
                    val_correct_3 += 1
            val_correct += batch_correct
            val_total += (labels.size(0) - minus)
            if batch_correct/labels.size(0) > 0.75:
                print(val_total)
                break

        val_accuracy = 100 * val_correct_3 / val_total
        print(f"Top-3 Validation Accuracy (with only legal moves allowed): {val_accuracy}%")
        val_accuracy = 100 * val_correct / val_total
        print(f"Top-1 Validation Accuracy (with only legal moves allowed): {val_accuracy}%")



# Evaluation of Most Up to Date Model

In [8]:
# Reload the data with particular batch size
torch.multiprocessing.set_start_method('fork', force=True)
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=6, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=6,pin_memory=True)

# Initialize model, loss function, and optimizer
d_hidden = 256
d_embed = 64
NUM_EPOCHS = 20
d_out = len(vocab.id_to_move.keys())
model = MultiModalSeven(vocab,d_embed,d_hidden,d_out) 
model = model.to(device)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(count_parameters(model))

model.load_state_dict(torch.load('model_images/multimodalmodel-exp-12-white-1500-checkpoint-8.pth'))


7513159


<All keys matched successfully>

In [9]:
model.eval()
val_correct_3 = 0
val_correct = 0
val_total = 0

if val_loader is not None:
    with torch.no_grad():
        for boards, sequences, lengths, fens, labels in val_loader:
            boards, sequences, lengths, labels = boards.to(device, non_blocking=True), sequences.to(device, non_blocking=True), lengths, labels.to(device, non_blocking=True)
            outputs = model(boards, sequences, lengths)
            probabilities = torch.softmax(outputs, dim=1)
            minus = 0
            for idx, (sequence, fen, label) in enumerate(zip(sequences, fens, labels)):
                output = probabilities[idx]
                sorted_probs, sorted_indices = torch.sort(output, descending=True)
                chess_board = chess.Board(fen)
                legal_moves_found = 0
                correct_move_found_within_top_3 = False

                for move_idx in sorted_indices:
                    move = vocab.get_move(move_idx.item())  # Convert index to move
                    if is_legal_move(chess_board, move):
                        if legal_moves_found == 0:
                            pred_move = vocab.get_id(move)
                        if vocab.get_id(move) == label.item():  # Check if this legal move is the correct one
                            correct_move_found_within_top_3 = True
                            break
                        legal_moves_found += 1
                        if legal_moves_found == 3:  # Stop after finding top 3 legal moves
                            break
                if pred_move == label.item():
                    val_correct +=1
                if correct_move_found_within_top_3:
                    val_correct_3 += 1
            val_total += (labels.size(0) - minus)

        val_accuracy = 100 * val_correct_3 / val_total
        print(f"Top-3 Validation Accuracy (with only legal moves allowed): {val_accuracy}%")
        val_accuracy = 100 * val_correct / val_total
        print(f"Top-1 Validation Accuracy (with only legal moves allowed): {val_accuracy}%")

Top-3 Validation Accuracy (with only legal moves allowed): 76.61637931034483%
Top-1 Validation Accuracy (with only legal moves allowed): 48.00646551724138%
