参考: https://github.com/amazon-science/co-with-gnns-example/blob/main/gnn_example.ipynb

In [1]:
import dgl
import torch
import random
import os
import numpy as np
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict, defaultdict
from dgl.nn.pytorch import GraphConv
from itertools import chain, islice, combinations
from networkx.algorithms.approximation import maximum_independent_set as mis
from time import time

# MacOS can have issues with MKL. For more details, see
# https://stackoverflow.com/questions/53014306/error-15-initializing-libiomp5-dylib-but-found-libiomp5-dylib-already-initial
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

from scipy import optimize

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# fix seed to ensure consistent results
seed_value = 1
random.seed(seed_value)        # seed python RNG
np.random.seed(seed_value)     # seed global NumPy RNG
torch.manual_seed(seed_value)  # seed torch RNG

# Set GPU/CPU
TORCH_DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
TORCH_DTYPE = torch.float32
print(f'Will use device: {TORCH_DEVICE}, torch dtype: {TORCH_DTYPE}')

Will use device: cuda, torch dtype: torch.float32


In [3]:
from utils import generate_graph, get_gnn, run_gnn_training, qubo_dict_to_torch, gen_combinations, loss_func

In [4]:
# helper function to generate Q matrix for Maximum Independent Set problem (MIS)
def gen_q_dict_mis(nx_G, penalty=2):
    """
    Helper function to generate QUBO matrix for MIS as minimization problem.
    
    Input:
        nx_G: graph as networkx graph object (assumed to be unweigthed)
    Output:
        Q_dic: QUBO as defaultdict
    """

    # Initialize our Q matrix
    Q_dic = defaultdict(int)

    # Update Q matrix for every edge in the graph
    # all off-diagonal terms get penalty
    for (u, v) in nx_G.edges:
        Q_dic[(u, v)] = penalty

    # all diagonal terms get -1
    for u in nx_G.nodes:
        Q_dic[(u, u)] = -1

    return Q_dic

# helper function to generate Q matrix for Max Cut problem (MC)
def gen_q_dict_mc(nx_G):
    """
    Helper function to generate QUBO matrix for MC as minimization problem.
    
    Input:
        nx_G: graph as networkx graph object (assumed to be unweigthed)
    Output:
        Q_dic: QUBO as defaultdict
    """

    # Initialize our Q matrix
    Q_dic = defaultdict(int)

    # Update Q matrix for every edge in the graph
    # all off-diagonal terms get penalty
    for (u, v) in nx_G.edges:
        Q_dic[(u, v)] = 2
        Q_dic[(u, u)] -= 1
        Q_dic[(v, v)] -= 1

    return Q_dic

def gen_adjacency_list(nx_G):
  adjacency_list = [[] for _ in range(nx_G.number_of_nodes())]

  for (u, v) in nx_G.edges:
    adjacency_list[u].append(v)
    adjacency_list[v].append(u)
  
  return adjacency_list

# Run classical MIS solver (provided by NetworkX)
def run_mis_solver(nx_graph):
    """
    helper function to run traditional solver for MIS.
    
    Input:
        nx_graph: networkx Graph object
    Output:
        ind_set_bitstring_nx: bitstring solution as list
        ind_set_nx_size: size of independent set (int)
        number_violations: number of violations of ind.set condition
    """
    # compare with traditional solver
    t_start = time()
    ind_set_nx = mis(nx_graph)
    t_solve = time() - t_start
    ind_set_nx_size = len(ind_set_nx)

    # get bitstring list
    nx_bitstring = [1 if (node in ind_set_nx) else 0 for node in sorted(list(nx_graph.nodes))]
    edge_set = set(list(nx_graph.edges))

    # Updated to be able to handle larger scale
    print('Calculating violations...')
    # check for violations
    number_violations = 0
    for ind_set_chunk in gen_combinations(combinations(ind_set_nx, 2), 100000):
        number_violations += len(set(ind_set_chunk).intersection(edge_set))

    return nx_bitstring, ind_set_nx_size, number_violations, t_solve

# Calculate results given bitstring and graph definition, includes check for violations
def postprocess_gnn_mis(best_bitstring, nx_graph):
    """
    helper function to postprocess MIS results

    Input:
        best_bitstring: bitstring as torch tensor
    Output:
        size_mis: Size of MIS (int)
        ind_set: MIS (list of integers)
        number_violations: number of violations of ind.set condition
    """

    # get bitstring as list
    bitstring_list = list(best_bitstring)

    # compute cost
    size_mis = sum(bitstring_list)

    # get independent set
    ind_set = set([node for node, entry in enumerate(bitstring_list) if entry == 1])
    edge_set = set(list(nx_graph.edges))

    print('Calculating violations...')
    # check for violations
    number_violations = 0
    for ind_set_chunk in gen_combinations(combinations(ind_set, 2), 100000):
        number_violations += len(set(ind_set_chunk).intersection(edge_set))

    return size_mis, ind_set, number_violations

In [5]:
for n in range(100, 101):
  for seed_value in range(1):
    # Step 1 - Set hyperparameters
    # Graph hypers
    # n = 100
    d = 3
    p = 0.1
    graph_type = 'reg'

    # Step 2 - Generate random graph
    # Constructs a random d-regular or p-probabilistic graph
    nx_graph = generate_graph(n=n, d=d, p=p, graph_type=graph_type, random_seed=seed_value)
    # get DGL graph from networkx graph, load onto device
    graph_dgl = dgl.from_networkx(nx_graph=nx_graph)
    graph_dgl = graph_dgl.to(TORCH_DEVICE)

    # Construct Q matrix for graph
    q_torch = qubo_dict_to_torch(nx_graph, gen_q_dict_mc(nx_graph), torch_dtype=TORCH_DTYPE, torch_device=TORCH_DEVICE)

    # # 平均場近似
    adjacency_list = gen_adjacency_list(nx_graph)
    # print(adjacency_list)

    # 解きたい関数をリストで戻す
    def func(x):
      t = 0.01
      beta = 1 / t

      func_list = []
      for i in range(n):
        func_list.append(x[i] + np.tanh(beta * sum(1 * x[j] for j in adjacency_list[i])))
      
      return func_list

    mfa_start = time()
    x_0 = 2 * np.random.rand(n) - 1
    result = optimize.root(func, x_0, method="broyden1")
    mfa_time = time() - mfa_start

    solution = []
    for m in result.x:
      if m > 0:
        solution.append(1)
      else:
        solution.append(0)

    solution = torch.tensor(solution).to(TORCH_DEVICE)
    mc_size_mfa = -loss_func(solution.float(), q_torch)
    mfa_tot_time = time() - mfa_start
    print(mc_size_mfa.item())

    outfile = open('data/mfa_n' + str(n) + '_d' + str(d) + '_seed' + str(seed_value) + '.dat', 'w')

    outfile.write('# 01: mc size 02: time 03: total time\n')
    outfile.write(str(mc_size_mfa.item()) + ' ' + str(mfa_time) + ' ' + str(mfa_tot_time) + '\n')

    outfile.close()

Generating d-regular graph with n=100, d=3, seed=0
133.0
