In [None]:
import os
backend = 'pytorch'
os.environ['DGLBACKEND'] = backend

In [None]:
import torch
import dgl
import networkx as nx
import tqdm.auto as tqdm
import pickle
import numpy as np
import pathlib

import torch.nn as nn
import torch.nn.functional as F
import dgl.nn

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score, accuracy_score
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
data_dir = pathlib.Path('/local/scratch/bh511/data/100000_instances')

train_set, _ = dgl.load_graphs(str(data_dir / 'train_graphs.bin'))
val_set, _ = dgl.load_graphs(str(data_dir / 'val_graphs.bin'))
val_set = val_set[:1000] # trim so it fits on the GPU

In [None]:
def prepare_batch(paths):
    H = []
    for path in paths:
        g = nx.read_gpickle(path)

        lg = nx.line_graph(g)
        features = {e: scaler.transform(g.edges[e]['x'][np.newaxis, :]).squeeze() for e in lg.nodes}
        labels = {e: g.edges[e]['y'] for e in lg.nodes}
        nx.set_node_attributes(lg, features, 'x')
        nx.set_node_attributes(lg, labels, 'y')

        h = dgl.from_networkx(lg, node_attrs=['x', 'y'])
        H.append(h)
    
    return dgl.batch(H)

In [None]:
class Net(nn.Module):
    def __init__(self, in_size, hidden_size, out_size, n_steps, activation=F.relu, dropout=0.0):
        super().__init__()
        
        self.activation = activation
        
        log2_in_size = max(np.floor(np.log2(in_size)).astype(int), 5) # min size = 32
        log2_hidden_size = np.floor(np.log2(hidden_size)).astype(int)
        log2_out_size = max(np.floor(np.log2(out_size)).astype(int), 5) # min size = 32
        
        embedding_layer_sizes = [2**x for x in range(log2_in_size, log2_hidden_size + 1)]
        embedding_layer_sizes.insert(0, in_size)
        
        decision_layer_sizes = [2**(-x) for x in range(-log2_hidden_size, -log2_out_size + 1)]
        decision_layer_sizes.append(out_size)
        
        self.embedding_layers = nn.ModuleList([nn.Linear(s1, s2) for s1, s2 in zip(embedding_layer_sizes[:-1], embedding_layer_sizes[1:])])
        
        self.msg_layer = dgl.nn.GatedGraphConv(hidden_size, hidden_size, n_steps=n_steps, n_etypes=1)
        
        self.decision_layers = nn.ModuleList([nn.Linear(s1, s2) for s1, s2 in zip(decision_layer_sizes[:-1], decision_layer_sizes[1:])])
    
    def forward(self, g, h):
        for l in self.embedding_layers:
            h = self.activation(l(h))
                
        etypes = torch.zeros(g.number_of_edges(), device=h.device)
        h = self.activation(self.msg_layer(g, h, etypes))
        
        for l in self.decision_layers[:-1]:
            h = self.activation(l(h))
        h = self.decision_layers[-1](h) # no activation on output layer
        
        return h

In [None]:
hidden_size = 256
n_steps = 4
activation_name = 'relu'
activation = getattr(F, activation_name)
batch_size = 128

In [None]:
in_size = train_set[0].ndata['x'].shape[1]
net = Net(in_size, hidden_size, 2, n_steps, activation)
net

In [None]:
if torch.cuda.is_available():
    net = net.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=5e-4)
# lr_sched = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.01)

n_epochs = 100

data_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, collate_fn=dgl.batch)

writer = SummaryWriter()

pbar = tqdm.trange(n_epochs)
for epoch in pbar:
    net.train()

    epoch_loss = 0
    for batch_i, batch in enumerate(data_loader):
        batch = batch.to(device)
        x = batch.ndata['x']
        y = batch.ndata['y']

        pos_weight = len(y)/y.sum() - 1
        w = torch.FloatTensor([1, pos_weight]).to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=w)
        
        optimizer.zero_grad()
        y_pred = net(batch, x)
        loss = criterion(y_pred, y.squeeze())
        loss.backward()
        optimizer.step()

        epoch_loss += loss.detach().item()

    epoch_loss /= (batch_i + 1)
    writer.add_scalar("Loss/train", epoch_loss, epoch)

    with torch.no_grad():
        net.eval()
        batch = dgl.batch(val_set).to(device)

        x = batch.ndata['x']
        y = batch.ndata['y']

        pos_weight = len(y)/y.sum() - 1
        w = torch.FloatTensor([1, pos_weight]).to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=w)

        y_pred = net(batch, x)
        val_loss = criterion(y_pred, y.squeeze())
        writer.add_scalar("Loss/validation", val_loss, epoch)
        
        y_prob = F.softmax(y_pred, dim=1).cpu()
        f1 = f1_score(y.squeeze().cpu(), y_prob[:, 1] > 0.5)
        acc = accuracy_score(y.squeeze().cpu(), y_prob[:, 1] > 0.5)
        writer.add_scalar("Metrics/F1 Score/validation", f1, epoch)
        writer.add_scalar("Metrics/Accuracy/validation", acc, epoch)

    pbar.set_postfix({
        'Train Loss': '{:.4f}'.format(epoch_loss),
        'Validation Loss': '{:.4f}'.format(val_loss),
    })

    writer.flush()
    
#     lr_sched.step()

writer.close()

In [None]:
test_set = [l.strip() for l in open(data_dir / 'test.txt')]
scaler = pickle.load(open(data_dir / 'scaler.pkl', 'rb'))

In [None]:
g = nx.read_gpickle(test_set[7])
lg = nx.line_graph(g)

features = {e: scaler.transform(g.edges[e]['x'][np.newaxis, :]).squeeze() for e in lg.nodes}
labels = {e: g.edges[e]['y'] for e in lg.nodes}
edges = {e: e for e in lg.nodes}
nx.set_node_attributes(lg, features, 'x')
nx.set_node_attributes(lg, labels, 'y')
nx.set_node_attributes(lg, edges, 'e')

h = dgl.from_networkx(lg, node_attrs=['x', 'y', 'e'])

In [None]:
h = h.to(device)
x = h.ndata['x']
y = h.ndata['y']
e = h.ndata['e']

with torch.no_grad():
    y_pred = net(h, x)
    y_prob = F.softmax(y_pred, dim=1)

In [None]:
p_in_solution = {tuple(k): v[1] for k, v in zip(e.cpu().numpy(), y_prob.cpu().numpy())}
in_solution = {e: float(g.edges[e]['y'][0]) for e in g.edges}

In [None]:
cmap_colors = np.zeros((100, 4))
cmap_colors[:, 0] = 1
cmap_colors[:, 3] = np.linspace(0, 1, 100)
cmap = ListedColormap(cmap_colors)

In [None]:
pos = nx.get_node_attributes(g, 'pos')

fig, ax = plt.subplots(1, 2, figsize=(10, 5))

nx.draw(g, pos, edge_color=in_solution.values(), edge_cmap=cmap, ax=ax[0], edge_vmax=1, edge_vmin=0)
nx.draw(g, pos, edge_color=p_in_solution.values(), edge_cmap=cmap, ax=ax[1], edge_vmax=1, edge_vmin=0)
ax[0].set_title('Optimal Solution')
ax[1].set_title('Output')

In [None]:
torch.save(net.state_dict(), 'net.bin')