In [5]:
import torch

print("torch version: ", torch.__version__)

torch version:  2.6.0+cpu


In [92]:
import os
import json
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import from_networkx
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity

In [93]:
# Load card attributes from JSON file
def load_card_attributes(file="data/labeled_cards.json"):
    with open(file, "r") as f:
        return {card["id"]: card for card in json.load(f)}

# Load deck graphs from exported GraphML files
def load_deck_graphs(folder="graphs"):
    deck_graphs = []
    for file in os.listdir(folder):
        if file.endswith(".graphml"):
            G = nx.read_graphml(os.path.join(folder, file))
            deck_graphs.append(G)
    return deck_graphs

# Normalize features
def normalize_features(features):
    scaler = MinMaxScaler()
    return scaler.fit_transform(features)

# Convert NetworkX graphs to PyTorch Geometric format with normalized attributes
def convert_to_pyg(graphs, card_data):
    pyg_graphs = []
    feature_list = ["copies", "cost", "power", "labels", "counter", "type", "traits", "color_Red", "color_Green", "color_Blue", "color_Purple", "color_Black", "color_Yellow"]
    
    for G in graphs:
        node_features = []
        for node in G.nodes():
            card_id = node
            card = card_data.get(card_id, {})
            
            features = [
                float(G.nodes[node].get("copies", 1)),
                float(card.get("cost", 0)),
                float(card.get("power", 0)),
                float(len(card.get("labels", []))),  # Count of labels
                float(card.get("counter", 0))
            ]
            
            colors = ["Red", "Green", "Blue", "Purple", "Black", "Yellow"]
            features.extend([1.0 if c in card.get("color", []) else 0.0 for c in colors])
            
            features.append(float(hash(card.get("type", "None")) % 10))
            features.append(float(hash(" ".join(card.get("traits", []))) % 10))
            
            node_features.append(features)
        
        normalized_features = normalize_features(node_features)
        for i, node in enumerate(G.nodes()):
            for j, attr in enumerate(feature_list):
                G.nodes[node][attr] = normalized_features[i][j]
        
        pyg_data = from_networkx(G, group_node_attrs=feature_list)
        pyg_graphs.append(pyg_data)
    return pyg_graphs

In [94]:
# Define GNN Model for Synergy Learning with Dropout
class SynergyGNN(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super(SynergyGNN, self).__init__()
        self.conv1 = GATConv(in_dim, hidden_dim, heads=4)
        self.dropout = nn.Dropout(p=0.3)
        self.conv2 = GATConv(hidden_dim * 4, out_dim, heads=1)
        self.fc = nn.Linear(out_dim, out_dim)  # Predict embeddings instead of full features
        self.relu = nn.ReLU()

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index)
        x = self.fc(x)  # Output embedding
        return x

In [95]:
# Train the GNN Model with Learning Rate Scheduling
def train_synergy_gnn(graphs, card_data, epochs=100, lr=0.01):
    pyg_graphs = convert_to_pyg(graphs, card_data)
    dataset = DataLoader(pyg_graphs, batch_size=1, shuffle=True)
    
    input_dim = pyg_graphs[0].num_features if len(pyg_graphs) > 0 else 17
    model = SynergyGNN(in_dim=input_dim, hidden_dim=32, out_dim=16)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10)
    criterion = nn.MSELoss()
    
    for epoch in range(epochs):
        total_loss = 0
        for data in dataset:
            optimizer.zero_grad()
            x, edge_index = data.x, data.edge_index
            pred = model(x.float(), edge_index)
            loss = criterion(pred, torch.randn_like(pred))  # Predict meaningful embeddings
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        scheduler.step(total_loss)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    return model, input_dim

# Save trained model
def save_model(model, input_dim, path="graphs/synergy_gnn.pth"):
    torch.save({
        'model_state_dict': model.state_dict(),
        'input_dim': input_dim
    }, path)
    print(f"Model saved to {path}")

# Load trained model
def load_model(model_path="graphs/synergy_gnn.pth", hidden_dim=32, out_dim=16):
    checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
    input_dim = checkpoint['input_dim']
    model = SynergyGNN(in_dim=input_dim, hidden_dim=hidden_dim, out_dim=out_dim)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model, input_dim

In [96]:
card_data = load_card_attributes()
deck_graphs = load_deck_graphs()
model, input_dim = train_synergy_gnn(deck_graphs, card_data, epochs=1000)



Epoch 1/1000, Loss: 1.0174, LR: 0.010000
Epoch 2/1000, Loss: 1.0243, LR: 0.010000
Epoch 3/1000, Loss: 1.0166, LR: 0.010000
Epoch 4/1000, Loss: 0.9998, LR: 0.010000
Epoch 5/1000, Loss: 0.9978, LR: 0.010000
Epoch 6/1000, Loss: 1.0112, LR: 0.010000
Epoch 7/1000, Loss: 0.9864, LR: 0.010000
Epoch 8/1000, Loss: 1.0055, LR: 0.010000
Epoch 9/1000, Loss: 0.9996, LR: 0.010000
Epoch 10/1000, Loss: 1.0114, LR: 0.010000
Epoch 11/1000, Loss: 1.0110, LR: 0.010000
Epoch 12/1000, Loss: 1.0049, LR: 0.010000
Epoch 13/1000, Loss: 0.9946, LR: 0.010000
Epoch 14/1000, Loss: 0.9961, LR: 0.010000
Epoch 15/1000, Loss: 1.0072, LR: 0.010000
Epoch 16/1000, Loss: 1.0055, LR: 0.010000
Epoch 17/1000, Loss: 0.9899, LR: 0.010000
Epoch 18/1000, Loss: 1.0005, LR: 0.005000
Epoch 19/1000, Loss: 0.9871, LR: 0.005000
Epoch 20/1000, Loss: 1.0099, LR: 0.005000
Epoch 21/1000, Loss: 1.0054, LR: 0.005000
Epoch 22/1000, Loss: 0.9961, LR: 0.005000
Epoch 23/1000, Loss: 0.9873, LR: 0.005000
Epoch 24/1000, Loss: 1.0145, LR: 0.005000
E

In [97]:
save_model(model, input_dim)

Model saved to graphs/synergy_gnn.pth


In [None]:
import os
import json
import networkx as nx
import torch
import torch.nn as nn
import torch.optim as optim
import torch_geometric.transforms as T
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.data import Data, DataLoader
from torch_geometric.utils import from_networkx
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np

# Load card attributes from JSON file
def load_card_attributes(file="data/labeled_cards.json"):
    with open(file, "r") as f:
        return {card["id"]: card for card in json.load(f)}

# Load deck graphs from exported GraphML files
def load_deck_graphs(folder="graphs"):
    deck_graphs = []
    for file in os.listdir(folder):
        if file.endswith(".graphml"):
            G = nx.read_graphml(os.path.join(folder, file))
            deck_graphs.append(G)
    return deck_graphs

# Generate a legal deck using the trained GNN
def generate_deck(model, card_data, deck_size=50):
    G = nx.Graph()
    for card_id in card_data.keys():
        G.add_node(card_id)
    
    pyg_data = convert_to_pyg([G], card_data)[0]  # Use a real card dataset graph
    x, edge_index = pyg_data.x, pyg_data.edge_index

    with torch.no_grad():
        card_embeddings = model(x.float(), edge_index).numpy()
    
    # Rank cards based on learned synergy scores (sum of embeddings)
    ranked_cards = sorted(zip(list(card_data.keys()), card_embeddings.sum(axis=1)), key=lambda x: x[1], reverse=True)
    
    # Apply deck-building constraints
    leader = None
    for card_id, _ in ranked_cards:
        if card_data[card_id]["type"] == "Leader":
            leader = card_id
            break
    if leader is None:
        raise ValueError("No Leader found in generated deck.")
    
    leader_color = set(card_data[leader]["color"])  # Convert leader's colors to a set
    banned_cards = {"ST10-001", "OP03-098", "OP05-041", "ST06-015", "OP06-116", "OP02-024", "OP02-052"}  # Manually defined banned card IDs
    
    selected_cards = []
    card_counts = {}
    for card_id, _ in ranked_cards:
        if len(selected_cards) >= deck_size:
            break
        if card_id == leader:
            continue
        if card_id in banned_cards:
            continue
        if not leader_color.intersection(set(card_data[card_id]["color"])):  # Ensure color matches
            continue
        
        if card_id not in card_counts:
            card_counts[card_id] = 0
        if card_counts[card_id] < 4:
            card_counts[card_id] += 1
            selected_cards.append((card_id, card_counts[card_id]))
    
    # Ensure exactly 50 cards
    if len(selected_cards) < deck_size:
        raise ValueError("Generated deck does not meet the 50-card requirement.")
    
    # Format deck list
    deck_list = [f"{count}x{card_id}" for card_id, count in selected_cards]
    
    print("Generated Legal Deck List:")
    print(f"1x{leader}")
    print("\n".join(deck_list))
    return deck_list


In [109]:
# Generate and print a legal deck
generate_deck(model, card_data)

Generated Legal Deck List:
1xEB01-021
1xOP01-061
1xOP08-056
1xOP07-058
1xOP08-002
1xOP06-042
1xOP01-067
1xOP02-062
1xP-052
1xOP07-072
1xOP01-075
1xST12-011
1xST03-001
1xP-047
1xOP08-055
1xOP01-071
1xOP08-040
1xOP02-063
1xOP06-047
1xOP05-055
1xOP03-049
1xOP06-119
1xOP07-044
1xOP05-051
1xOP08-052
1xOP02-071
1xOP04-044
1xOP07-051
1xOP07-054
1xP-056
1xOP08-051
1xOP01-062
1xP-009
1xST03-003
1xOP05-042
1xOP01-118
1xOP01-091
1xOP04-001
1xOP02-051
1xOP04-056
1xOP08-043
1xP-055
1xST04-001
1xOP06-058
1xOP07-038
1xOP02-090
1xOP04-058
1xOP02-026
1xOP05-050
1xOP07-052
1xST03-016


['1xOP01-061',
 '1xOP08-056',
 '1xOP07-058',
 '1xOP08-002',
 '1xOP06-042',
 '1xOP01-067',
 '1xOP02-062',
 '1xP-052',
 '1xOP07-072',
 '1xOP01-075',
 '1xST12-011',
 '1xST03-001',
 '1xP-047',
 '1xOP08-055',
 '1xOP01-071',
 '1xOP08-040',
 '1xOP02-063',
 '1xOP06-047',
 '1xOP05-055',
 '1xOP03-049',
 '1xOP06-119',
 '1xOP07-044',
 '1xOP05-051',
 '1xOP08-052',
 '1xOP02-071',
 '1xOP04-044',
 '1xOP07-051',
 '1xOP07-054',
 '1xP-056',
 '1xOP08-051',
 '1xOP01-062',
 '1xP-009',
 '1xST03-003',
 '1xOP05-042',
 '1xOP01-118',
 '1xOP01-091',
 '1xOP04-001',
 '1xOP02-051',
 '1xOP04-056',
 '1xOP08-043',
 '1xP-055',
 '1xST04-001',
 '1xOP06-058',
 '1xOP07-038',
 '1xOP02-090',
 '1xOP04-058',
 '1xOP02-026',
 '1xOP05-050',
 '1xOP07-052',
 '1xST03-016']