In [1]:
import sys
import os
import yaml
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
sys.path.append('../../')
from constants import ROOT_DIR
from src.data.datasets.event_dataset import EventDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('../../')
from constants import ROOT_DIR

In [3]:
YAML_FILE = '../../configs/model_configs/bert_tic_tac_toe_single_agent.yml'
from src.models.bert_model import create_model, get_tokenizer

In [4]:
with open(YAML_FILE, 'r') as f:
    config = yaml.safe_load(f)
print(config)

{'name': 'bert_tic_tac_toe_single_agent_1k_model', 'model': {'type': 'BERT', 'pretrained_model': 'bert-base-uncased', 'hidden_size': 768, 'num_layers': 12}, 'training': {'batch_size': 16, 'learning_rate': 0.001, 'num_epochs': 15, 'warmup_steps': 500}, 'data': {'game': 'tic-tac-toe', 'sequence_length': 20, 'max_event_length': 10, 'path': '/games/tic-tac-toe/1k_single_agent.csv'}, 'tokenizer': {'max_length': 128, 'padding': 'max_length', 'truncation': True}}


In [5]:
def collate_fn(batch):
    batch_input, batch_target, sequences, targets = zip(*batch)
    
    # Convert batch_input to a list of token IDs
    batch_input = [torch.tensor([token.item() for seq in input_batch for token in seq]) for input_batch in batch_input]
    
    # Pad sequences
    batch_input = nn.utils.rnn.pad_sequence(batch_input, batch_first=True, padding_value=0)
    
    # Create attention mask
    attention_mask = (batch_input != 0).float()
    
    # For batch_target, we'll just take the last token of each sequence
    batch_target = torch.tensor([target[-1].item() for target in batch_target])
    
    return batch_input, attention_mask, batch_target, sequences, targets

def train_model(config):
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'])
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, collate_fn=collate_fn)

    model = create_model(config, dataset.vocab_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=float(config['training']['learning_rate']))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    print(f"Vocabulary size: {dataset.vocab_size}")
    print(f"Device: {device}")

    for epoch in range(config['training']['num_epochs']):
        model.train()
        total_loss = 0
        for batch_idx, (batch_input, attention_mask, batch_target, _, _) in enumerate(dataloader):
            batch_input, attention_mask, batch_target = batch_input.to(device), attention_mask.to(device), batch_target.to(device)
            
            # Pad input sequences to maximum length
            batch_input = nn.utils.rnn.pad_sequence(batch_input, batch_first=True, padding_value=0)
            attention_mask = nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
            
            optimizer.zero_grad()
            logits = model(input_ids=batch_input, attention_mask=attention_mask)
            
            loss = criterion(logits, batch_target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{config['training']['num_epochs']}, Average Loss: {total_loss/len(dataloader):.4f}")

    torch.save(model.state_dict(), f'../../results/models/{config["name"]}.pth')
    print(f"\nModel saved to ../../results/models/{config['name']}.pth")

def evaluate_model(config):
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'])
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False, collate_fn=collate_fn)

    model = create_model(config, dataset.vocab_size)
    model.load_state_dict(torch.load(f'../../results/models/{config["name"]}.pth'))
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    correct = 0
    total = 0

    with torch.no_grad():
        for batch_input, attention_mask, batch_target, _, _ in dataloader:
            batch_input, attention_mask, batch_target = batch_input.to(device), attention_mask.to(device), batch_target.to(device)
            
            logits = model(input_ids=batch_input, attention_mask=attention_mask)
            
            _, predicted = torch.max(logits, 1)
            
            total += batch_target.size(0)
            correct += (predicted == batch_target).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')

In [6]:
train_model(config=config)

Vocabulary size: 20
Device: cuda


  input_tensor = [torch.tensor(seq).clone().detach() for seq in input_seq]


Epoch 1/15, Average Loss: 2.6884
Epoch 2/15, Average Loss: 2.6326
Epoch 3/15, Average Loss: 2.6408
Epoch 4/15, Average Loss: 2.6388
Epoch 5/15, Average Loss: 2.6345
Epoch 6/15, Average Loss: 2.6118
Epoch 7/15, Average Loss: 2.5509
Epoch 8/15, Average Loss: 2.5128
Epoch 9/15, Average Loss: 2.5097
Epoch 10/15, Average Loss: 2.5033
Epoch 11/15, Average Loss: 2.5033
Epoch 12/15, Average Loss: 2.5008
Epoch 13/15, Average Loss: 2.5003
Epoch 14/15, Average Loss: 2.4979
Epoch 15/15, Average Loss: 2.4998

Model saved to ../../results/models/bert_tic_tac_toe_single_agent_1k_model.pth


In [7]:
evaluate_model(config=config)

  input_tensor = [torch.tensor(seq).clone().detach() for seq in input_seq]


Accuracy on the test set: 8.87%
