In [1]:
import os
import torch
import torch.nn as nn 
import time
from tqdm import tqdm
from model import *
from utils import *
from copy import copy

from torch_geometric.utils import subgraph
from torch_geometric.nn import Node2Vec
from torch.utils.data import DataLoader
from torch_geometric.datasets import Planetoid, DeezerEurope, Twitch, CitationFull
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.utils import structured_negative_sampling



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(x, data):
    from sklearn.metrics import roc_auc_score, average_precision_score
    edge_index = data.edge_label_index
    labels = data.edge_label.long().cpu()
    
    s, t = edge_index
    s_emb = x[s].detach().cpu()
    t_emb = x[t].detach().cpu()

    scores = s_emb.mul(t_emb).sum(dim=-1)
    auc = roc_auc_score(y_true=labels, y_score=scores)
    ap = average_precision_score(y_true=labels, y_score=scores)
    return auc, ap


In [17]:
import pickle    
def save_file(data, path):
    with open(path, 'wb') as f:
        pickle.dump(data, f)
    /home/aemad/PycharmProjects/project_slkd/
def open_file(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data
# 
name = 'crocodile'
inductive = True

if inductive:
    train_data = open_file('/datasplits/'+'ind'+name+'_train_data.pickle')
    valid_data = open_file('/datasplits/'+'ind'+name+'_valid_data.pickle')
    test_data = open_file('/datasplits/'+'ind'+name+'_test_data.pickle')

else:
    train_data = open_file('/datasplits/'+name+'_train_data.pickle')
    valid_data = open_file('/datasplits/'+name+'_valid_data.pickle')
    test_data = open_file('/datasplits/'+name+'_test_data.pickle')


In [18]:
device = torch.device('cuda:1')
neg_num = 1

# node / attr /  inter
theta_list = (0.1,0.85,0.05)
lambda_list = (0.1,0.85,0.05)

In [19]:
save = True 
if save: 
    dist = precompute_dist_data(train_data.edge_index, num_nodes=train_data.num_nodes)
    #save_file(dist, '/home/aemad/PycharmProjects/project_slkd/datasplits/'+name+'_dist.pickle')
else:
    dist = open_file('/home/aemad/PycharmProjects/project_slkd/datasplits/'+name+'_dist.pickle')

100%|███████████████████████████████████████████| 10/10 [00:25<00:00,  2.50s/it]
9218it [00:30, 299.05it/s]


In [20]:
from copy import deepcopy

In [None]:
torch.manual_seed(10)
best_model = None
best = 0.0
for gamma1 in [1]:
    for gamma2 in [70]:
        print(gamma1, gamma2)
        deal = DEAL(128, train_data.x.shape[1], train_data.x.shape[0], device, None)
        torch.manual_seed(10)
        optimizer2 = torch.optim.Adam(deal.parameters(), lr=0.01)
        epochs = 2000
        losses = []

        for epoch in range(epochs):
            deal.train()
            loss = deal.default_loss(train_data.edge_label_index.T.to(device), train_data.edge_label.to(device), torch.tensor(dist).to(device), train_data.to(device), thetas=theta_list, train_num=train_data.edge_label_index.T.shape[0], gamma1=gamma1, gamma2=gamma2)

            optimizer2.zero_grad()
            loss.backward()
            optimizer2.step()
            losses.append(loss.item())

            with torch.no_grad():
                deal.eval()

                score, x = deal.evaluate_inductive(valid_data.edge_label_index.T, valid_data, lambda_list)
                #val_loss = deal.default_loss(valid_data.edge_label_index.T, valid_data.edge_label.to(device), torch.tensor(dist).to(device), valid_data, thetas=theta_list, train_num=valid_data.edge_index.T.shape[0])

                auc, ap = evaluate(
                    x,
                    valid_data
                )

            if ap > best:
                best_model = deepcopy(deal)
                best = ap
                #print('new best model')

            if (epoch + 1) % 10 == 0:
                print(f"Epoch {epoch + 1}/{epochs}, "
                  f"student loss: {loss.item():.8f}, "
                  #f"val loss: {val_loss.item():.8f}, "    
                  f"validation AUC: {auc}, AP: {ap}")



1 70
Epoch 10/2000, student loss: 0.30404227, validation AUC: 0.9014934732902116, AP: 0.8970569167371251
Epoch 20/2000, student loss: 0.29194818, validation AUC: 0.9025853290914754, AP: 0.8978674739613355
Epoch 30/2000, student loss: 0.28439066, validation AUC: 0.9038615419110447, AP: 0.8971540668282896
Epoch 40/2000, student loss: 0.27838250, validation AUC: 0.9067380243368672, AP: 0.897289889296387
Epoch 50/2000, student loss: 0.27322020, validation AUC: 0.908384400376181, AP: 0.897265779303192
Epoch 60/2000, student loss: 0.26755992, validation AUC: 0.9085992137639837, AP: 0.8962896309646667
Epoch 70/2000, student loss: 0.26355225, validation AUC: 0.9103789604608861, AP: 0.8968181285424687
Epoch 80/2000, student loss: 0.25878078, validation AUC: 0.9104482915857988, AP: 0.8968752785827956
Epoch 90/2000, student loss: 0.25480788, validation AUC: 0.9105207395996813, AP: 0.8959089477893136
Epoch 100/2000, student loss: 0.25128477, validation AUC: 0.9104761622064924, AP: 0.89500750116095

In [13]:
best_model.eval()
score, x = best_model.evaluate_inductive(valid_data.edge_label_index.T, valid_data, lambda_list)
               
auc, ap = evaluate(
                    x.detach(),
                    valid_data
                )
print(auc)
        
#testing       
with torch.no_grad():
    best_model.eval()
    import time
    t_0 = time.time()
    score, x = best_model.evaluate_inductive(test_data.edge_label_index.T, test_data, lambda_list)
    t_1 = time.time()
    elapsed_time = round((t_1 - t_0) * 10 ** 3, 3)
    print(elapsed_time)
    auc, ap = evaluate(
        x.detach(),
        test_data
    )

    print(f"testing AUC: {auc}, AP: {ap}")

0.8087824488544746
12.425
testing AUC: 0.6490326047686854, AP: 0.6361673036992709
