In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import csv
import random

class SequenceMaskedTicTacToeDataset(Dataset):
    def __init__(self, csv_file, sequence_length=25, state_dim=(3, 3)):
        self.data = []
        self.state_dim = state_dim
        self.sequence_length = sequence_length
        
        with open(csv_file, 'r') as file:
            csv_reader = csv.reader(file)
            next(csv_reader)  # Skip header
            current_sequence = []
            for row in csv_reader:
                agent = int(row[0])
                state = np.array([float(x) for x in row[1][1:-1].split()]).reshape(state_dim)
                current_sequence.append((agent, state))
                
                if len(current_sequence) == sequence_length:
                    self.data.append(current_sequence)
                    current_sequence = current_sequence[1:]  # Slide the window
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        sequence = self.data[idx]
        
        # Choose a random position to mask
        # mask_position = random.randint(0, self.sequence_length - 1)
        mask_position = -1
        
        # Create masked version of the sequence
        masked_sequence = [state.copy() for _, state in sequence]
        original_sequence = [state for _, state in sequence]
        
        # Mask the chosen position
        mask = np.ones_like(masked_sequence[mask_position], dtype=bool)
        masked_sequence[mask_position] = np.full_like(masked_sequence[mask_position], -1)
        
        return (
            torch.tensor([agent for agent, _ in sequence], dtype=torch.long),
            torch.tensor(np.array(masked_sequence), dtype=torch.float),
            torch.tensor(np.array(original_sequence), dtype=torch.float),
            torch.tensor(mask, dtype=torch.bool),
            torch.tensor(mask_position, dtype=torch.long)
        )


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np

class SequenceTicTacToePredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=2):
        super(SequenceTicTacToePredictor, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_dim, input_dim)

    def forward(self, x):
        # x shape: (batch_size, sequence_length, 3, 3)
        batch_size, seq_len, rows, cols = x.shape
        x = x.float()  # Convert to float for LSTM processing
        x = x.view(batch_size, seq_len, -1)  # Flatten each state
        
        lstm_out, _ = self.lstm(x)
        predictions = self.fc(lstm_out)
        predictions = predictions.view(batch_size, seq_len, rows, cols)
        
        # Round predictions to nearest integer during forward pass
        rounded_predictions = torch.round(predictions)
        # Use straight-through estimator for backpropagation
        rounded_predictions = predictions + (rounded_predictions - predictions).detach()
        
        return rounded_predictions

In [3]:
class SequenceTicTacToePredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=2, dropout=0.2):
        super(SequenceTicTacToePredictor, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        
        # Embedding layer to learn representations of board states
        self.embedding = nn.Linear(input_dim, hidden_dim)
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers, 
                            batch_first=True, bidirectional=True, dropout=dropout)
        
        # Output layers
        self.fc1 = nn.Linear(hidden_dim * 2, hidden_dim)  # *2 for bidirectional
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()

    def forward(self, x):
        # x shape: (batch_size, sequence_length, 3, 3)
        batch_size, seq_len, rows, cols = x.shape
        x = x.float().view(batch_size, seq_len, -1)  # Flatten each state
        
        # Embed the input
        x = self.relu(self.embedding(x))
        
        # LSTM layer
        lstm_out, _ = self.lstm(x)
        
        # Output layers
        out = self.dropout(lstm_out)
        out = self.relu(self.fc1(out))
        out = self.fc2(out)
        
        predictions = out.view(batch_size, seq_len, rows, cols)
        
        # Round predictions to nearest integer during forward pass
        rounded_predictions = torch.round(predictions)
        # Use straight-through estimator for backpropagation
        rounded_predictions = predictions + (rounded_predictions - predictions).detach()
        
        return rounded_predictions

In [4]:
# Hyperparameters
input_dim = 9  # 3x3 grid flattened
hidden_dim = 64
learning_rate = 0.001
batch_size = 32
num_epochs = 500

# Create dataset and dataloader
dataset = SequenceMaskedTicTacToeDataset('../data/processed/synthetic/tic-tac-toe.csv', sequence_length=25)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss function, and optimizer
model = SequenceTicTacToePredictor(input_dim, hidden_dim)
criterion = nn.MSELoss()  # Using MSE loss for integer predictions
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for agents, masked_sequences, original_sequences, masks, mask_positions in dataloader:
        optimizer.zero_grad()
        
        # Forward pass
        predictions = model(masked_sequences)
        
        # Compute loss only on masked positions
        loss = criterion(predictions[torch.arange(predictions.size(0)), mask_positions], 
                         original_sequences[torch.arange(original_sequences.size(0)), mask_positions])
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

Epoch [1/500], Loss: 1.0782
Epoch [2/500], Loss: 0.9470
Epoch [3/500], Loss: 1.1407
Epoch [4/500], Loss: 1.1595
Epoch [5/500], Loss: 1.1146
Epoch [6/500], Loss: 0.9356
Epoch [7/500], Loss: 0.7889
Epoch [8/500], Loss: 0.7944
Epoch [9/500], Loss: 0.7650
Epoch [10/500], Loss: 0.7137
Epoch [11/500], Loss: 0.8014
Epoch [12/500], Loss: 0.8764
Epoch [13/500], Loss: 0.7965
Epoch [14/500], Loss: 0.7222
Epoch [15/500], Loss: 0.7500
Epoch [16/500], Loss: 0.7639
Epoch [17/500], Loss: 0.7150
Epoch [18/500], Loss: 0.7396
Epoch [19/500], Loss: 0.7238
Epoch [20/500], Loss: 0.7417
Epoch [21/500], Loss: 0.6905
Epoch [22/500], Loss: 0.7486
Epoch [23/500], Loss: 0.7567
Epoch [24/500], Loss: 0.7412
Epoch [25/500], Loss: 0.6961
Epoch [26/500], Loss: 0.6840
Epoch [27/500], Loss: 0.6970
Epoch [28/500], Loss: 0.6502
Epoch [29/500], Loss: 0.7215
Epoch [30/500], Loss: 0.7667
Epoch [31/500], Loss: 0.6227
Epoch [32/500], Loss: 0.7484
Epoch [33/500], Loss: 0.6644
Epoch [34/500], Loss: 0.7049
Epoch [35/500], Loss: 0

In [5]:
# Inference function
def sample_inference(model, dataset, num_predictions):
    model.eval()
    with torch.no_grad():
        for i in range(num_predictions):
            random_idx = random.randint(0, len(dataset) - 1)
            agents, masked_sequences, original_sequences, masks, mask_positions = dataset[random_idx]
            
            masked_sequences_batch = masked_sequences.unsqueeze(0)
            
            predictions = model(masked_sequences_batch)
            predictions = predictions.round().long()  # Ensure integer predictions
            
            print(f"\nSample {i+1}:")
            print("Input sequence:")
            visualize_sequence(masked_sequences.numpy(), mask_positions.item())
            print("\nPredicted sequence:")
            visualize_sequence(predictions[0].numpy(), mask_positions.item())
            print("\nOriginal sequence:")
            visualize_sequence(original_sequences.numpy(), mask_positions.item())
            print("\n" + "="*50)

def visualize_sequence(sequence, mask_position):
    for i, state in enumerate(sequence):
        if i == mask_position:
            print("*Masked State*")
        print(state.reshape(3, 3))
        print()

# Run inference
num_predictions = 5
sample_inference(model, dataset, num_predictions)


Sample 1:
Input sequence:
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

[[0. 0. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]

[[2. 0. 0.]
 [1. 0. 0.]
 [0. 0. 0.]]

[[2. 0. 0.]
 [1. 0. 0.]
 [0. 1. 0.]]

[[2. 0. 2.]
 [1. 0. 0.]
 [0. 1. 0.]]

[[2. 0. 2.]
 [1. 0. 0.]
 [1. 1. 0.]]

[[2. 0. 2.]
 [1. 2. 0.]
 [1. 1. 0.]]

[[2. 0. 2.]
 [1. 2. 1.]
 [1. 1. 0.]]

[[2. 2. 2.]
 [1. 2. 1.]
 [1. 1. 0.]]

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

[[0. 0. 1.]
 [0. 0. 0.]
 [0. 0. 0.]]

[[0. 0. 1.]
 [0. 0. 2.]
 [0. 0. 0.]]

[[0. 0. 1.]
 [0. 1. 2.]
 [0. 0. 0.]]

[[2. 0. 1.]
 [0. 1. 2.]
 [0. 0. 0.]]

[[2. 0. 1.]
 [0. 1. 2.]
 [0. 1. 0.]]

[[2. 0. 1.]
 [0. 1. 2.]
 [2. 1. 0.]]

[[2. 1. 1.]
 [0. 1. 2.]
 [2. 1. 0.]]

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 1.]]

[[0. 0. 0.]
 [0. 0. 2.]
 [0. 0. 1.]]

[[0. 0. 1.]
 [0. 0. 2.]
 [0. 0. 1.]]

[[0. 2. 1.]
 [0. 0. 2.]
 [0. 0. 1.]]

[[1. 2. 1.]
 [0. 0. 2.]
 [0. 0. 1.]]

[[1. 2. 1.]
 [0. 2. 2.]
 [0. 0. 1.]]

[[-1. -1. -1.]
 [-1. -1. -1.]
 [-1. -1. -1.]]


Predicted seq