<a href="https://colab.research.google.com/github/pepsibetter/EJOR/blob/main/EJOR_Sample_Code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

In [None]:
import networkx as nx
import numpy as np
import random
import collections
import heapq
import copy
import math
import queue
from torch_geometric.data import Data
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv
from torch.nn import Linear

In [None]:
def generate_pages(num_pages):
    adjacency = collections.defaultdict(set) # dictionary to store links of certain pages
    vertices = num_pages # number of pages
    outdegree = collections.defaultdict(int)
    indegree = collections.defaultdict(int)
    size = dict() # to store size of each page
    a = 1.72
    m = 1
    
    for i in range(vertices): # page id 0 to (vertices - 1)
        out_degree = math.ceil((np.random.pareto(a,) + 1)*m)
        outdegree[i] = out_degree
        
        while(out_degree):
            link = random.randint(0, vertices - 1)
            if link != i and link not in adjacency[i]:
                adjacency[i].add(link)
                out_degree -= 1
    
    for v in adjacency.keys():
        for n in adjacency[v]:
            indegree[n] += 1
        size[v] = abs(random.gauss(10, 6)) # randomly assign size amount to each page
    
    homepage = 0

    return homepage, adjacency, size

In [None]:
def Dijkstra(source):
    global size
    global adjacency
    
    linkedpages = set() # to store visited ones
    queue = [source]
    while(queue):
        page = queue.pop()
        linkedpages.add(page)
        for p in adjacency[page]:
            if p not in linkedpages:
                queue.append(p)
    
    dist = {}
    for page_id in linkedpages:
        dist[page_id] = [float('inf'), str()] # slower when adding path, maybe caused by list structure
    dist[source] = [0, str(source)]
    Q = [] # a priority queue
    S = set() # to store already visited page
    heapq.heappush(Q, (dist[source][0], source, dist[source][1]))
    N = len(linkedpages)
    
    while(len(S) < N):
        dis, page_id, path = heapq.heappop(Q)
        if page_id not in S:
            S.add(page_id)
            for next_page in adjacency[page_id]:
                if next_page not in S:
                    #dist[next_page][0] = min(dist[next_page][0], dis + size[next_page])
                    if dis + size[next_page] < dist[next_page][0]:
                        dist[next_page][0] = dis + size[next_page]
                        dist[next_page][1] = path + '-' + str(next_page)
                    heapq.heappush(Q, (dist[next_page][0], next_page, dist[next_page][1]))        
    
    return dist

In [None]:
def subsets(terminals):
    allsubsets = []
    print(len(terminals))
    for i in range(int(math.pow(2, len(terminals)))):
        subset = []
        
        for j in range(len(terminals)):
            if (i&(1 << j) > 0):
                subset.append(terminals[j])
        
        allsubsets.append(subset)
    
    return allsubsets

In [None]:
def Transition(S):
    visited = set()
    while(Q):
        dis, node = heapq.heappop(Q)
#         if node not in visited:
#             visited.add(node)
        for last_node in indegree[node]:
            if dp[node][S] + size[node] < dp[last_node][S]:
                tmp = set([node])
                for j in path[node][S]:
                     tmp.add(j)
                path[last_node][S] = tmp.copy()
                dp[last_node][S] = dp[node][S] +size[node]
                heapq.heappush(Q, (dp[last_node][S], last_node))
#             for next_node in adjacency[node]:
#                 if dp[next_node][S] + size[next_node] < dp[node][S]:
#                     dp[node][S] =  dp[next_node][S] + size[next_node]
#                     heapq.heappush(Q, (dp[node][S], node))

In [None]:
homepage, adjacency, size = generate_pages(100) # the scale is 100
k = 3
dist = Dijkstra(homepage)
indegree = collections.defaultdict(set)
for i in adjacency.keys():
    for j in adjacency[i]:
        indegree[j].add(i)

dest_set = []
result_set = []

num_samples = 40000

for _ in range(num_samples): # generate data 
    destinations = random.sample(dist.keys(), k) 
    while homepage in destinations:
        destinations = random.sample(dist.keys(), k)
    destinations = sorted(destinations)

    # initalize
    endstate = 1<<k
    dp = [[float('inf') for i in range(endstate)] for _ in range(100)]
    count = 0
    Q = []
    for i in range(100):
        if i in destinations:
            dp[i][1<<count] = 0
            count += 1

    path = [[set([i]) for _ in range(endstate)] for i in range(100)]

    # DP
    for S in range(1, 1<<k):
        for i in range(100):
            sub = (S-1)&S
            while(sub):
                if dp[i][sub]+dp[i][S^sub] < dp[i][S]:
                    dp[i][S] = dp[i][sub]+dp[i][S^sub]
                    tmp = set()
                    for j in path[i][sub]:
                        tmp.add(j)
                    for j in path[i][S^sub]:
                        tmp.add(j)
                    path[i][S] = tmp.copy()
                sub = (sub-1)&S

            if dp[i][S] < float('inf'):
                heapq.heappush(Q, (dp[i][S], i))

        Transition(S)

    result = dp[homepage][(1<<k)-1]
    guidance = path[homepage][-1]
    #print(guidance)
    result_set.append(guidance)
    dest_set.append(destinations)

In [None]:
node_feature = [] # Node-weighted Steiner Tree
#min_size = np.min(list(size.values())) # feature normalization
#max_size = np.max(list(size.values()))
for i in range(len(dest_set)):
  features = [[x, 0] for x in size.values()]
  features[0][1] = 1
  for j in dest_set[i]:
    features[j][1] = 1
  node_feature.append(features)

In [None]:
edges = []
for i in adjacency.keys():
  for j in adjacency[i]:
    edges.append([i, j])
edge_index = torch.tensor(edges, dtype=torch.long) # Edge connectivity with shape [2, num_edges]

In [None]:
labels = [] # target labels
for i in range(len(result_set)):
  plain = [[0] for _ in range(100)]
  plain[0] = [1]
  for j in result_set[i]:
    plain[j] = [1]
  
  labels.append(plain)

#y = torch.tensor(labels[0], dtype=torch.int64) # targets

In [None]:
train_dataset = []
for i in range(int(num_samples*0.7)):
  x = torch.tensor(node_feature[i], dtype=torch.float) # Node feature matrix with shape [num_nodes, num_node_features (2 here, size, flag)]
  y = torch.tensor(labels[i], dtype=torch.float) # targets
  data = Data(x=x, y=y, edge_index=edge_index.t().contiguous()) # one set
  train_dataset.append(data)

In [None]:
test_dataset = []
for i in range(int(num_samples*0.7), num_samples):
  x = torch.tensor(node_feature[i], dtype=torch.float) # Node feature matrix with shape [num_nodes, num_node_features (2 here, size, flag)]
  y = torch.tensor(labels[i], dtype=torch.float) # targets
  data = Data(x=x, y=y, edge_index=edge_index.t().contiguous()) # one set
  test_dataset.append(data)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv1 = GATConv(2, 256, heads=4) # from input feature dimension
        self.lin1 = torch.nn.Linear(2, 4 * 256)
        
        self.conv2 = GATConv(4 * 256, 256, heads=4)
        self.lin2 = torch.nn.Linear(4 * 256, 4 * 256)
        
        self.conv3 = GATConv(4 * 256, 256, heads=4)
        self.lin3 = torch.nn.Linear(4 * 256, 4 * 256)
        
        self.conv4 = GATConv(4 * 256, 256, heads=4)
        self.lin4 = torch.nn.Linear(4 * 256, 4 * 256)
        
        self.conv5 = GATConv(4 * 256, 256, heads=4)
        self.lin5 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv6 = GATConv(4 * 256, 256, heads=4)
        self.lin6 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv7 = GATConv(4 * 256, 256, heads=4)
        self.lin7 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv8 = GATConv(4 * 256, 256, heads=4)
        self.lin8 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv9 = GATConv(4 * 256, 256, heads=4)
        self.lin9 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv10 = GATConv(4 * 256, 256, heads=4)
        self.lin10 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv11 = GATConv(4 * 256, 256, heads=4)
        self.lin11 = torch.nn.Linear(4 * 256, 4 * 256)

        self.conv12 = GATConv(4 * 256, 1, heads=6,
                             concat=False)
        self.lin12 = torch.nn.Linear(4 * 256, 1) # to output class

        self.norm1 = torch.nn.BatchNorm1d(4*256)
        self.norm2 = torch.nn.BatchNorm1d(4*256)
        self.norm3 = torch.nn.BatchNorm1d(4*256)
        self.norm4 = torch.nn.BatchNorm1d(4*256)
        self.norm5 = torch.nn.BatchNorm1d(4*256)
        self.norm6 = torch.nn.BatchNorm1d(4*256)
        self.norm7 = torch.nn.BatchNorm1d(4*256)
        self.norm8 = torch.nn.BatchNorm1d(4*256)
        self.norm9 = torch.nn.BatchNorm1d(4*256)
        self.norm10 = torch.nn.BatchNorm1d(4*256)
        self.norm11 = torch.nn.BatchNorm1d(4*256)

        # self.convs = torch.nn.ModuleList()
        # for layer in range(11):
        #     self.convs.append(GATConv(4 * 256, 256, heads=4))

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index) + self.lin1(x)
        x = self.norm1(x)
        x = F.relu(x)
        
        x = self.conv2(x, edge_index) + self.lin2(x)
        x = self.norm2(x)
        x = F.relu(x)
        
        x = self.conv3(x, edge_index) + self.lin3(x)
        x = self.norm3(x)
        x = F.relu(x)
        
        x = self.conv4(x, edge_index) + self.lin4(x)
        x = self.norm4(x)
        x = F.relu(x)

        x = self.conv5(x, edge_index) + self.lin5(x)
        x = self.norm5(x)
        x = F.relu(x)
        
        x = self.conv6(x, edge_index) + self.lin6(x)
        x = self.norm6(x)
        x = F.relu(x)

        x = self.conv7(x, edge_index) + self.lin7(x)
        x = self.norm7(x)
        x = F.relu(x)

        x = self.conv8(x, edge_index) + self.lin8(x)
        x = self.norm8(x)
        x = F.relu(x)

        x = self.conv9(x, edge_index) + self.lin9(x)
        x = self.norm9(x)
        x = F.relu(x)

        x = self.conv10(x, edge_index) + self.lin10(x)
        x = self.norm10(x)
        x = F.relu(x)

        x = self.conv11(x, edge_index) + self.lin11(x)
        x = self.norm11(x)
        x = F.relu(x)

        x = self.conv12(x, edge_index) + self.lin12(x)
        x = torch.sigmoid(x)

        return x

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adagrad(model.parameters(), lr=0.01) # Or use Adadelta 

In [None]:
def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        loss = criterion(model(data.x, data.edge_index), data.y)
        total_loss += loss.item() * data.num_graphs        
        loss.backward()
        optimizer.step()

    return total_loss / len(train_loader.dataset)

In [None]:
@torch.no_grad()
def test(loader):
    #global check_1, check_2
    model.eval()

    ys, preds = [], []
    count = 0
    for data in loader:
        out = model(data.x.to(device), data.edge_index.to(device))
        # check_1.append(data.y)
        # check_2.append(out)
        predicted = (out > 0.5).float().cpu()
        
        target_y = data.y.numpy().tolist()
        pred_y = predicted.numpy().tolist()

        ys.append(data.y)
        preds.append((out > 0.5).float().cpu())

        flag = 1
        for i in range(len(target_y)):
          if pred_y[i] != target_y[i]:
            flag = 0
        
        if flag:
          count += 1
    
    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()    
    node_count = 0
    for j in range(len(y)):
      if y[j] == pred[j]:
        node_count += 1

    return count / len(test_loader.dataset), node_count / len(y)

In [None]:
for epoch in range(100):
    loss = train()
    train_accuracy, train_node = test(train_loader)
    test_accuracy, test_node = test(test_loader)
    print('Epoch: {:02d}, Loss: {:.4f}, Node accuracy of train: {:.4f}, Node accuracy of test: {:.4f}, Path accuracy of Test: {:.4f}'.format(
        epoch, loss, train_node, test_node, test_accuracy))