In [1]:
import sys
import os
import yaml

sys.path.append('../../')
os.getcwd()

'/Users/dev/Research/EventForge/code/notebooks/models'

In [2]:
YAML_FILE = '../../configs/model_configs/lstm_tic_tac_toe_single_agent.yml'
from src.models.lstm_model import create_model

In [3]:
with open(YAML_FILE, 'r') as f:
    config = yaml.safe_load(f)
config

{'name': 'lstm_tic_tac_toe_single_agent_1k_model',
 'model': {'type': 'LSTM', 'hidden_size': 256, 'num_layers': 1},
 'training': {'batch_size': 64, 'learning_rate': 0.001, 'num_epochs': 3},
 'data': {'game': 'tic-tac-toe',
  'sequence_length': 20,
  'path': '/games/tic-tac-toe/1k_single_agent.csv'}}

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from src.data.datasets.event_dataset import EventDataset

def train_model(config):
    # Load data 
    dataset = EventDataset(config['data']['path'], config['data']['sequence_length'])
    dataloader = DataLoader(dataset, batch_size=config['training']['batch_size'], shuffle=True)

    # Create model
    model = create_model(config, dataset.get_vocab_size())
    
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['training']['learning_rate'])

    # Training loop
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(config['training']['num_epochs']):
        model.train()
        total_loss = 0
        for batch_input, batch_target in dataloader:
            batch_input, batch_target = batch_input.to(device), batch_target.to(device)
            
            optimizer.zero_grad()
            output = model(batch_input)
            loss = criterion(output, batch_target)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{config['training']['num_epochs']}, Loss: {total_loss/len(dataloader):.4f}")

    # Save the model
    torch.save(model.state_dict(), f'../../results/models/{config['name']}.pth')

In [5]:
train_model(config=config)

Epoch 1/3, Loss: 0.9688
Epoch 2/3, Loss: 0.3244
Epoch 3/3, Loss: 0.2016
