In [3]:
import sys
import yaml
from sklearn.model_selection import train_test_split

sys.path.append('../../../')

In [4]:
from scripts.utils import load_data, preprocess_data_label_encoder, create_sequences, train_and_test, simulate_predictions, simulate_predictions_from_pretrained

In [5]:
from src.models.bidir_lstm_multihead import build_model

In [6]:
YAML_FILE = '../../../configs/games/tic-tac-toe/bidir_lstm_multihead_1k.yml'

In [7]:
with open(YAML_FILE, 'r') as f:
    config = yaml.safe_load(f)
print(config)

{'name': '/games/tic-tac-toe/bidir_lstm_multihead_1k', 'model': {'type': 'Bidirectional Multihead LSTM', 'embedding_dim': 128, 'hidden_size': 256, 'num_layers': 2}, 'training': {'batch_size': 64, 'learning_rate': 0.001, 'num_epochs': 5}, 'data': {'game': 'tic-tac-toe', 'sequence_length': 20, 'max_event_length': 50, 'path': '/games/tic-tac-toe/1k_single_agent.csv'}}


In [8]:
# Load and preprocess data
df = load_data(f'../../../data/processed/{config['data']['path']}')
df, e_event, e_agent, e_context = preprocess_data_label_encoder(df)

# Create sequences
sequence_length = config['data']['sequence_length']
X, y = create_sequences(df, sequence_length)

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [9]:
model = train_and_test(
    X_train, 
    X_test, 
    y_train, 
    y_test, 
    e_event, 
    e_agent, 
    e_context, 
    name=config['name'], 
    epochs=config['training']['num_epochs'], 
    batch_size=config['training']['batch_size'],
    build_model=build_model
)

Epoch 1/5
[1m120/120[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 18ms/step - agent_id_accuracy: 0.4056 - context_accuracy: 0.0964 - event_type_accuracy: 0.7358 - loss: 4.4250 - val_agent_id_accuracy: 0.5302 - val_context_accuracy: 0.0979 - val_event_type_accuracy: 0.8010 - val_loss: 3.8364
Epoch 2/5
[1m120/120[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - agent_id_accuracy: 0.5749 - context_accuracy: 0.1876 - event_type_accuracy: 0.8746 - loss: 3.3326 - val_agent_id_accuracy: 0.7109 - val_context_accuracy: 0.2198 - val_event_type_accuracy: 0.9307 - val_loss: 2.8814
Epoch 3/5
[1m120/120[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - agent_id_accuracy: 0.7780 - context_accuracy: 0.2340 - event_type_accuracy: 0.9342 - loss: 2.7442 - val_agent_id_accuracy: 0.9005 - val_context_accuracy: 0.2573 - val_event_type_accuracy: 0.9260 - val_loss: 2.4628
Epoch 4/5
[1m120/120[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 12ms/step - agent_




loss: 2.2806
compile_metrics: 0.9302


In [10]:
simulate_predictions(
    data=X,
    model=model,
    e_event=e_event,
    e_agent=e_agent,
    e_context=e_context,
    n = 15,
    k = config['data']['sequence_length']
)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 344ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2

In [11]:
simulate_predictions_from_pretrained(
    data=X,
    modelpath=f'../../../models/{config['name']}.h5',
    e_event=e_event,
    e_agent=e_agent,
    e_context=e_context,
    n = 15,
    k = config['data']['sequence_length']
)



[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 335ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 21ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2