In [1]:
from __future__ import division
from __future__ import print_function
import os, sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import argparse
import time
import random
import numpy as np
import scipy.sparse as sp
import torch

SEED = 4
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

from torch import optim
import torch.nn.functional as F
from model import ROD_lp
from optimizer import loss_function
from utils import *
from tqdm import tqdm
from sklearn.preprocessing import normalize, MinMaxScaler
from sklearn import metrics
import copy

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_hops', type=int, default=5)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--dims', type=int, default=[1024])
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--upd', type=int, default=10)
args = parser.parse_args(args=[])
print("Using {} dataset".format(args.dataset))

Using cora dataset


### 1. Data Fetching

In [3]:
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
adj, features, true_labels, idx_train, idx_val, idx_test = load_data(args.dataset)
n_nodes, feat_dim = features.shape
dims = [feat_dim] + args.dims

adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
adj.eliminate_zeros()
adj_orig = adj

adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
adj = adj_train
n = adj.shape[0]

adj_normalized = preprocess_graph(adj, norm='sym', renorm=True)
features = sp.csr_matrix(features).toarray()
    
feature_list = [features]
for i in range(args.num_hops):
    feature_list.append(adj_normalized.dot(feature_list[-1]))
input_feature = [torch.FloatTensor(feat).to(device) for feat in feature_list] 

adj_1st = (adj + sp.eye(n)).toarray()

adj_label = torch.FloatTensor(adj_1st)

neg_num = pos_num = adj_label.sum().long()

model = ROD_lp(dims, args.num_hops)

optimizer = optim.Adam(model.parameters(), lr=args.lr)

### 2. Data Preprocessing

In [4]:
sm_sim_mx_list = []
for i in range(args.num_hops+1):
    cur_feat = F.normalize(input_feature[i].cpu().data)
    sm_sim_mx_list.append(torch.mm(cur_feat, cur_feat.t()).reshape([-1,]))

adj_label = adj_label.reshape([-1, ])

model = model.to(device)
adj_label = adj_label.to(device)

pos_inds_list = []
neg_inds_list = []
for i in range(args.num_hops+1):
    pos_inds_list.append(np.argpartition(-sm_sim_mx_list[i], pos_num)[:pos_num])
    neg_inds_list.append(np.argpartition(sm_sim_mx_list[i], pos_num*200)[:pos_num*200])

pos_inds_ensemble = np.concatenate(pos_inds_list, axis=0)
neg_inds_ensemble = np.concatenate(neg_inds_list, axis=0)
    
length_ensemble = len(pos_inds_ensemble)
length_ensemble_neg = len(neg_inds_ensemble)
    
length = len(pos_inds_list[0])
length_neg = len(neg_inds_list[0])

pos_inds_ensemble_cuda = torch.LongTensor(pos_inds_ensemble).to(device)
pos_inds_cuda_list = [torch.LongTensor(pos_inds).to(device) for pos_inds in pos_inds_list]

batch_size = args.batch_size

### 3. Link Prediction

In [5]:
best_lp = 0.
best_emb_list = []
tqdm.write('Start Training...')
for epoch in tqdm(range(args.epochs)):
    model.train()
    optimizer.zero_grad()

    all_time_per_epoch = 0.
    t = time.time()
    z_list = model(input_feature)

    start, end = 0, batch_size
    start_ensemble, end_ensemble = 0, batch_size*args.num_hops
    batch_num = 0
    loss = 0.

    ran_head = np.random.randint(0, length_neg - length - 1)
    sampled_neg_list = []
    for i in range(args.num_hops+1):
        sampled_neg_list.append(torch.LongTensor(neg_inds_list[i][np.arange(ran_head, ran_head+length)]).to(device))
   
    ran_head_0 = np.random.randint(0, length_ensemble_neg - length_ensemble - 1)
    sample_neg_ensemble = torch.LongTensor(neg_inds_ensemble[np.arange(ran_head_0, ran_head_0+length_ensemble)]).cuda()

    while end_ensemble <= length_ensemble:
        sample_ensemble_inds = torch.cat((pos_inds_ensemble_cuda[start_ensemble:end_ensemble], sample_neg_ensemble[start_ensemble:end_ensemble]), 0)
        sample_ensemble_inds = sample_ensemble_inds.cuda()
        xind = sample_ensemble_inds // n_nodes
        yind = sample_ensemble_inds % n_nodes
        batch_label_original = torch.index_select(adj_label, 0, sample_ensemble_inds)

        batch_pred_ensemble_list = []
        for i in range(args.num_hops+1):
            sampled_inds = torch.cat((pos_inds_cuda_list[i][start:end], sampled_neg_list[i][start:end]), 0).to(device)
            xind = sampled_inds // n_nodes
            yind = sampled_inds % n_nodes
            zx_ensemble = torch.index_select(z_list[i], 0, xind)
            zy_ensemble = torch.index_select(z_list[i], 0, yind)
            zx = torch.index_select(z_list[i], 0, xind)
            zy = torch.index_select(z_list[i], 0, yind)
            batch_label = torch.cat((torch.ones(end-start), torch.zeros(end-start))).to(device)
            batch_label_original = torch.index_select(adj_label, 0, sampled_inds)
            batch_pred = (zx * zy).sum(1)
            batch_pred_ensemble = (zx_ensemble * zy_ensemble).sum(1)
            batch_pred_ensemble_list.append(batch_pred_ensemble)
            weight = torch.cat((batch_pred[:batch_size], 1-batch_pred[batch_size:]), 0).data
            loss += loss_function(adj_preds=batch_pred, adj_labels=batch_label_original)
            sm_sim_mx = sm_sim_mx_list[i].to(device)
            batch_label_soft = torch.index_select(sm_sim_mx, 0, sampled_inds)
            loss += 0.2*loss_function(adj_preds=batch_pred, adj_labels=batch_label, weight=weight)
            loss += 0.2*F.mse_loss(batch_pred, batch_label_soft)
        
        attention_scores = [torch.sigmoid(model.lr_att2(batch_pred.view(-1,1))).view(batch_pred.shape[0], 1) for batch_pred in batch_pred_ensemble_list]
        W = torch.cat(attention_scores, dim=1)
        W = F.softmax(W, 1)

        pred_ensemble = torch.mul(batch_pred_ensemble_list[0], W[:, 0])
        for i in range(1, args.num_hops+1):
            pred_ensemble += torch.mul(batch_pred_ensemble_list[i], W[:, i])

        #teacher loss
        for i in range(args.num_hops+1):
            loss += 0.1*F.kl_div(F.log_softmax(batch_pred_ensemble_list[i], dim=-1), F.softmax(pred_ensemble, dim=-1), reduction='mean')
            loss += 0.1*loss_function(adj_preds=batch_pred_ensemble_list[i], adj_labels=pred_ensemble)
            loss += loss_function(adj_preds=batch_pred_ensemble_list[i], adj_labels=batch_label_original)
     
        start_ensemble = end_ensemble
        start = end
        if end_ensemble < length_ensemble <= end_ensemble + (args.num_hops+1)*batch_size:
            break
        else:
            end += batch_size
            end_ensemble += (args.num_hops+1)*batch_size
        
    loss.backward()
    cur_loss = loss.item()
    optimizer.step()

    if (epoch + 1) % args.upd == 0:
        model.eval()
        z_list = model(input_feature)
        z_list = [zz.cpu().data.numpy() for zz in z_list]
    
        val_auc, val_ap = get_roc_score_ensemble(z_list, adj_orig, val_edges, val_edges_false)
        if val_auc + val_ap >= best_lp:
            best_lp = val_auc + val_ap
            best_emb_list = [z_list[0]]
            for i in range(1, len(z_list)):
                best_emb_list.append(z_list[i])
        tqdm.write("Epoch: {}, train_loss_gae={:.5f}, time={:.5f}".format(
            epoch + 1, cur_loss, time.time() - t))

tqdm.write("Optimization Finished!")
auc_score, ap_score = get_roc_score_ensemble(best_emb_list, adj_orig, test_edges, test_edges_false)
tqdm.write('Test AUC score: ' + str(auc_score))
tqdm.write('Test AP score: ' + str(ap_score))

Start Training...


  0%|          | 0/400 [00:00<?, ?it/s]

  2%|▎         | 10/400 [00:09<06:01,  1.08it/s]

Epoch: 10, train_loss_gae=265.56262, time=1.17799


  5%|▌         | 20/400 [00:17<05:46,  1.10it/s]

Epoch: 20, train_loss_gae=259.73642, time=1.16615


  8%|▊         | 30/400 [00:26<05:56,  1.04it/s]

Epoch: 30, train_loss_gae=255.40619, time=1.24016


 10%|█         | 40/400 [00:34<05:32,  1.08it/s]

Epoch: 40, train_loss_gae=251.52681, time=1.18879


 12%|█▎        | 50/400 [00:43<05:24,  1.08it/s]

Epoch: 50, train_loss_gae=249.87376, time=1.17592


 15%|█▌        | 60/400 [00:51<05:12,  1.09it/s]

Epoch: 60, train_loss_gae=245.72154, time=1.20336


 18%|█▊        | 70/400 [01:00<05:09,  1.07it/s]

Epoch: 70, train_loss_gae=243.93951, time=1.18776


 20%|██        | 80/400 [01:08<04:50,  1.10it/s]

Epoch: 80, train_loss_gae=242.54346, time=1.17101


 22%|██▎       | 90/400 [01:16<04:41,  1.10it/s]

Epoch: 90, train_loss_gae=241.48589, time=1.15625


 25%|██▌       | 100/400 [01:25<04:30,  1.11it/s]

Epoch: 100, train_loss_gae=239.85323, time=1.16544


 28%|██▊       | 110/400 [01:33<04:22,  1.10it/s]

Epoch: 110, train_loss_gae=238.25244, time=1.15104


 30%|███       | 120/400 [01:41<04:12,  1.11it/s]

Epoch: 120, train_loss_gae=237.89786, time=1.15018


 32%|███▎      | 130/400 [01:49<04:10,  1.08it/s]

Epoch: 130, train_loss_gae=235.63214, time=1.23590


 35%|███▌      | 140/400 [01:58<03:54,  1.11it/s]

Epoch: 140, train_loss_gae=234.91306, time=1.14559


 38%|███▊      | 150/400 [02:06<03:47,  1.10it/s]

Epoch: 150, train_loss_gae=234.75591, time=1.17282


 40%|████      | 160/400 [02:14<03:37,  1.10it/s]

Epoch: 160, train_loss_gae=235.47620, time=1.15709


 42%|████▎     | 170/400 [02:23<03:31,  1.09it/s]

Epoch: 170, train_loss_gae=233.50177, time=1.17283


 45%|████▌     | 180/400 [02:31<03:24,  1.07it/s]

Epoch: 180, train_loss_gae=233.52017, time=1.21372


 48%|████▊     | 190/400 [02:40<03:12,  1.09it/s]

Epoch: 190, train_loss_gae=231.84271, time=1.17852


 50%|█████     | 200/400 [02:48<03:02,  1.10it/s]

Epoch: 200, train_loss_gae=231.12776, time=1.17088


 52%|█████▎    | 210/400 [02:56<02:55,  1.08it/s]

Epoch: 210, train_loss_gae=229.98880, time=1.19796


 55%|█████▌    | 220/400 [03:05<02:44,  1.09it/s]

Epoch: 220, train_loss_gae=230.51297, time=1.16934


 57%|█████▊    | 230/400 [03:13<02:36,  1.09it/s]

Epoch: 230, train_loss_gae=229.94714, time=1.18920


 60%|██████    | 240/400 [03:22<02:31,  1.05it/s]

Epoch: 240, train_loss_gae=229.67207, time=1.20962


 62%|██████▎   | 250/400 [03:31<02:45,  1.10s/it]

Epoch: 250, train_loss_gae=230.17270, time=1.54185


 65%|██████▌   | 260/400 [03:41<02:26,  1.04s/it]

Epoch: 260, train_loss_gae=229.61797, time=1.41976


 68%|██████▊   | 270/400 [03:50<02:13,  1.03s/it]

Epoch: 270, train_loss_gae=228.34282, time=1.38383


 70%|███████   | 280/400 [03:59<02:08,  1.07s/it]

Epoch: 280, train_loss_gae=229.61380, time=1.45446


 72%|███████▎  | 290/400 [04:08<01:46,  1.04it/s]

Epoch: 290, train_loss_gae=227.68752, time=1.19056


 75%|███████▌  | 300/400 [04:17<01:39,  1.01it/s]

Epoch: 300, train_loss_gae=226.75423, time=1.22768


 78%|███████▊  | 310/400 [04:26<01:28,  1.02it/s]

Epoch: 310, train_loss_gae=226.07558, time=1.20234


 80%|████████  | 320/400 [04:36<01:29,  1.12s/it]

Epoch: 320, train_loss_gae=227.29707, time=1.49906


 82%|████████▎ | 330/400 [04:47<01:19,  1.14s/it]

Epoch: 330, train_loss_gae=227.79396, time=1.49225


 85%|████████▌ | 340/400 [04:56<01:03,  1.05s/it]

Epoch: 340, train_loss_gae=226.47716, time=1.41819


 88%|████████▊ | 350/400 [05:06<00:53,  1.08s/it]

Epoch: 350, train_loss_gae=225.37592, time=1.43759


 90%|█████████ | 360/400 [05:16<00:44,  1.11s/it]

Epoch: 360, train_loss_gae=226.34306, time=1.38019


 92%|█████████▎| 370/400 [05:26<00:32,  1.09s/it]

Epoch: 370, train_loss_gae=227.52266, time=1.44697


 95%|█████████▌| 380/400 [05:35<00:21,  1.08s/it]

Epoch: 380, train_loss_gae=226.24983, time=1.45776


 98%|█████████▊| 390/400 [05:45<00:10,  1.00s/it]

Epoch: 390, train_loss_gae=225.49178, time=1.25663


100%|██████████| 400/400 [05:54<00:00,  1.13it/s]


Epoch: 400, train_loss_gae=225.49290, time=1.39334
Optimization Finished!
Test AUC score: 0.9593848679828179
Test AP score: 0.9624760908405264
