In [8]:
import sys
import os
import yaml
import json

sys.path.append('../../')

from constants import ROOT_DIR


In [9]:
YAML_FILE = '../../configs/model_configs/transformer_tic_tac_toe_single_agent.yml'
from src.models.trasnformer_model import create_model

In [10]:
with open(YAML_FILE, 'r') as f:
    config = yaml.safe_load(f)
print(config)

{'name': 'transformer_tic_tac_toe_single_agent_1k_model', 'model': {'type': 'Transformer', 'embedding_dim': 128, 'nhead': 4, 'num_encoder_layers': 3}, 'training': {'batch_size': 64, 'learning_rate': 0.0001, 'num_epochs': 5}, 'data': {'game': 'tic-tac-toe', 'sequence_length': 20, 'max_event_length': 10, 'path': '/games/tic-tac-toe/1k_single_agent.csv'}}


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from src.data.datasets.event_dataset import EventDataset

In [12]:
def collate_fn(batch):
    batch_input, batch_target, sequences, targets = zip(*batch)
    
    # Pad batch_target to the maximum length in the batch
    max_length = max(len(target) for target in batch_target)
    batch_target = [torch.nn.functional.pad(target, (0, max_length - len(target)), value=0) for target in batch_target]
    
    batch_input = [item for sublist in batch_input for item in sublist]
    batch_input = torch.nn.utils.rnn.pad_sequence(batch_input, batch_first=True, padding_value=0)
    batch_target = torch.stack(batch_target)
    
    return batch_input, batch_target, sequences, targets

def train_model(config):
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'])
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, collate_fn=collate_fn)

    model = create_model(config, dataset.vocab_size)
    criterion = nn.CrossEntropyLoss(ignore_index=0)  # Ignore padding index
    optimizer = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    print(f"Vocabulary size: {dataset.vocab_size}")
    print(f"Device: {device}")

    for epoch in range(config['training']['num_epochs']):
        model.train()
        total_loss = 0
        for batch_idx, (batch_input, batch_target, _, _) in enumerate(dataloader):
            # print(f"\nBatch {batch_idx + 1}:")
            # print(f"Input shape: {batch_input.shape}")
            # print(f"Target shape: {batch_target.shape}")

            batch_input, batch_target = batch_input.to(device), batch_target.to(device)
            
            optimizer.zero_grad()
            output = model(batch_input)
            # print(f"Raw output shape: {output.shape}")
            
            # Reshape output and target tensors
            batch_size, seq_len, vocab_size = output.shape
            output = output.contiguous().view(-1, vocab_size)
            batch_target = batch_target.view(-1)
            
            # print(f"Reshaped output shape: {output.shape}")
            # print(f"Reshaped target shape: {batch_target.shape}")
            
            # Ensure output and target have the same batch size
            min_length = min(output.size(0), batch_target.size(0))
            output = output[:min_length]
            batch_target = batch_target[:min_length]
            
            # print(f"Final output shape: {output.shape}")
            # print(f"Final target shape: {batch_target.shape}")

            try:
                loss = criterion(output, batch_target)
                # print(f"Loss: {loss.item()}")

                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            except RuntimeError as e:
                print(f"Runtime error during training: {e}")
                continue

            # if batch_idx % 10 == 0:
            #     print(f"Epoch {epoch+1}, Batch {batch_idx+1}/{len(dataloader)}, Loss: {loss.item():.4f}")

        print(f"Epoch {epoch+1}/{config['training']['num_epochs']}, Average Loss: {total_loss/len(dataloader):.4f}")

    torch.save(model.state_dict(), f'../../results/models/{config["name"]}.pth')
    print(f"\nModel saved to ../../results/models/{config['name']}.pth")
     
def evaluate_model(config):
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'])
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False, collate_fn=collate_fn)

    model = create_model(config, dataset.vocab_size)
    model.load_state_dict(torch.load(f'../../results/models/{config["name"]}.pth'))
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for batch_input, batch_target, _, _ in dataloader:
            batch_input, batch_target = batch_input.to(device), batch_target.to(device)
            
            outputs = model(batch_input)
            
            # Reshape outputs and target
            batch_size, seq_len, vocab_size = outputs.shape
            outputs = outputs.contiguous().view(-1, vocab_size)
            batch_target = batch_target.contiguous().view(-1)
            
            # Ensure outputs and target have the same batch size
            min_length = min(outputs.size(0), batch_target.size(0))
            outputs = outputs[:min_length]
            batch_target = batch_target[:min_length]
            
            _, predicted = torch.max(outputs, 1)
            
            # Ignore padded elements (assuming 0 is the padding index)
            mask = batch_target != 0
            predicted = predicted[mask]
            batch_target = batch_target[mask]
            
            total += batch_target.size(0)
            correct += (predicted == batch_target).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')

In [13]:
train_model(config=config)

  input_tensor = [torch.tensor(seq).clone().detach() for seq in input_seq]


Vocabulary size: 20
Device: cuda
Epoch 1/5, Average Loss: 2.4491
Epoch 2/5, Average Loss: 2.4308
Epoch 3/5, Average Loss: 2.4285
Epoch 4/5, Average Loss: 2.4289
Epoch 5/5, Average Loss: 2.4284

Model saved to ../../results/models/transformer_tic_tac_toe_single_agent_1k_model.pth


In [14]:
evaluate_model(config=config)

  input_tensor = [torch.tensor(seq).clone().detach() for seq in input_seq]


Accuracy on the test set: 20.87%
