In [3]:
import math
import os.path as osp
from itertools import chain

import momepy
import fiona
import geopandas as gpd
import networkx as nx
import osmnx as ox
import matplotlib.pyplot as plt

import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse.csgraph import shortest_path
from sklearn.metrics import roc_auc_score
from torch.nn import BCEWithLogitsLoss, Conv1d, MaxPool1d, ModuleList

from torch_geometric.data import Data, Batch, InMemoryDataset, download_url, extract_zip
from torch_geometric.loader import DataLoader
from torch_geometric.nn import MLP, GCNConv, global_sort_pool
from torch_geometric.transforms import RandomLinkSplit, OneHotDegree
from torch_geometric.utils import k_hop_subgraph, to_scipy_sparse_matrix
from torch_geometric.utils.convert import from_networkx

def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
dataset_root = './datasets'



In [4]:
full_dataset_label = 'No Bounds'
training_places = ['London']

# Fields to ignore
meridian_fields = ['meridian_id', 'meridian_gid', 'meridian_code',
                   'meridian_osodr', 'meridian_number', 'meridian_road_name',
                   'meridian_indicator', 'meridian_class', 'meridian_class_scale']
census_geom_fields = ['wz11cd', 'lsoa11nm', 'msoa11nm',
                      'oa11cd', 'lsoa11cd', 'msoa11cd'] # Allowed: lad11cd, lad11nm
misc_fields = ['id']
ignore_fields = meridian_fields + census_geom_fields + misc_fields


all_feature_fields = ['metres', 'choice2km', 'nodecount2km', 'integration2km',
                      'choice10km', 'nodecount10km','integration10km',
                      'choice100km','nodecount100km','integration100km']
rank_fields = ['choice2kmrank', 'choice10kmrank','integration10kmrank', 'integration2kmrank']
log_fields = ['choice2kmlog','choice10kmlog','choice100kmlog']

# dictionary caches - RELOADING THIS CELL DELETES GRAPH CACHES
loaded_gdfs = {}
loaded_graphs={}

full_gdf = gpd.read_file('./OpenMapping-gb-v1_gpkg/gpkg/ssx_openmapping_gb_v1.gpkg',
                    ignore_fields=ignore_fields)
included_places = full_gdf['lad11nm'].unique()
print(included_places)
osmnx_buffer = 10000

def load_gdf(place=training_places[0], verbose=False): #(W, S, E, N)
    """Geodataframe (gdf) loader with caching"""
    if place in included_places:
        # Retrieve matching rows corresponding to the Local Authority
        print(f'Loading {place} from SSx')
        gdf = full_gdf.query(f'lad11nm == "{place}"').copy()
    elif place == full_dataset_label:
        # Read full dataset without boundaries
        gdf = full_gdf.copy()
    else:
        print(f'Loading {place} from OSMnx')
        # Load gdf from osmnx (for testing only, graph will lack target attr)
        # Actually uses the Nominatim API:
        # https://nominatim.org/release-docs/latest/api/Overview/
        g = ox.graph.graph_from_place(place, buffer_dist=osmnx_buffer)
        g = ox.projection.project_graph(g)
        gdf = ox.utils_graph.graph_to_gdfs(g, nodes=False)
        gdf = gdf.rename(columns={'length': 'metres'})
        return gdf

    if verbose:
        print(f'{gdf.size} geometries retrieved from {place}')

    loaded_gdfs[place] = gdf
    return gdf

['Isle of Wight' 'Wycombe' None 'Enfield' 'Slough' 'South Bucks'
 'Hillingdon' 'Ealing' 'Chiltern' 'Copeland' 'Windsor and Maidenhead'
 'Plymouth' 'South Hams' 'Oxford' 'Waltham Forest' 'Mendip' 'Dudley'
 'Cotswold' 'Erewash' 'Redbridge' 'Epping Forest' 'Test Valley'
 'Basingstoke and Deane' 'South Gloucestershire' 'Woking' 'Broxbourne'
 'Wolverhampton' 'Wiltshire' 'Swindon' 'Bath and North East Somerset'
 'Trafford' 'Salford' 'South Staffordshire' 'West Oxfordshire'
 'Malvern Hills' 'Vale of White Horse' 'South Kesteven' 'North Kesteven'
 'Guildford' 'Southwark' 'Chichester' 'Waverley' 'Elmbridge'
 'Forest of Dean' 'Tewkesbury' 'Charnwood' 'Sheffield' 'Ashfield'
 'North West Leicestershire' 'North East Derbyshire' 'Stroud' 'Shropshire'
 'Telford and Wrekin' 'Horsham' 'City of London' 'Newcastle-under-Lyme'
 'Stafford' 'Stoke-on-Trent' 'Arun' 'Lichfield' 'Sandwell' 'Birmingham'
 'Amber Valley' 'Mole Valley' 'Bolsover' 'Rushcliffe' 'Wychavon' 'Gedling'
 'Gloucester' 'Liverpool' 'Newark 

In [5]:
def process_graph(g, feature_fields=[]):
    """
    Takes SSx metrics from adjacent edges and averages them to form the node attribute
    """
    for node, d in g.nodes(data=True):
        # get attributes from adjacent edges
        new_data = {field:[] for field in feature_fields}
        for _, _, edge_data in g.edges(node, data=True):
            for field in feature_fields:
                new_data[field].append(edge_data[field])
        
        d.clear()
        for field in feature_fields:
            d[field] = sum(new_data[field]) / len(new_data[field])
        
        # Encode coordinate as feature
        d['x'], d['y'] = node
    
    for u, v, k, d in g.edges(keys=True, data=True):
        d.clear()
        d['u'] = u
        d['v'] = v
        d['key'] = k
    return g

def load_graph(place, feature_fields=[], reload=True, verbose=False):
  if verbose:
    print(f'Loading graph of {place}...')
  key = (place)
  if key in loaded_graphs and reload:
    g = loaded_graphs[key]
    if verbose:
        print('Loaded existing graph.')
  else:
    gdf = load_gdf(place, verbose=verbose)
    G = momepy.gdf_to_nx(gdf, approach='primal')
    G = process_graph(G, feature_fields)
    
    if verbose:
      print(f'Generated graph with {G.number_of_nodes()} nodes and {G.number_of_edges()} edges')
    node_attrs = set(chain.from_iterable(d.keys() for *_, d in G.nodes(data=True)))
    edge_attrs = set(chain.from_iterable(d.keys() for *_, d in G.edges(data=True)))
    
    if verbose:
        # List node attributes
        print(f'Node attributes: {node_attrs}')
        print(f'Edge attributes: {edge_attrs}')                            
    
    # If no node attributes, node degree will be added later
    # Edge attribute (angle) are not added
    if len(node_attrs) > 0:
        g = from_networkx(G, group_node_attrs=list(node_attrs))
    else:
        g = from_networkx(G)
    loaded_graphs[key] = g
  return g

loaded_graphs = {}
data = load_graph('Birmingham', all_feature_fields, verbose=True)
data = data.to(device)
data

Loading graph of Birmingham...
Loading Birmingham from SSx
323840 geometries retrieved from Birmingham
Generated graph with 12509 nodes and 16192 edges
Node attributes: {'choice2km', 'integration100km', 'nodecount100km', 'choice100km', 'x', 'choice10km', 'metres', 'nodecount2km', 'nodecount10km', 'y', 'integration10km', 'integration2km'}
Edge attributes: {'key', 'v', 'u'}


Data(edge_index=[2, 32384], u=[32384, 2], v=[32384, 2], key=[32384], x=[12509, 12])

In [None]:
from torch_geometric.utils.convert import to_networkx

def visualize_preds(place, model):
    data = load_graph(place)
    gdf = load_gdf(place)
    streets = momepy.gdf_to_nx(gdf, approach='primal')

    # Randomly sample negative links and run predictions
    vis_data, _, _ = RandomLinkSplit(num_val=0, num_test=0, is_undirected=True)(data)
    z = model.encode(vis_data.x, vis_data.edge_index)
    out = model.decode(z, vis_data.edge_label_index).view(-1).sigmoid()
    vis_data.res = (out != vis_data.edge_label)
    
    streets2 = to_networkx(vis_data, edge_attrs=['res', 'u', 'v', 'key']).to_undirected()
    res_dict = {}
    for _, _, d in streets2.edges(data=True):
        u = tuple(d['u'])
        v = tuple(d['v'])
        k = d['key']
        res_dict[(u, v, k)] = d['res']
    nx.set_edge_attributes(streets, res_dict, 'res')
    points, lines = momepy.nx_to_gdf(streets)
    lines.plot(column='res', figsize=(20,10), cmap='Set1', legend=True)

     # plot original and predicted streets
    f, ax = plt.subplots(1, 2, figsize=(20, 10), sharex=True, sharey=True)
    for i, facet in enumerate(ax):
        facet.set_title(("Original", "Predicted")[i], size=20)
        facet.axis("off")
    lines.plot(ax=ax[0], column=f'choice{rad}{mod}', cmap='jet',
               legend=True, legend_kwds={'shrink': 0.3})
    lines.plot(ax=ax[1], column='preds', cmap='jet', legend=True,
               legend_kwds={'shrink': 0.3})
    plt.suptitle(f'\"{place}\": choice{rad}{mod}\n {graph.num_nodes} road nodes. R2: {r2:.3f}, KT: {kt:.3f}',
                 fontweight="bold", fontsize=20)
    plt.tight_layout()
    plt.show()
    

# Standard GAE-GCN Link Pred

In [None]:
from torch_geometric.utils import negative_sampling

transform = RandomLinkSplit(num_val=0.05, num_test=0.1,
                            is_undirected=True, add_negative_train_samples=False)
train_data, val_data, test_data = transform(data)
# After applying the `RandomLinkSplit` transform, the data is transformed from
# a data object to a list of tuples (train_data, val_data, test_data), with
# each element representing the corresponding split.


class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def encode(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)

    def decode(self, z, edge_label_index):
        return (z[edge_label_index[0]] * z[edge_label_index[1]]).sum(dim=-1)

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()


model = Net(data.num_features, 128, 64).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()


def train():
    model.train()
    optimizer.zero_grad()
    z = model.encode(train_data.x, train_data.edge_index)

    # We perform a new round of negative sampling for every training epoch:
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')

    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    out = model.decode(z, edge_label_index).view(-1)
    loss = criterion(out, edge_label)
    loss.backward()
    optimizer.step()
    return loss


@torch.no_grad()
def test(data):
    model.eval()
    z = model.encode(data.x, data.edge_index)
    out = model.decode(z, data.edge_label_index).view(-1).sigmoid()
    return roc_auc_score(data.edge_label.cpu().numpy(), out.cpu().numpy())


best_val_auc = final_test_auc = 0
print_every = 10
epochs = 1000
for epoch in range(1, epochs + 1):
    loss = train()
    val_auc = test(val_data)
    test_auc = test(test_data)
    if val_auc > best_val_auc:
        best_val = val_auc
        final_test_auc = test_auc
    if epoch % print_every == 0:
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
              f'Test: {test_auc:.4f}')

print(f'Final Test: {final_test_auc:.4f}')

z = model.encode(test_data.x, test_data.edge_index)
final_edge_index = model.decode_all(z)

In [1]:
train_data.edge_label_index.size(1)

NameError: name 'train_data' is not defined

# GVAE

In [26]:
class GCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv2 = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv2(x, edge_index)


class VariationalGCNEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 2 * out_channels)
        self.conv_mu = GCNConv(2 * out_channels, out_channels)
        self.conv_logstd = GCNConv(2 * out_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)


class LinearEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv(x, edge_index)


class VariationalLinearEncoder(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv_mu = GCNConv(in_channels, out_channels)
        self.conv_logstd = GCNConv(in_channels, out_channels)

    def forward(self, x, edge_index):
        return self.conv_mu(x, edge_index), self.conv_logstd(x, edge_index)

In [27]:
from torch_geometric.loader import DataLoader    
import torch_geometric.transforms as T

transform = T.Compose([
    T.NormalizeFeatures(),
    T.ToDevice(device),
])

def process_dataset(dataset, split=0.8, batch_size=64, verbose=False):
    # Train-test split
    idx = torch.randperm(len(dataset))
    split_idx = math.floor(split * len(dataset))
    train_idx = idx[:split_idx]
    test_idx = idx[split_idx:]
    
    train_dataset = [transform(dataset[i]) for i in train_idx]
    test_transform = T.Compose([
        transform,
        RandomLinkSplit(num_val=0, num_test=0, split_labels=True, is_undirected=True)
    ])
    test_dataset = [test_transform(dataset[i])[0] for i in test_idx]
    if verbose:
        print(f'Number of training graphs: {len(train_dataset)}')
        print(f'Number of test graphs: {len(test_dataset)}')

    # Load graphs into dataloader
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    if verbose:
        for step, data in enumerate(train_loader):
            print(f'Step {step + 1}:')
            print('=======')
            print(f'Number of graphs in the current batch: {data.num_graphs}')
            print(data)
            print()

    return train_dataset, test_dataset, train_loader, test_loader


dataset = torch.load('datasets/ssx_cities_dataset.pt')
train_dataset, test_dataset, train_loader, test_loader = process_dataset(dataset, batch_size=32, verbose=True)

Number of training graphs: 277
Number of test graphs: 70
Step 1:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 465518], x=[155354, 17], edge_attr=[465518, 1], y=[32, 3], batch=[155354], ptr=[33])

Step 2:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 580710], x=[189645, 17], edge_attr=[580710, 1], y=[32, 3], batch=[189645], ptr=[33])

Step 3:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 550286], x=[173579, 17], edge_attr=[550286, 1], y=[32, 3], batch=[173579], ptr=[33])

Step 4:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 467834], x=[148436, 17], edge_attr=[467834, 1], y=[32, 3], batch=[148436], ptr=[33])

Step 5:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 497438], x=[163496, 17], edge_attr=[497438, 1], y=[32, 3], batch=[163496], ptr=[33])

Step 6:
Number of graphs in the current batch: 32
DataBatch(edge_index=[2, 507448], x=[161853, 17], edge_attr=[507448, 1], y=[32, 3], batc

In [28]:
from torch_geometric.nn import GAE, VGAE
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/GVAE1_experiment_'+'2d_300_epochs')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
out_channels = 2
in_channels = dataset[0].num_node_features
epochs = 300
variational = True
linear = False

if not variational and not linear:
    model = GAE(GCNEncoder(in_channels, out_channels))
elif not variational and linear:
    model = GAE(LinearEncoder(in_channels, out_channels))
elif variational and not linear:
    model = VGAE(VariationalGCNEncoder(in_channels, out_channels))
elif variational and linear:
    model = VGAE(VariationalLinearEncoder(in_channels, out_channels))

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

def train(loader):
    model.train()
    loss_tot = 0
    loss_recon = 0
    loss_kl = 0
    for data in loader:
        optimizer.zero_grad()
        z = model.encode(data.x, data.edge_index)
        
        loss = model.recon_loss(z, data.edge_index)
        loss_recon += loss
        
        kl_loss = (1 / data.num_nodes) * model.kl_loss()
        loss_kl += kl_loss
        
        loss += kl_loss
        loss_tot += loss
        loss.backward()
        optimizer.step()
        
        loss_tot += loss
    loss_tot /= len(loader.dataset)
    loss_recon /= len(loader.dataset)
    loss_kl /= len(loader.dataset)
    return loss_tot, loss_recon, loss_kl


@torch.no_grad()
def test(loader):
    model.eval()
    auc_tot = 0
    ap_tot = 0
    for data in loader:
        z = model.encode(data.x, data.edge_index)
        auc, ap = model.test(z, data.pos_edge_label_index, data.neg_edge_label_index)
        auc_tot += auc
        ap_tot += ap
    return auc_tot / len(loader), ap_tot / len(loader)

for epoch in range(1, epochs + 1):
    total_loss, recon_loss, kl_loss = train(train_loader)
    auc, ap = test(test_dataset)
    writer.add_scalar('auc train',auc,epoch)
    writer.add_scalar('ap train',ap,epoch)
    print('Epoch: {:03d}, AUC: {:.4f}, AP: {:.4f}'.format(epoch, auc, ap))

Epoch: 001, AUC: 0.5747, AP: 0.6074
Epoch: 002, AUC: 0.5769, AP: 0.6092
Epoch: 003, AUC: 0.5802, AP: 0.6114
Epoch: 004, AUC: 0.5865, AP: 0.6147
Epoch: 005, AUC: 0.5929, AP: 0.6178
Epoch: 006, AUC: 0.5994, AP: 0.6210
Epoch: 007, AUC: 0.6061, AP: 0.6246
Epoch: 008, AUC: 0.6139, AP: 0.6287
Epoch: 009, AUC: 0.6218, AP: 0.6324
Epoch: 010, AUC: 0.6250, AP: 0.6320
Epoch: 011, AUC: 0.6259, AP: 0.6307
Epoch: 012, AUC: 0.6248, AP: 0.6289
Epoch: 013, AUC: 0.6248, AP: 0.6283
Epoch: 014, AUC: 0.6271, AP: 0.6292
Epoch: 015, AUC: 0.6305, AP: 0.6306
Epoch: 016, AUC: 0.6322, AP: 0.6312
Epoch: 017, AUC: 0.6352, AP: 0.6324
Epoch: 018, AUC: 0.6359, AP: 0.6326
Epoch: 019, AUC: 0.6356, AP: 0.6322
Epoch: 020, AUC: 0.6369, AP: 0.6327
Epoch: 021, AUC: 0.6413, AP: 0.6346
Epoch: 022, AUC: 0.6423, AP: 0.6349
Epoch: 023, AUC: 0.6447, AP: 0.6358
Epoch: 024, AUC: 0.6467, AP: 0.6366
Epoch: 025, AUC: 0.6480, AP: 0.6371
Epoch: 026, AUC: 0.6484, AP: 0.6371
Epoch: 027, AUC: 0.6490, AP: 0.6372
Epoch: 028, AUC: 0.6479, AP:

In [30]:
%load_ext tensorboard
%tensorboard --logdir=runs

# SEAL + DGCNN

In [14]:
class SEALDataset(InMemoryDataset):
    def __init__(self, dataset, num_hops, split='train'):
        self.data = dataset[0]
        self.num_hops = num_hops
        super().__init__(dataset_root)
        index = ['train', 'val', 'test'].index(split)
        self.data, self.slices = torch.load(self.processed_paths[index])

    @property
    def processed_file_names(self):
        return ['SEAL_train_data.pt', 'SEAL_val_data.pt', 'SEAL_test_data.pt']

    def process(self):
        transform = RandomLinkSplit(num_val=0.05, num_test=0.1,
                                    is_undirected=True, split_labels=True)
        train_data, val_data, test_data = transform(self.data)

        self._max_z = 0

        # Collect a list of subgraphs for training, validation and testing:
        train_pos_data_list = self.extract_enclosing_subgraphs(
            train_data.edge_index, train_data.pos_edge_label_index, 1)
        train_neg_data_list = self.extract_enclosing_subgraphs(
            train_data.edge_index, train_data.neg_edge_label_index, 0)

        val_pos_data_list = self.extract_enclosing_subgraphs(
            val_data.edge_index, val_data.pos_edge_label_index, 1)
        val_neg_data_list = self.extract_enclosing_subgraphs(
            val_data.edge_index, val_data.neg_edge_label_index, 0)

        test_pos_data_list = self.extract_enclosing_subgraphs(
            test_data.edge_index, test_data.pos_edge_label_index, 1)
        test_neg_data_list = self.extract_enclosing_subgraphs(
            test_data.edge_index, test_data.neg_edge_label_index, 0)

        # Convert node labeling to one-hot features.
        for data in chain(train_pos_data_list, train_neg_data_list,
                          val_pos_data_list, val_neg_data_list,
                          test_pos_data_list, test_neg_data_list):
            # We solely learn links from structure, dropping any node features:
            data.x = F.one_hot(data.z, self._max_z + 1).to(torch.float)

        torch.save(self.collate(train_pos_data_list + train_neg_data_list),
                   self.processed_paths[0])
        torch.save(self.collate(val_pos_data_list + val_neg_data_list),
                   self.processed_paths[1])
        torch.save(self.collate(test_pos_data_list + test_neg_data_list),
                   self.processed_paths[2])

    def extract_enclosing_subgraphs(self, edge_index, edge_label_index, y):
        data_list = []
        for src, dst in edge_label_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True)
            src, dst = mapping.tolist()

            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            z = self.drnl_node_labeling(sub_edge_index, src, dst,
                                        num_nodes=sub_nodes.size(0))

            data = Data(x=self.data.x[sub_nodes], z=z,
                        edge_index=sub_edge_index, y=y)
            data_list.append(data)

        return data_list

    def drnl_node_labeling(self, edge_index, src, dst, num_nodes=None):
        # Double-radius node labeling (DRNL).
        src, dst = (dst, src) if src > dst else (src, dst)
        adj = to_scipy_sparse_matrix(edge_index, num_nodes=num_nodes).tocsr()

        idx = list(range(src)) + list(range(src + 1, adj.shape[0]))
        adj_wo_src = adj[idx, :][:, idx]

        idx = list(range(dst)) + list(range(dst + 1, adj.shape[0]))
        adj_wo_dst = adj[idx, :][:, idx]

        dist2src = shortest_path(adj_wo_dst, directed=False, unweighted=True,
                                 indices=src)
        dist2src = np.insert(dist2src, dst, 0, axis=0)
        dist2src = torch.from_numpy(dist2src)

        dist2dst = shortest_path(adj_wo_src, directed=False, unweighted=True,
                                 indices=dst - 1)
        dist2dst = np.insert(dist2dst, src, 0, axis=0)
        dist2dst = torch.from_numpy(dist2dst)

        dist = dist2src + dist2dst
        dist_over_2, dist_mod_2 = dist // 2, dist % 2

        z = 1 + torch.min(dist2src, dist2dst)
        z += dist_over_2 * (dist_over_2 + dist_mod_2 - 1)
        z[src] = 1.
        z[dst] = 1.
        z[torch.isnan(z)] = 0.

        self._max_z = max(int(z.max()), self._max_z)

        return z.to(torch.long)

dataset = [data]

train_dataset = SEALDataset(dataset, num_hops=2, split='train')
val_dataset = SEALDataset(dataset, num_hops=2, split='val')
test_dataset = SEALDataset(dataset, num_hops=2, split='test')

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)


Processing...
  dist_over_2, dist_mod_2 = dist // 2, dist % 2
Done!


In [15]:
class DGCNN(torch.nn.Module):
    def __init__(self, hidden_channels, num_layers, GNN=GCNConv, k=0.6):
        super().__init__()

        if k < 1:  # Transform percentile to number.
            num_nodes = sorted([data.num_nodes for data in train_dataset])
            k = num_nodes[int(math.ceil(k * len(num_nodes))) - 1]
            k = max(10, k)
        self.k = int(k)

        self.convs = ModuleList()
        self.convs.append(GNN(train_dataset.num_features, hidden_channels))
        for i in range(0, num_layers - 1):
            self.convs.append(GNN(hidden_channels, hidden_channels))
        self.convs.append(GNN(hidden_channels, 1))

        conv1d_channels = [16, 32]
        total_latent_dim = hidden_channels * num_layers + 1
        conv1d_kws = [total_latent_dim, 5]
        self.conv1 = Conv1d(1, conv1d_channels[0], conv1d_kws[0],
                            conv1d_kws[0])
        self.maxpool1d = MaxPool1d(2, 2)
        self.conv2 = Conv1d(conv1d_channels[0], conv1d_channels[1],
                            conv1d_kws[1], 1)
        dense_dim = int((self.k - 2) / 2 + 1)
        dense_dim = (dense_dim - conv1d_kws[1] + 1) * conv1d_channels[1]
        self.mlp = MLP([dense_dim, 128, 1], dropout=0.5, batch_norm=False)

    def forward(self, x, edge_index, batch):
        xs = [x]
        for conv in self.convs:
            xs += [conv(xs[-1], edge_index).tanh()]
        x = torch.cat(xs[1:], dim=-1)

        # Global pooling.
        x = global_sort_pool(x, batch, self.k)
        x = x.unsqueeze(1)  # [num_graphs, 1, k * hidden]
        x = self.conv1(x).relu()
        x = self.maxpool1d(x)
        x = self.conv2(x).relu()
        x = x.view(x.size(0), -1)  # [num_graphs, dense_dim]

        return self.mlp(x)

In [16]:
model = DGCNN(hidden_channels=32, num_layers=3).to(device)
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.0001)
criterion = BCEWithLogitsLoss()


def train():
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = criterion(out.view(-1), data.y.to(torch.float))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs

    return total_loss / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()

    y_pred, y_true = [], []
    for data in loader:
        data = data.to(device)
        logits = model(data.x, data.edge_index, data.batch)
        y_pred.append(logits.view(-1).cpu())
        y_true.append(data.y.view(-1).cpu().to(torch.float))

    return roc_auc_score(torch.cat(y_true), torch.cat(y_pred))


best_val_auc = test_auc = 0
for epoch in range(1, 51):
    loss = train()
    val_auc = test(val_loader)
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        test_auc = test(test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_auc:.4f}, '
          f'Test: {test_auc:.4f}')

Epoch: 01, Loss: 0.5439, Val: 0.8484, Test: 0.8939
Epoch: 02, Loss: 0.4729, Val: 0.8508, Test: 0.8964
Epoch: 03, Loss: 0.4682, Val: 0.8532, Test: 0.8993
Epoch: 04, Loss: 0.4654, Val: 0.8562, Test: 0.9029
Epoch: 05, Loss: 0.4600, Val: 0.8595, Test: 0.9057
Epoch: 06, Loss: 0.4584, Val: 0.8601, Test: 0.9067
Epoch: 07, Loss: 0.4546, Val: 0.8630, Test: 0.9085
Epoch: 08, Loss: 0.4515, Val: 0.8620, Test: 0.9085
Epoch: 09, Loss: 0.4477, Val: 0.8633, Test: 0.9105
Epoch: 10, Loss: 0.4458, Val: 0.8654, Test: 0.9116
Epoch: 11, Loss: 0.4447, Val: 0.8671, Test: 0.9119
Epoch: 12, Loss: 0.4444, Val: 0.8678, Test: 0.9132
Epoch: 13, Loss: 0.4419, Val: 0.8682, Test: 0.9140
Epoch: 14, Loss: 0.4423, Val: 0.8671, Test: 0.9140
Epoch: 15, Loss: 0.4401, Val: 0.8672, Test: 0.9140
Epoch: 16, Loss: 0.4406, Val: 0.8668, Test: 0.9140
Epoch: 17, Loss: 0.4395, Val: 0.8675, Test: 0.9140
Epoch: 18, Loss: 0.4396, Val: 0.8679, Test: 0.9140
Epoch: 19, Loss: 0.4388, Val: 0.8673, Test: 0.9140
Epoch: 20, Loss: 0.4380, Val: 0

# RGCN

In [None]:
from torch.nn import Parameter
from torch_geometric.nn import GAE, RGCNConv

class RGCNEncoder(torch.nn.Module):
    def __init__(self, num_nodes, hidden_channels, num_relations):
        super().__init__()
        self.node_emb = Parameter(torch.Tensor(num_nodes, hidden_channels))
        self.conv1 = RGCNConv(hidden_channels, hidden_channels, num_relations,
                              num_blocks=5)
        self.conv2 = RGCNConv(hidden_channels, hidden_channels, num_relations,
                              num_blocks=5)
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.node_emb)
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, edge_index, edge_type):
        x = self.node_emb
        x = self.conv1(x, edge_index, edge_type).relu_()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x


class DistMultDecoder(torch.nn.Module):
    def __init__(self, num_relations, hidden_channels):
        super().__init__()
        self.rel_emb = Parameter(torch.Tensor(num_relations, hidden_channels))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.rel_emb)

    def forward(self, z, edge_index, edge_type):
        z_src, z_dst = z[edge_index[0]], z[edge_index[1]]
        rel = self.rel_emb[edge_type]
        return torch.sum(z_src * rel * z_dst, dim=1)


model = GAE(
    RGCNEncoder(data.num_nodes, hidden_channels=500,
                num_relations=dataset.num_relations),
    DistMultDecoder(dataset.num_relations // 2, hidden_channels=500),
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


def negative_sampling(edge_index, num_nodes):
    # Sample edges by corrupting either the subject or the object of each edge.
    mask_1 = torch.rand(edge_index.size(1)) < 0.5
    mask_2 = ~mask_1

    neg_edge_index = edge_index.clone()
    neg_edge_index[0, mask_1] = torch.randint(num_nodes, (mask_1.sum(), ))
    neg_edge_index[1, mask_2] = torch.randint(num_nodes, (mask_2.sum(), ))
    return neg_edge_index


def train():
    model.train()
    optimizer.zero_grad()

    z = model.encode(data.edge_index, data.edge_type)

    pos_out = model.decode(z, data.train_edge_index, data.train_edge_type)

    neg_edge_index = negative_sampling(data.train_edge_index, data.num_nodes)
    neg_out = model.decode(z, neg_edge_index, data.train_edge_type)

    out = torch.cat([pos_out, neg_out])
    gt = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
    cross_entropy_loss = F.binary_cross_entropy_with_logits(out, gt)
    reg_loss = z.pow(2).mean() + model.decoder.rel_emb.pow(2).mean()
    loss = cross_entropy_loss + 1e-2 * reg_loss

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
    optimizer.step()

    return float(loss)


@torch.no_grad()
def test():
    model.eval()
    z = model.encode(data.edge_index, data.edge_type)

    valid_mrr = compute_mrr(z, data.valid_edge_index, data.valid_edge_type)
    test_mrr = compute_mrr(z, data.test_edge_index, data.test_edge_type)

    return valid_mrr, test_mrr


@torch.no_grad()
def compute_mrr(z, edge_index, edge_type):
    ranks = []
    for i in tqdm(range(edge_type.numel())):
        (src, dst), rel = edge_index[:, i], edge_type[i]

        # Try all nodes as tails, but delete true triplets:
        tail_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        for (heads, tails), types in [
            (data.train_edge_index, data.train_edge_type),
            (data.valid_edge_index, data.valid_edge_type),
            (data.test_edge_index, data.test_edge_type),
        ]:
            tail_mask[tails[(heads == src) & (types == rel)]] = False

        tail = torch.arange(data.num_nodes)[tail_mask]
        tail = torch.cat([torch.tensor([dst]), tail])
        head = torch.full_like(tail, fill_value=src)
        eval_edge_index = torch.stack([head, tail], dim=0)
        eval_edge_type = torch.full_like(tail, fill_value=rel)

        out = model.decode(z, eval_edge_index, eval_edge_type)
        perm = out.argsort(descending=True)
        rank = int((perm == 0).nonzero(as_tuple=False).view(-1)[0])
        ranks.append(rank + 1)

        # Try all nodes as heads, but delete true triplets:
        head_mask = torch.ones(data.num_nodes, dtype=torch.bool)
        for (heads, tails), types in [
            (data.train_edge_index, data.train_edge_type),
            (data.valid_edge_index, data.valid_edge_type),
            (data.test_edge_index, data.test_edge_type),
        ]:
            head_mask[heads[(tails == dst) & (types == rel)]] = False

        head = torch.arange(data.num_nodes)[head_mask]
        head = torch.cat([torch.tensor([src]), head])
        tail = torch.full_like(head, fill_value=dst)
        eval_edge_index = torch.stack([head, tail], dim=0)
        eval_edge_type = torch.full_like(head, fill_value=rel)

        out = model.decode(z, eval_edge_index, eval_edge_type)
        perm = out.argsort(descending=True)
        rank = int((perm == 0).nonzero(as_tuple=False).view(-1)[0])
        ranks.append(rank + 1)

    return (1. / torch.tensor(ranks, dtype=torch.float)).mean()


for epoch in range(1, 10001):
    loss = train()
    print(f'Epoch: {epoch:05d}, Loss: {loss:.4f}')
    if (epoch % 500) == 0:
        valid_mrr, test_mrr = test()
        print(f'Val MRR: {valid_mrr:.4f}, Test MRR: {test_mrr:.4f}')