In [48]:
import os
import os.path as osp

import numpy as np
import pandas as pd
from collections import defaultdict
from sklearn.metrics import roc_auc_score

import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import DataLoader
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
#import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

In [49]:
from dataset_dummy_features import Dec

In [50]:
dataset = Dec()

Processing...
Done!


In [51]:
data = dataset[0]

In [52]:
#data = pyg_utils.train_test_split_edges(data)

In [53]:
data.train_pos_edge_index = data.edge_index
data.val_pos_edge_index = data.edge_index
data.test_pos_edge_index = data.edge_index

In [54]:
data

Data(edge_index=[2, 43953], edge_type=[43953], test_pos_edge_index=[2, 43953], train_pos_edge_index=[2, 43953], val_pos_edge_index=[2, 43953], x=[7736, 10184])

In [6]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_features, 128)
        self.conv2 = GCNConv(128, 64)

    def encode(self):
        x = self.conv1(data.x, data.train_pos_edge_index)
        x = x.relu()
        return self.conv2(x, data.train_pos_edge_index)

    def decode(self, z, pos_edge_index, neg_edge_index):
        edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1)
        logits = (z[edge_index[0]] * z[edge_index[1]]).sum(dim=-1)
        return logits

    def decode_all(self, z):
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()

In [36]:
model = Net()

In [37]:
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)

In [38]:
def get_link_labels(pos_edge_index, neg_edge_index):
    E = pos_edge_index.size(1) + neg_edge_index.size(1)
    link_labels = torch.zeros(E, dtype=torch.float)
    link_labels[:pos_edge_index.size(1)] = 1.
    return link_labels

In [39]:
def train():
    model.train()

    neg_edge_index = pyg_utils.negative_sampling(
        edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes,
        num_neg_samples=data.train_pos_edge_index.size(1))

    optimizer.zero_grad()
    z = model.encode()
    link_logits = model.decode(z, data.train_pos_edge_index, neg_edge_index)
    link_labels = get_link_labels(data.train_pos_edge_index, neg_edge_index)
    loss = F.binary_cross_entropy_with_logits(link_logits, link_labels)
    loss.backward()
    optimizer.step()

    return loss

In [43]:
@torch.no_grad()
def test():
    model.eval()
    perfs = []

    pos_edge_index = data.edge_index
    neg_edge_index = pyg_utils.negative_sampling(
        edge_index=data.train_pos_edge_index, num_nodes=data.num_nodes,
        num_neg_samples=data.val_pos_edge_index.size(1))

    z = model.encode()
    link_logits = model.decode(z, pos_edge_index, neg_edge_index)
    link_probs = link_logits.sigmoid()
    link_labels = get_link_labels(pos_edge_index, neg_edge_index)
    perfs.append(roc_auc_score(link_labels.cpu(), link_probs.cpu()))
    return perfs

In [47]:
best_val_perf = test_perf = 0
for epoch in range(1, 11):
    train_loss = train()
    val_perf = test()
    log = 'Epoch: {:03d}, Loss: {:.4f}, Val: {:.4f}, Test: {:.4f}'
    print(log.format(epoch, train_loss, best_val_perf, test_perf))

Epoch: 001, Loss: 0.6932, Val: 0.0000, Test: 0.0000
Epoch: 002, Loss: 0.6932, Val: 0.0000, Test: 0.0000
Epoch: 003, Loss: 0.6932, Val: 0.0000, Test: 0.0000
Epoch: 004, Loss: 0.6933, Val: 0.0000, Test: 0.0000
Epoch: 005, Loss: 0.6934, Val: 0.0000, Test: 0.0000
Epoch: 006, Loss: 0.6935, Val: 0.0000, Test: 0.0000
Epoch: 007, Loss: 0.6936, Val: 0.0000, Test: 0.0000
Epoch: 008, Loss: 0.6937, Val: 0.0000, Test: 0.0000
Epoch: 009, Loss: 0.6938, Val: 0.0000, Test: 0.0000
Epoch: 010, Loss: 0.6939, Val: 0.0000, Test: 0.0000
