In [1]:
import sys
import os
import yaml
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
sys.path.append('../../')
from constants import ROOT_DIR
from src.data.datasets.event_dataset import EventDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sys.path.append('../../')
from constants import ROOT_DIR

In [3]:
YAML_FILE = '../../configs/model_configs/bert_tic_tac_toe_single_agent.yml'
from src.models.bert_model import create_model

In [4]:
with open(YAML_FILE, 'r') as f:
    config = yaml.safe_load(f)
print(config)

{'name': 'bert_tic_tac_toe_single_agent_1k_model', 'model': {'type': 'BERT', 'pretrained_model': 'bert-base-uncased', 'hidden_size': 768, 'num_layers': 12}, 'training': {'batch_size': 16, 'learning_rate': 0.001, 'num_epochs': 15, 'warmup_steps': 500}, 'data': {'game': 'tic-tac-toe', 'sequence_length': 20, 'max_event_length': 10, 'path': '/games/tic-tac-toe/1k_single_agent.csv'}, 'tokenizer': {'max_length': 128, 'padding': 'max_length', 'truncation': True}}


In [5]:
def collate_fn(batch, tokenizer, max_length):
    batch_input, batch_target, sequences, targets = zip(*batch)
    encoded_batch = tokenizer.batch_encode_plus(
        batch_input,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    return {
        'input_ids': encoded_batch['input_ids'],
        'attention_mask': encoded_batch['attention_mask'],
        'labels': torch.tensor(batch_target)
    }

def train_model(config):
    pretrained_model_name = config['model']['pretrained_model']
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'], tokenizer)
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True, collate_fn=lambda b: collate_fn(b, tokenizer, config['tokenizer']['max_length']))
    model = create_model(config, dataset.num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=config['training']['learning_rate'])
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    for epoch in range(config['training']['num_epochs']):
        model.train()
        total_loss = 0
        for batch_idx, batch in enumerate(dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}/{config['training']['num_epochs']}, Average Loss: {total_loss/len(dataloader):.4f}")
    torch.save(model.state_dict(), f'../../results/models/{config["name"]}.pth')
    print(f"\nModel saved to ../../results/models/{config['name']}.pth")

def evaluate_model(config):
    pretrained_model_name = config['model']['pretrained_model']
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'], tokenizer)
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=False, collate_fn=lambda b: collate_fn(b, tokenizer, config['tokenizer']['max_length']))
    model = create_model(config, dataset.num_classes)
    model.load_state_dict(torch.load(f'../../results/models/{config["name"]}.pth'))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy on the test set: {accuracy:.2f}%')

In [6]:
train_model(config=config)

TypeError: EventDataset.__init__() takes 3 positional arguments but 4 were given

In [None]:
evaluate_model(config=config)