参考: https://github.com/amazon-science/co-with-gnns-example/blob/main/gnn_example.ipynb

In [1]:
import dgl
import torch
import random
import os
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict, defaultdict
from dgl.nn.pytorch import GraphConv
from itertools import chain, islice, combinations
# from networkx.algorithms.approximation.independent_set import maximum_independent_set as mis
from networkx.algorithms.approximation import maximum_independent_set as mis
from time import time

# MacOS can have issues with MKL. For more details, see
# https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [2]:
# fix seed to ensure consistent results
seed_value = 1
random.seed(seed_value)        # seed python RNG
np.random.seed(seed_value)     # seed global NumPy RNG
torch.manual_seed(seed_value)  # seed torch RNG

# Set GPU/CPU
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.float32
print(f'Will use device: {TORCH_DEVICE}, torch dtype: {TORCH_DTYPE}')

Will use device: cuda, torch dtype: torch.float32


In [3]:
from utils import generate_graph, get_gnn, run_gnn_training, qubo_dict_to_torch, gen_combinations, loss_func

In [4]:
# helper function to generate Q matrix for Maximum Independent Set problem (MIS)
def gen_q_dict_mis(nx_G, penalty=2):
    """
    Helper function to generate QUBO matrix for MIS as minimization problem.
    
    Input:
        nx_G: graph as networkx graph object (assumed to be unweigthed)
    Output:
        Q_dic: QUBO as defaultdict
    """

    # Initialize our Q matrix
    Q_dic = defaultdict(int)

    # Update Q matrix for every edge in the graph
    # all off-diagonal terms get penalty
    for (u, v) in nx_G.edges:
        Q_dic[(u, v)] = penalty

    # all diagonal terms get -1
    for u in nx_G.nodes:
        Q_dic[(u, u)] = -1

    return Q_dic

# helper function to generate Q matrix for Max Cut problem (MC)
def gen_q_dict_mc(nx_G):
    """
    Helper function to generate QUBO matrix for MC as minimization problem.
    
    Input:
        nx_G: graph as networkx graph object (assumed to be unweigthed)
    Output:
        Q_dic: QUBO as defaultdict
    """

    # Initialize our Q matrix
    Q_dic = defaultdict(int)

    # Update Q matrix for every edge in the graph
    # all off-diagonal terms get penalty
    for (u, v) in nx_G.edges:
        Q_dic[(u, v)] = 2
        Q_dic[(u, u)] -= 1
        Q_dic[(v, v)] -= 1

    return Q_dic

def gen_adjacency_list(nx_G):
  adjacency_list = [[] for _ in range(nx_G.number_of_nodes())]

  for (u, v) in nx_G.edges:
    adjacency_list[u].append(v)
    adjacency_list[v].append(u)
  
  return adjacency_list

# Run classical MIS solver (provided by NetworkX)
def run_mis_solver(nx_graph):
    """
    helper function to run traditional solver for MIS.
    
    Input:
        nx_graph: networkx Graph object
    Output:
        ind_set_bitstring_nx: bitstring solution as list
        ind_set_nx_size: size of independent set (int)
        number_violations: number of violations of ind.set condition
    """
    # compare with traditional solver
    t_start = time()
    ind_set_nx = mis(nx_graph)
    t_solve = time() - t_start
    ind_set_nx_size = len(ind_set_nx)

    # get bitstring list
    nx_bitstring = [1 if (node in ind_set_nx) else 0 for node in sorted(list(nx_graph.nodes))]
    edge_set = set(list(nx_graph.edges))

    # Updated to be able to handle larger scale
    print('Calculating violations...')
    # check for violations
    number_violations = 0
    for ind_set_chunk in gen_combinations(combinations(ind_set_nx, 2), 100000):
        number_violations += len(set(ind_set_chunk).intersection(edge_set))

    return nx_bitstring, ind_set_nx_size, number_violations, t_solve

# Calculate results given bitstring and graph definition, includes check for violations
def postprocess_gnn_mis(best_bitstring, nx_graph):
    """
    helper function to postprocess MIS results

    Input:
        best_bitstring: bitstring as torch tensor
    Output:
        size_mis: Size of MIS (int)
        ind_set: MIS (list of integers)
        number_violations: number of violations of ind.set condition
    """

    # get bitstring as list
    bitstring_list = list(best_bitstring)

    # compute cost
    size_mis = sum(bitstring_list)

    # get independent set
    ind_set = set([node for node, entry in enumerate(bitstring_list) if entry == 1])
    edge_set = set(list(nx_graph.edges))

    print('Calculating violations...')
    # check for violations
    number_violations = 0
    for ind_set_chunk in gen_combinations(combinations(ind_set, 2), 100000):
        number_violations += len(set(ind_set_chunk).intersection(edge_set))

    return size_mis, ind_set, number_violations

In [5]:
for n in range(100, 101):
  for seed_value in range(1):
    # Step 1 - Set hyperparameters
    # Graph hypers
    # n = 100
    d = 3
    p = 0.1
    graph_type = 'reg'

    # NN learning hypers #
    number_epochs = int(1e5)
    learning_rate = 1e-4
    PROB_THRESHOLD = 0.5

    # Early stopping to allow NN to train to near-completion
    tol = 1e-4          # loss must change by more than tol, or trigger
    patience = 1000    # number early stopping triggers before breaking loop

    # Problem size (e.g. graph size)
    # n = 100

    # Establish dim_embedding and hidden_dim values
    dim_embedding = int(np.cbrt(n))    # e.g. 10
    hidden_dim = int(dim_embedding/2)  # e.g. 5

    # Step 2 - Generate random graph
    # Constructs a random d-regular or p-probabilistic graph
    nx_graph = generate_graph(n=n, d=d, p=p, graph_type=graph_type, random_seed=seed_value)
    # get DGL graph from networkx graph, load onto device
    graph_dgl = dgl.from_networkx(nx_graph=nx_graph)
    graph_dgl = graph_dgl.to(TORCH_DEVICE)

    # Construct Q matrix for graph
    q_torch = qubo_dict_to_torch(nx_graph, gen_q_dict_mc(nx_graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)


    # Step 3 - Set up optimizer/GNN architecture
    # Establish pytorch GNN + optimizer
    opt_params = {'lr': learning_rate}
    gnn_hypers = {
        'dim_embedding': dim_embedding,
        'hidden_dim': hidden_dim,
        'dropout': 0.0,
        'number_classes': 1,
        'prob_threshold': PROB_THRESHOLD,
        'number_epochs': number_epochs,
        'tolerance': tol,
        'patience': patience
    }

    net, embed, optimizer = get_gnn(n, gnn_hypers, opt_params, TORCH_DEVICE, TORCH_DTYPE)

    # For tracking hyperparameters in results object
    gnn_hypers.update(opt_params)


    # Step 4 - Run GNN training
    print('Running GNN...')
    gnn_start = time()

    _, epoch, final_bitstring, best_bitstring = run_gnn_training(
        q_torch, graph_dgl, net, embed, optimizer, gnn_hypers['number_epochs'],
        gnn_hypers['tolerance'], gnn_hypers['patience'], gnn_hypers['prob_threshold'], n, d, seed_value)

    gnn_time = time() - gnn_start


    # Step 5 - Post-process GNN results
    final_loss = loss_func(final_bitstring.float(), q_torch)
    final_bitstring_str = ','.join([str(x) for x in final_bitstring])

    # Process bitstring reported by GNN
    # size_mis, ind_set, number_violations = postprocess_gnn_mis(best_bitstring, nx_graph)
    gnn_tot_time = time() - gnn_start

    # print(f'Independence number found by GNN is {size_mis} with {number_violations} violations')
    print(f'Max cut size found by GNN is {-final_loss}')
    print(f'Took {round(gnn_tot_time, 3)}s, model training took {round(gnn_time, 3)}s')

    outfile = open('data/2gnn_n' + str(n) + '_d' + str(d) + '_seed' + str(seed_value) + '.dat', 'w')
    outfile.write('# 01: mc_size 02: time 03: total time 04: epoch\n')
    outfile.write(str(-final_loss.item()) + ' ' + str(gnn_time) + ' ' + str(gnn_tot_time) + ' ' + str(epoch) + '\n')
    outfile.close()

Generating d-regular graph with n=100, d=3, seed=0
Running GNN...
Epoch: 0, Loss: -0.14785480499267578
Epoch: 1000, Loss: -1.819901943206787
Epoch: 2000, Loss: -12.973966598510742
Epoch: 3000, Loss: -39.122989654541016
Epoch: 4000, Loss: -67.81595611572266
Epoch: 5000, Loss: -89.59819030761719
Epoch: 6000, Loss: -103.88700103759766
Epoch: 7000, Loss: -113.13985443115234
Epoch: 8000, Loss: -119.13922119140625
Epoch: 9000, Loss: -122.97570037841797
Epoch: 10000, Loss: -125.3886947631836
Epoch: 11000, Loss: -126.84629821777344
Epoch: 12000, Loss: -127.71759033203125
Epoch: 13000, Loss: -128.23631286621094
Epoch: 14000, Loss: -128.54486083984375
Epoch: 15000, Loss: -128.72906494140625
Epoch: 16000, Loss: -128.83843994140625
Epoch: 17000, Loss: -128.903564453125
Stopping early on epoch 17050 (patience: 1000)
GNN training (n=100) took 89.751
GNN final continuous loss: -128.90602111816406
GNN best continuous loss: -128.90602111816406
Max cut size found by GNN is 129.0
Took 90.8s, model traini