In [1]:
from __future__ import division
from __future__ import print_function
import os, sys
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=UserWarning)

import argparse
import time
import random
import numpy as np
import scipy.sparse as sp
import torch

SEED = 4
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

from torch import optim
import torch.nn.functional as F
from model import ROD_lp
from optimizer import loss_function
from utils import *
from tqdm import tqdm
from sklearn.preprocessing import normalize, MinMaxScaler
from sklearn import metrics
import copy

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_hops', type=int, default=5)
parser.add_argument('--epochs', type=int, default=400)
parser.add_argument('--dims', type=int, default=[1024])
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--batch_size', type=int, default=500)
parser.add_argument('--dataset', type=str, default='cora')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--upd', type=int, default=10)
args = parser.parse_args(args=[])
print("Using {} dataset".format(args.dataset))

Using cora dataset


### 1. Data Fetching

In [3]:
device = torch.device(f"cuda:{args.device}" if torch.cuda.is_available() else "cpu")
adj, features, true_labels, idx_train, idx_val, idx_test = load_data(args.dataset)
n_nodes, feat_dim = features.shape
dims = [feat_dim] + args.dims

adj = adj - sp.dia_matrix((adj.diagonal()[np.newaxis, :], [0]), shape=adj.shape)
adj.eliminate_zeros()
adj_orig = adj

adj_train, train_edges, val_edges, val_edges_false, test_edges, test_edges_false = mask_test_edges(adj)
adj = adj_train
n = adj.shape[0]

adj_normalized = preprocess_graph(adj, norm='sym', renorm=True)
features = sp.csr_matrix(features).toarray()
    
feature_list = [features]
for i in range(args.num_hops):
    feature_list.append(adj_normalized.dot(feature_list[-1]))
input_feature = [torch.FloatTensor(feat).to(device) for feat in feature_list] 

adj_1st = (adj + sp.eye(n)).toarray()

adj_label = torch.FloatTensor(adj_1st)

neg_num = pos_num = adj_label.sum().long()

model = ROD_lp(dims, args.num_hops)

optimizer = optim.Adam(model.parameters(), lr=args.lr)

### 2. Data Preprocessing

In [4]:
sm_sim_mx_list = []
for i in range(args.num_hops+1):
    cur_feat = F.normalize(input_feature[i].cpu().data)
    sm_sim_mx_list.append(torch.mm(cur_feat, cur_feat.t()).reshape([-1,]))

adj_label = adj_label.reshape([-1, ])

model = model.to(device)
adj_label = adj_label.to(device)

pos_inds_list = []
neg_inds_list = []
for i in range(args.num_hops+1):
    pos_inds_list.append(np.argpartition(-sm_sim_mx_list[i], pos_num)[:pos_num])
    neg_inds_list.append(np.argpartition(sm_sim_mx_list[i], pos_num*200)[:pos_num*200])

pos_inds_ensemble = np.concatenate(pos_inds_list, axis=0)
neg_inds_ensemble = np.concatenate(neg_inds_list, axis=0)
    
length_ensemble = len(pos_inds_ensemble)
length_ensemble_neg = len(neg_inds_ensemble)
    
length = len(pos_inds_list[0])
length_neg = len(neg_inds_list[0])

pos_inds_ensemble_cuda = torch.LongTensor(pos_inds_ensemble).to(device)
pos_inds_cuda_list = [torch.LongTensor(pos_inds).to(device) for pos_inds in pos_inds_list]

batch_size = args.batch_size

### 3. Link Prediction

In [5]:
best_lp = 0.
best_emb_list = []
tqdm.write('Start Training...')
for epoch in tqdm(range(args.epochs)):
    model.train()
    optimizer.zero_grad()

    all_time_per_epoch = 0.
    t = time.time()
    z_list = model(input_feature)

    start, end = 0, batch_size
    start_ensemble, end_ensemble = 0, batch_size*args.num_hops
    batch_num = 0
    loss = 0.

    ran_head = np.random.randint(0, length_neg - length - 1)
    sampled_neg_list = []
    for i in range(args.num_hops+1):
        sampled_neg_list.append(torch.LongTensor(neg_inds_list[i][np.arange(ran_head, ran_head+length)]).to(device))
   
    ran_head_0 = np.random.randint(0, length_ensemble_neg - length_ensemble - 1)
    sample_neg_ensemble = torch.LongTensor(neg_inds_ensemble[np.arange(ran_head_0, ran_head_0+length_ensemble)]).cuda()

    while end_ensemble <= length_ensemble:
        sample_ensemble_inds = torch.cat((pos_inds_ensemble_cuda[start_ensemble:end_ensemble], sample_neg_ensemble[start_ensemble:end_ensemble]), 0)
        sample_ensemble_inds = sample_ensemble_inds.cuda()
        xind = sample_ensemble_inds // n_nodes
        yind = sample_ensemble_inds % n_nodes
        batch_label_original = torch.index_select(adj_label, 0, sample_ensemble_inds)

        batch_pred_ensemble_list = []
        for i in range(args.num_hops+1):
            sampled_inds = torch.cat((pos_inds_cuda_list[i][start:end], sampled_neg_list[i][start:end]), 0).to(device)
            xind = sampled_inds // n_nodes
            yind = sampled_inds % n_nodes
            zx_ensemble = torch.index_select(z_list[i], 0, xind)
            zy_ensemble = torch.index_select(z_list[i], 0, yind)
            zx = torch.index_select(z_list[i], 0, xind)
            zy = torch.index_select(z_list[i], 0, yind)
            batch_label = torch.cat((torch.ones(end-start), torch.zeros(end-start))).to(device)
            batch_label_original = torch.index_select(adj_label, 0, sampled_inds)
            batch_pred = (zx * zy).sum(1)
            batch_pred_ensemble = (zx_ensemble * zy_ensemble).sum(1)
            batch_pred_ensemble_list.append(batch_pred_ensemble)
            weight = torch.cat((batch_pred[:batch_size], 1-batch_pred[batch_size:]), 0).data
            loss += loss_function(adj_preds=batch_pred, adj_labels=batch_label_original)
            sm_sim_mx = sm_sim_mx_list[i].to(device)
            batch_label_soft = torch.index_select(sm_sim_mx, 0, sampled_inds)
            loss += 0.2*loss_function(adj_preds=batch_pred, adj_labels=batch_label, weight=weight)
            loss += 0.2*F.mse_loss(batch_pred, batch_label_soft)
        
        attention_scores = [torch.sigmoid(model.lr_att2(batch_pred.view(-1,1))).view(batch_pred.shape[0], 1) for batch_pred in batch_pred_ensemble_list]
        W = torch.cat(attention_scores, dim=1)
        W = F.softmax(W, 1)

        pred_ensemble = torch.mul(batch_pred_ensemble_list[0], W[:, 0])
        for i in range(1, args.num_hops+1):
            pred_ensemble += torch.mul(batch_pred_ensemble_list[i], W[:, i])

        #teacher loss
        for i in range(args.num_hops+1):
            loss += 0.1*F.kl_div(F.log_softmax(batch_pred_ensemble_list[i], dim=-1), F.softmax(pred_ensemble, dim=-1), reduction='mean')
            loss += 0.1*loss_function(adj_preds=batch_pred_ensemble_list[i], adj_labels=pred_ensemble)
            loss += loss_function(adj_preds=batch_pred_ensemble_list[i], adj_labels=batch_label_original)
     
        start_ensemble = end_ensemble
        start = end
        if end_ensemble < length_ensemble <= end_ensemble + (args.num_hops+1)*batch_size:
            break
        else:
            end += batch_size
            end_ensemble += (args.num_hops+1)*batch_size
        
    loss.backward()
    cur_loss = loss.item()
    optimizer.step()

    if (epoch + 1) % args.upd == 0:
        model.eval()
        z_list = model(input_feature)
        z_list = [zz.cpu().data.numpy() for zz in z_list]
    
        val_auc, val_ap = get_roc_score_ensemble(z_list, adj_orig, val_edges, val_edges_false)
        if val_auc + val_ap >= best_lp:
            best_lp = val_auc + val_ap
            best_emb_list = [z_list[0]]
            for i in range(1, len(z_list)):
                best_emb_list.append(z_list[i])
        tqdm.write("Epoch: {}, train_loss_gae={:.5f}, time={:.5f}".format(
            epoch + 1, cur_loss, time.time() - t))

tqdm.write("Optimization Finished!")
auc_score, ap_score = get_roc_score_ensemble(best_emb_list, adj_orig, test_edges, test_edges_false)
tqdm.write('Test AUC score: ' + str(auc_score))
tqdm.write('Test AP score: ' + str(ap_score))


  0%|          | 0/400 [00:00<?, ?it/s]

Start Training...


  2%|▎         | 10/400 [00:05<03:52,  1.68it/s]

Epoch: 10, train_loss_gae=282.18314, time=0.92969


  5%|▌         | 20/400 [00:10<04:16,  1.48it/s]

Epoch: 20, train_loss_gae=274.31812, time=1.10809


  8%|▊         | 30/400 [00:15<03:51,  1.60it/s]

Epoch: 30, train_loss_gae=267.98447, time=0.98193


 10%|█         | 40/400 [00:20<03:57,  1.52it/s]

Epoch: 40, train_loss_gae=262.95465, time=1.07796


 12%|█▎        | 50/400 [00:26<03:35,  1.63it/s]

Epoch: 50, train_loss_gae=260.54636, time=0.95535


 15%|█▌        | 60/400 [00:31<03:32,  1.60it/s]

Epoch: 60, train_loss_gae=256.96902, time=0.97111


 18%|█▊        | 70/400 [00:36<03:37,  1.52it/s]

Epoch: 70, train_loss_gae=255.19661, time=1.08256


 20%|██        | 80/400 [00:41<03:24,  1.57it/s]

Epoch: 80, train_loss_gae=252.25406, time=1.03497


 22%|██▎       | 90/400 [00:46<03:06,  1.66it/s]

Epoch: 90, train_loss_gae=251.05368, time=0.92694


 25%|██▌       | 100/400 [00:51<02:54,  1.72it/s]

Epoch: 100, train_loss_gae=248.65565, time=0.89220


 28%|██▊       | 110/400 [00:56<02:51,  1.69it/s]

Epoch: 110, train_loss_gae=248.84816, time=0.93304


 30%|███       | 120/400 [01:01<02:47,  1.67it/s]

Epoch: 120, train_loss_gae=245.43610, time=0.95098


 32%|███▎      | 130/400 [01:06<02:38,  1.71it/s]

Epoch: 130, train_loss_gae=243.90959, time=0.90385


 35%|███▌      | 140/400 [01:11<02:31,  1.71it/s]

Epoch: 140, train_loss_gae=242.45346, time=0.90405


 38%|███▊      | 150/400 [01:15<02:25,  1.72it/s]

Epoch: 150, train_loss_gae=241.46948, time=0.89645


 40%|████      | 160/400 [01:20<02:20,  1.71it/s]

Epoch: 160, train_loss_gae=241.79732, time=0.92180


 42%|████▎     | 170/400 [01:25<02:27,  1.56it/s]

Epoch: 170, train_loss_gae=241.08754, time=1.10694


 45%|████▌     | 180/400 [01:30<02:07,  1.73it/s]

Epoch: 180, train_loss_gae=240.55112, time=0.90286


 48%|████▊     | 190/400 [01:35<02:03,  1.71it/s]

Epoch: 190, train_loss_gae=238.82446, time=0.91873


 50%|█████     | 200/400 [01:40<01:55,  1.73it/s]

Epoch: 200, train_loss_gae=237.97897, time=0.90107


 52%|█████▎    | 210/400 [01:44<01:51,  1.71it/s]

Epoch: 210, train_loss_gae=236.97346, time=0.92358


 55%|█████▌    | 220/400 [01:49<01:44,  1.72it/s]

Epoch: 220, train_loss_gae=237.07791, time=0.92098


 57%|█████▊    | 230/400 [01:54<01:40,  1.69it/s]

Epoch: 230, train_loss_gae=236.54646, time=0.93684


 60%|██████    | 240/400 [01:59<01:33,  1.72it/s]

Epoch: 240, train_loss_gae=238.63480, time=0.90264


 62%|██████▎   | 250/400 [02:04<01:30,  1.67it/s]

Epoch: 250, train_loss_gae=237.31531, time=0.94227


 65%|██████▌   | 260/400 [02:09<01:26,  1.63it/s]

Epoch: 260, train_loss_gae=233.23711, time=1.00346


 68%|██████▊   | 270/400 [02:14<01:18,  1.65it/s]

Epoch: 270, train_loss_gae=233.55034, time=0.97825


 70%|███████   | 280/400 [02:19<01:10,  1.69it/s]

Epoch: 280, train_loss_gae=233.70340, time=0.93087


 72%|███████▎  | 290/400 [02:24<01:04,  1.71it/s]

Epoch: 290, train_loss_gae=233.00595, time=0.90710


 75%|███████▌  | 300/400 [02:28<00:59,  1.67it/s]

Epoch: 300, train_loss_gae=235.20827, time=0.95313


 78%|███████▊  | 310/400 [02:33<00:54,  1.66it/s]

Epoch: 310, train_loss_gae=232.91017, time=0.96240


 80%|████████  | 320/400 [02:38<00:47,  1.68it/s]

Epoch: 320, train_loss_gae=231.46730, time=0.94379


 82%|████████▎ | 330/400 [02:43<00:41,  1.68it/s]

Epoch: 330, train_loss_gae=231.91656, time=0.93897


 85%|████████▌ | 340/400 [02:48<00:35,  1.69it/s]

Epoch: 340, train_loss_gae=230.14760, time=0.93192


 88%|████████▊ | 350/400 [02:53<00:31,  1.59it/s]

Epoch: 350, train_loss_gae=230.48987, time=1.05451


 90%|█████████ | 360/400 [02:58<00:24,  1.65it/s]

Epoch: 360, train_loss_gae=231.71518, time=0.97203


 92%|█████████▎| 370/400 [03:03<00:17,  1.71it/s]

Epoch: 370, train_loss_gae=231.18553, time=0.90440


 95%|█████████▌| 380/400 [03:08<00:11,  1.70it/s]

Epoch: 380, train_loss_gae=231.42209, time=0.92595


 98%|█████████▊| 390/400 [03:13<00:05,  1.69it/s]

Epoch: 390, train_loss_gae=231.00861, time=0.92533


100%|██████████| 400/400 [03:17<00:00,  2.02it/s]


Epoch: 400, train_loss_gae=230.78877, time=0.90777
Optimization Finished!
Test AUC score: 0.9599717710429951
Test AP score: 0.9640584566332792
