In [9]:
import dgl
import torch
import random
import os
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F
from collections import defaultdict
from dgl.nn.pytorch import GraphConv
from itertools import chain
from time import time

In [10]:
# Fix seed for reproducibility
seed_value = 1
random.seed(seed_value)
np.random.seed(seed_value)
torch.manual_seed(seed_value)

# Set GPU/CPU
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.float32
print(f'Using device: {TORCH_DEVICE}, dtype: {TORCH_DTYPE}')

Using device: cpu, dtype: torch.float32


In [11]:
# GNN class for spin glass problem
class GCN_dev(nn.Module):
    def __init__(self, in_feats, hidden_size, number_classes, dropout, device):
        super(GCN_dev, self).__init__()
        self.dropout_frac = dropout
        self.conv1 = GraphConv(in_feats, hidden_size).to(device)
        self.conv2 = GraphConv(hidden_size, number_classes).to(device)

    def forward(self, g, inputs):
        h = self.conv1(g, inputs)
        h = torch.relu(h)
        h = F.dropout(h, p=self.dropout_frac)
        h = self.conv2(g, h)
        h = torch.tanh(h)  # Use tanh for spin values in [-1, 1]
        return h

# Generate a 2D grid graph with periodic boundary conditions
def generate_grid_graph(n, graph_type='grid', random_seed=0):
    if graph_type == 'grid':
        side_length = int(n**0.5)
        if side_length**2 != n:
            raise ValueError("n must be a perfect square for a 2D grid graph.")
        print(f'Generating 2D grid graph with n={n}, side_length={side_length}, seed={random_seed}')
        nx_temp = nx.grid_2d_graph(side_length, side_length, periodic=True)
    else:
        raise NotImplementedError(f'Graph type {graph_type} not handled.')

    nx_temp = nx.relabel.convert_node_labels_to_integers(nx_temp)
    nx_graph = nx.OrderedGraph()
    nx_graph.add_nodes_from(sorted(nx_temp.nodes()))
    nx_graph.add_edges_from(nx_temp.edges)

    # Add random Gaussian-distributed edge weights
    torch.manual_seed(random_seed)
    for u, v in nx_graph.edges:
        nx_graph[u][v]['weight'] = torch.randn(1).item()

    return nx_graph

# Convert spin glass couplings to QUBO matrix
def spin_glass_to_qubo(J):
    """
    Convert a spin glass coupling matrix J to a QUBO matrix Q.
    J: Coupling matrix (n x n torch tensor)
    Returns: QUBO matrix Q (n x n torch tensor)
    """
    n = J.shape[0]
    Q = torch.zeros_like(J)
    for i in range(n):
        for j in range(n):
            Q[i, j] = -J[i, j]  # Map spin glass couplings to QUBO
    return Q

# Loss function for spin glass problem
def spin_glass_loss(spins, Q):
    """
    Compute the energy of a spin configuration using the QUBO matrix.
    spins: Spin configuration (n x 1 tensor)
    Q: QUBO matrix (n x n tensor)
    """
    return torch.matmul(spins.T, torch.matmul(Q, spins))

# Train the GNN
def run_gnn_training(Q, dgl_graph, net, embed, optimizer, number_epochs, tol, patience):
    inputs = embed.weight
    prev_loss = float('inf')
    count = 0
    best_spin_config = torch.ones(dgl_graph.number_of_nodes()).type(Q.dtype).to(Q.device)
    best_loss = spin_glass_loss(best_spin_config, Q)

    t_gnn_start = time()

    for epoch in range(number_epochs):
        # Forward pass
        probs = net(dgl_graph, inputs)[:, 0]
        spins = torch.tanh(probs)  # Map to [-1, 1]

        # Compute loss
        loss = spin_glass_loss(spins, Q)
        loss_ = loss.detach().item()

        # Track best solution
        if loss < best_loss:
            best_loss = loss
            best_spin_config = torch.sign(spins).clone()

        # Log progress
        if epoch % 1000 == 0:
            print(f'Epoch: {epoch}, Energy: {loss_}')

        # Early stopping
        if abs(loss_ - prev_loss) <= tol:
            count += 1
        else:
            count = 0

        if count >= patience:
            print(f'Stopping early at epoch {epoch} (patience: {patience}).')
            break

        # Update loss tracking
        prev_loss = loss_

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    t_gnn = time() - t_gnn_start
    print(f'GNN training (n={dgl_graph.number_of_nodes()}) took {round(t_gnn, 3)} seconds')
    print(f'GNN final energy: {loss_}')
    print(f'GNN best energy: {best_loss}')

    final_spin_config = torch.sign(probs)
    return net, epoch, final_spin_config, best_spin_config

In [12]:
# Main script
if __name__ == '__main__':
    # Graph parameters
    n = 36  # Number of nodes (must be a perfect square for a 2D grid graph) 
    random_seed = 137 

    # Generate graph with weights
    nx_graph = generate_grid_graph(n, graph_type='grid', random_seed=random_seed)
    edge_weights = torch.zeros((n, n))
    for u, v, data in nx_graph.edges(data=True):
        edge_weights[u, v] = data['weight']
        edge_weights[v, u] = data['weight']

    # Convert to DGL graph
    dgl_graph = dgl.from_networkx(nx_graph)
    dgl_graph = dgl_graph.to(TORCH_DEVICE)

    # Convert spin glass couplings to QUBO matrix
    Q = spin_glass_to_qubo(edge_weights).type(TORCH_DTYPE).to(TORCH_DEVICE)

    # GNN hyperparameters
    dim_embedding = 8 # int(np.sqrt(n))  # Embedding dimension
    hidden_dim = int(dim_embedding / 2)  # Hidden layer dimension
    dropout = 0.05
    number_classes = 1

    # Optimizer parameters
    learning_rate = 0.001
    opt_params = {'lr': learning_rate}

    # Instantiate GNN
    net = GCN_dev(dim_embedding, hidden_dim, number_classes, dropout, TORCH_DEVICE)
    net = net.type(TORCH_DTYPE).to(TORCH_DEVICE)
    embed = nn.Embedding(n, dim_embedding)
    embed = embed.type(TORCH_DTYPE).to(TORCH_DEVICE)

    # Set up Adam optimizer
    params = chain(net.parameters(), embed.parameters())
    optimizer = torch.optim.Adam(params, **opt_params)

    # Train GNN
    number_epochs = 15000
    tol = 1e-6
    patience = 2000

    print('Running GNN...')
    net, epoch, final_spin_config, best_spin_config = run_gnn_training(
        Q, dgl_graph, net, embed, optimizer, number_epochs, tol, patience
    )

    # Print final results
    print("Final spin configuration:", final_spin_config)
    print("Best spin configuration:", best_spin_config)

    

Generating 2D grid graph with n=36, side_length=6, seed=137
Running GNN...
Epoch: 0, Energy: -0.2203582525253296
Epoch: 1000, Energy: -46.405052185058594
Epoch: 2000, Energy: -45.649879455566406
Epoch: 3000, Energy: -47.42603302001953
Epoch: 4000, Energy: -47.69712829589844
Epoch: 5000, Energy: -47.71731948852539
Epoch: 6000, Energy: -47.69609832763672
Epoch: 7000, Energy: -47.62809753417969
Epoch: 8000, Energy: -47.72010803222656
Epoch: 9000, Energy: -47.72175598144531
Epoch: 10000, Energy: -47.72443389892578
Epoch: 11000, Energy: -47.72057342529297
Epoch: 12000, Energy: -47.72633743286133
Epoch: 13000, Energy: -47.672752380371094
Epoch: 14000, Energy: -47.72498321533203
GNN training (n=36) took 208.224 seconds
GNN final energy: -47.72235107421875
GNN best energy: -47.726829528808594
Final spin configuration: tensor([ 1.,  1.,  1., -1.,  1.,  1.,  1.,  1.,  1., -1., -1.,  1.,  1.,  1.,
        -1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1.,  1., -1.,  1.,  1., -1.,
         1.,  1.,  1.,