In [1]:
import os
import torch
import torch.nn as nn 
import time
from tqdm import tqdm
from copy import copy

from torch_geometric.utils import subgraph
from torch.utils.data import DataLoader
from torch_geometric.datasets import Planetoid, DeezerEurope, Twitch, AttributedGraphDataset
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from torch_geometric.utils import structured_negative_sampling
from torch_geometric.nn import MetaPath2Vec
    
import numpy

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def evaluate(x, data, edge_types, exclude = None):
    from sklearn.metrics import roc_auc_score, average_precision_score
    auc = 0; ap = 0
    count = 0
    for edge_type in edge_types:
            
        edge_index = data[edge_type].edge_label_index
        labels = data[edge_type].edge_label.long()
        
        s, t = edge_index
        src_type, trg_type = edge_type[0], edge_type[2]
        
        if src_type not in x.keys() or trg_type not in x.keys():
                continue
            
        if exclude is not None and (src_type in exclude or trg_type in exclude):
            continue
            
        count += 1
        s_emb = x[src_type][s].detach().cpu()
        t_emb = x[trg_type][t].detach().cpu()
        scores = s_emb.mul(t_emb).sum(dim=-1).cpu().numpy()
        auc += roc_auc_score(y_true=labels, y_score=scores)
        ap += average_precision_score(y_true=labels, y_score=scores)
        
    return auc/count, ap/count


In [3]:
import pickle    
def save_file(data, path):
    with open(path, 'wb') as f:
        pickle.dump(data, f)
    
def open_file(path):
    with open(path, 'rb') as f:
        data = pickle.load(f)
    return data

name = 'IMDB'
train_data = open_file('/home/aemad/PycharmProjects/project_slkd/datasplits/'+name+'_train_data.pickle')
valid_data = open_file('/home/aemad/PycharmProjects/project_slkd/datasplits/'+name+'_valid_data.pickle')
test_data = open_file('/home/aemad/PycharmProjects/project_slkd/datasplits/'+name+'_test_data.pickle')


In [4]:
train_data

HeteroData(
  [1mmovie[0m={
    x=[4932, 3489],
    y=[4932, 5],
    train_mask=[4932],
    test_mask=[4932]
  },
  [1mdirector[0m={ x=[2393, 3341] },
  [1mactor[0m={ x=[6124, 3341] },
  [1mkeyword[0m={
    num_nodes=7971,
    x=[7971, 1]
  },
  [1m(movie, to, director)[0m={
    edge_index=[2, 3453],
    edge_label=[6906],
    edge_label_index=[2, 6906]
  },
  [1m(director, to, movie)[0m={
    edge_index=[2, 3453],
    edge_label=[6906],
    edge_label_index=[2, 6906]
  },
  [1m(movie, >actorh, actor)[0m={
    edge_index=[2, 10347],
    edge_label=[20694],
    edge_label_index=[2, 20694]
  },
  [1m(actor, to, movie)[0m={
    edge_index=[2, 10347],
    edge_label=[20694],
    edge_label_index=[2, 20694]
  },
  [1m(movie, to, keyword)[0m={
    edge_index=[2, 16527],
    edge_label=[33054],
    edge_label_index=[2, 33054]
  },
  [1m(keyword, to, movie)[0m={
    edge_index=[2, 16527],
    edge_label=[33054],
    edge_label_index=[2, 33054]
  }
)

In [102]:
metapaths = [[
    ('author', 'to', 'paper'),
    ('paper', 'to', 'author')], 
    [ ('author', 'to', 'paper'),
    ('paper', 'to', 'venue'), 
    ('venue', 'to', 'paper'), 
    ('paper', 'to', 'author')],
    [('paper', 'to', 'author'),
    ('author', 'to', 'paper')]]

num_nodes_dict = {'author':train_data['author'].num_nodes,
                  'paper':train_data['paper'].num_nodes,
                  'venue':train_data['venue'].num_nodes, 
                  'term':train_data['term'].num_nodes} 

exclude_list = ['term']

In [124]:
metapaths = [[
    ('paper', 'to', 'term'),
    ('term', 'to', 'paper')], 
    [('paper', 'to', 'subject'), 
    ('subject', 'to', 'paper')], 
    [('paper', 'to', 'author'),
    ('author', 'to', 'paper')]]
    
num_nodes_dict = {'author':train_data['author'].num_nodes,
                  'paper':train_data['paper'].num_nodes,
                  'subject':train_data['subject'].num_nodes, 
                  'term':train_data['term'].num_nodes} 

In [26]:
metapaths = [[
    ('movie', 'to', 'director'),
    ('director', 'to', 'movie')], 
    [ ('movie', '>actorh', 'actor'),
    ('actor', 'to', 'movie')],
    [('director', 'to', 'movie'),
     ('movie', 'to', 'keyword'),
    ('keyword', 'to', 'movie'),
    ('movie', 'to', 'director')]]

num_nodes_dict = {'movie':train_data['movie'].num_nodes,
                  'director':train_data['director'].num_nodes,
                  'actor':train_data['actor'].num_nodes, 
                  'keyword':train_data['keyword'].num_nodes} 

exclude_list = ['keyword']


In [8]:
device = torch.device('cuda:1')


In [9]:
best_emb.items()

NameError: name 'best_emb' is not defined

In [10]:
def train(epoch, model, optimizer, loader, log_steps=100, eval_steps=2000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        
    return total_loss / len(loader)


In [11]:
def map_to_embeddings(emb_dict, selected_types):
    
    embeddings = None
    for nt, vec in emb_dict.items():
        if nt in selected_types:
            if embeddings is None:
                embeddings = vec.to(device)
            else:
                embeddings= torch.cat((embeddings, vec.to(device)), 0)
                
    embeddings = torch.cat((embeddings, torch.rand(1, 128).to(device)), 0) #oov
    
    return embeddings


def get_nodetype_count(metapath):
    count = 0
    start, end = {}, {}
    types = set([x[0] for x in metapath]) | set([x[-1] for x in metapath])
    types = sorted(list(types))
        
    for nt in types:
        start[nt] = count
        count += num_nodes_dict[key]
        end[nt] = count

    return start, end
        
        
def map_to_emb_dict(emb, meta_path):
    start_dict, end_dict = get_nodetype_count(meta_path)
    
    emb_dict = {}
    for k in start_dict.keys():
        start = start_dict[k]
        end = end_dict[k]
        emb_dict[k] = emb[start: end]
        
    return emb_dict

In [12]:
f_embeddings = {}
for nt in num_nodes_dict.keys():
    f_embeddings[nt] = torch.rand(num_nodes_dict[nt], 128)

In [14]:

model = MetaPath2Vec(train_data.edge_index_dict, embedding_dim=128,
                     metapath=metapaths[0], walk_length=40, context_size=5,
                     walks_per_node=40, num_negative_samples=1,
                     sparse=True, num_nodes_dict= num_nodes_dict).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=3)
opt = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

best_emb = None
best = 0
embedding = {}
count = 0

node_types = set([x[0] for x in metapaths[0]]) | set([x[-1] for x in metapaths[0]])
node_types = sorted(list(node_types))

for epoch in range(1, 200):    
    torch.manual_seed(10)
    loss = train(epoch, model, opt, loader)
        
    for nt in node_types:
            embedding[nt] = model(nt).detach()        
       
    auc, ap = evaluate(embedding, valid_data, valid_data.edge_types, exclude = exclude_list)

    if ap > best:
        best = ap 
        best_emb = embedding

    if epoch % 1 == 0:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}'
             f"validation AUC: {auc}, AP: {ap}")

Epoch: 01, Loss: 5.4089validation AUC: 0.5554682389147867, AP: 0.523536785491856
Epoch: 02, Loss: 3.4840validation AUC: 0.5974062843294974, AP: 0.5391789222180876
Epoch: 03, Loss: 2.4965validation AUC: 0.6245942176268982, AP: 0.5519139269525115
Epoch: 04, Loss: 1.8905validation AUC: 0.6481388526593403, AP: 0.5682128885564046
Epoch: 05, Loss: 1.5097validation AUC: 0.669111989763381, AP: 0.5871315185487156
Epoch: 06, Loss: 1.2730validation AUC: 0.6860900476858576, AP: 0.6074192904576943
Epoch: 07, Loss: 1.1263validation AUC: 0.69816580195763, AP: 0.6258219204813296
Epoch: 08, Loss: 1.0322validation AUC: 0.7050923887775715, AP: 0.6385897928999725
Epoch: 09, Loss: 0.9688validation AUC: 0.7099370908746796, AP: 0.648872238669415
Epoch: 10, Loss: 0.9239validation AUC: 0.7133746693053664, AP: 0.6574270667746098
Epoch: 11, Loss: 0.8908validation AUC: 0.7156684454575003, AP: 0.6644548852640668
Epoch: 12, Loss: 0.8655validation AUC: 0.7181555982538501, AP: 0.6701464903809128
Epoch: 13, Loss: 0.84

KeyboardInterrupt: 

In [15]:
for k, emb in best_emb.items():
    f_embeddings[k] = best_emb[k]

In [16]:
         
model = MetaPath2Vec(train_data.edge_index_dict, embedding_dim=128,
                     metapath=metapaths[1], walk_length=40, context_size=5,
                     walks_per_node=40, num_negative_samples=1,
                     sparse=True, num_nodes_dict= num_nodes_dict).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=3)
opt = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

best_emb = None
best = 0
embedding = {}
count = 0

node_types = set([x[0] for x in metapaths[1]]) | set([x[-1] for x in metapaths[1]])
node_types = sorted(list(node_types))

model.embedding.weight.data = map_to_embeddings(f_embeddings, node_types).to(device)
    

for epoch in range(1, 100):    
    torch.manual_seed(10)
    loss = train(epoch, model, opt, loader)
        
    for nt in node_types:
            embedding[nt] = model(nt).detach()        
       
    auc, ap = evaluate(embedding, valid_data, valid_data.edge_types, exclude = exclude_list)

    if ap > best:
        best = ap 
        best_emb = embedding

 
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}'
             f"validation AUC: {auc}, AP: {ap}")



Epoch: 01, Loss: 13.1015validation AUC: 0.6677628397330497, AP: 0.6115886505531895
Epoch: 02, Loss: 12.8393validation AUC: 0.6918688681195619, AP: 0.6332434570280845
Epoch: 03, Loss: 12.7715validation AUC: 0.7084495553348134, AP: 0.6446190344235401
Epoch: 04, Loss: 12.7416validation AUC: 0.7162898132456639, AP: 0.6501552347879904
Epoch: 05, Loss: 12.7169validation AUC: 0.7198957703518953, AP: 0.6537018011281854
Epoch: 06, Loss: 12.7013validation AUC: 0.7215888489220175, AP: 0.6553867417079181
Epoch: 07, Loss: 12.6946validation AUC: 0.7220946867999464, AP: 0.6564924374954351
Epoch: 08, Loss: 12.6887validation AUC: 0.7218256094693218, AP: 0.6569995084277125
Epoch: 09, Loss: 12.6825validation AUC: 0.7219544182085134, AP: 0.6576466461632764
Epoch: 10, Loss: 12.6791validation AUC: 0.7217417233509158, AP: 0.6579795728258337
Epoch: 11, Loss: 12.6763validation AUC: 0.7214171803354437, AP: 0.6581602215991487
Epoch: 12, Loss: 12.6739validation AUC: 0.7210715511918475, AP: 0.6582526624955944
Epoc

KeyboardInterrupt: 

In [147]:
for k, emb in best_emb.items():
    f_embeddings[k] = best_emb[k]

In [27]:
         
model = MetaPath2Vec(train_data.edge_index_dict, embedding_dim=128,
                     metapath=metapaths[2], walk_length=40, context_size=5,
                     walks_per_node=40, num_negative_samples=1,
                     sparse=True, num_nodes_dict= num_nodes_dict).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=3)
opt = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)

best_emb = None
best = 0
embedding = {}
count = 0

node_types = set([x[0] for x in metapaths[2]]) | set([x[-1] for x in metapaths[2]])
node_types = sorted(list(node_types))

model.embedding.weight.data = map_to_embeddings(f_embeddings, node_types).to(device)
    

for epoch in range(1, 100):    
    torch.manual_seed(10)
    loss = train(epoch, model, opt, loader)
        
    for nt in node_types:
            embedding[nt] = model(nt).detach()     
            
    auc, ap = evaluate(embedding, valid_data, valid_data.edge_types, exclude = ['term'])

    if ap > best:
        best = ap 
        best_emb = embedding

 
    if epoch % 1 == 0:
        print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}'
             f"validation AUC: {auc}, AP: {ap}")



Epoch: 01, Loss: 29.6905validation AUC: 0.5250140322667031, AP: 0.519666680457141
Epoch: 02, Loss: 29.6572validation AUC: 0.5345512600353184, AP: 0.5269207688578165
Epoch: 03, Loss: 29.6439validation AUC: 0.5459664972281195, AP: 0.5359845882428744
Epoch: 04, Loss: 29.6358validation AUC: 0.5565373112107997, AP: 0.5441033285602936
Epoch: 05, Loss: 29.6303validation AUC: 0.565888872994841, AP: 0.5518739824333624
Epoch: 06, Loss: 29.6230validation AUC: 0.5737538658497406, AP: 0.5582109915267925
Epoch: 07, Loss: 29.6092validation AUC: 0.5807372809866412, AP: 0.5652798940730128
Epoch: 08, Loss: 29.5909validation AUC: 0.5855000222774078, AP: 0.5701594743081316
Epoch: 09, Loss: 29.5743validation AUC: 0.5903376846796006, AP: 0.5755350220994145
Epoch: 10, Loss: 29.5634validation AUC: 0.5953339171710961, AP: 0.5807984137327168
Epoch: 11, Loss: 29.5575validation AUC: 0.5995213582560531, AP: 0.5854328332486319
Epoch: 12, Loss: 29.5473validation AUC: 0.6034531962990622, AP: 0.5910095103357094
Epoch:

KeyboardInterrupt: 

In [17]:
print(best_emb)
print(f_embeddings)
for k, emb in best_emb.items():
    f_embeddings[k] = best_emb[k]
   

{'actor': tensor([[-0.1108, -0.0613,  0.2515,  ...,  0.1939,  0.3752,  0.2751],
        [ 0.2688,  0.0479, -0.1770,  ..., -0.0361,  0.0785,  0.0300],
        [ 0.1760, -0.3401, -0.2490,  ..., -0.0135,  0.1509, -0.3064],
        ...,
        [ 1.2228,  1.2528,  1.2380,  ...,  1.3253,  1.2618,  0.9634],
        [ 1.8060,  0.9398,  0.7917,  ...,  1.4464,  1.0192,  1.4935],
        [ 1.0809,  0.8574,  0.8718,  ...,  0.1665,  1.7878,  1.7768]],
       device='cuda:1'), 'movie': tensor([[0.8808, 0.9595, 0.8099,  ..., 1.5876, 1.5234, 1.4108],
        [1.4786, 1.0366, 1.5477,  ..., 0.4423, 0.8793, 0.5725],
        [1.3855, 0.9127, 1.1700,  ..., 0.8246, 1.1523, 0.9958],
        ...,
        [0.7609, 0.0691, 0.8617,  ..., 1.2866, 0.2046, 0.1781],
        [1.5306, 0.2063, 0.0761,  ..., 1.4726, 1.5867, 0.5069],
        [0.7584, 0.9116, 0.5399,  ..., 0.8170, 0.1462, 1.2111]],
       device='cuda:1')}
{'movie': tensor([[-9.2838e-02,  7.2854e-02,  5.2074e-02,  ...,  4.4586e-01,
          3.5844e-01, 

In [28]:
with torch.no_grad():
    #best_model.eval()
    auc, ap = evaluate(f_embeddings, test_data, test_data.metadata()[1], exclude = ['term'])
        
    print(f"testing AUC: {auc}, AP: {ap}")

testing AUC: 0.5722435441008574, AP: 0.5506597763046167


In [18]:
i

2