In [1]:
import pickle
import random
import numpy as np
import networkx as nx
from sklearn.decomposition import PCA

import torch

from smartsampling import *
from evaluation import *

In [2]:
seed = 1
device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device=="cuda":
    torch.cuda.manual_seed(seed)

In [3]:
node_features = pickle.load(open('data/cora/node_features.pkl', 'rb'))
text_features = node_features.toarray()

pca = PCA(n_components=100)
pca.fit(text_features)
emb_features = pca.transform(text_features)

nnodes = text_features.shape[0]
device = "cuda" if torch.cuda.is_available() else "cpu"

## 4 nodes each walk, favoring 1-hop neighbor

In [4]:
p_space = [1e10] #[0.2, 0.5, 1, 2, 5]
q_space = [1e10] #[0.2, 0.5, 1, 2, 5]
c_space = [5] #[1, 2, 3, 4, 5]

nsamplers = int(nnodes/1)
nwalk = 100
nnega = 5
penalty = 3e-4
dropout = 0.4

lr = 0.005
weight_decay = 0.0005

ratio = 0.1
nfold = 5

In [5]:
lp_train = pickle.load(open('data/cora/lp_train.pkl', 'rb'))
lp_test = pickle.load(open('data/cora/lp_test.pkl', 'rb'))

X_train, y_train, X_test, y_test = pickle.load(open('simulation/train_test.pkl', 'rb'))

# node_adjs = nx.from_scipy_sparse_matrix(pickle.load(open('data/cora/node_adjs.pkl', 'rb')))
# node_labels = pickle.load(open('data/cora/node_labels.pkl', 'rb'))

In [6]:
model = SmartSampling(text_features, emb_features, p_space, q_space, c_space, 
                      nnodes, nsamplers, nnega, lr, weight_decay, dropout, device, 
                      lp_train, X_train, y_train)

In [7]:
embeddings = model.train(0, penalty)
# nc_acc, nc_f1 = nc_evaluate(embeddings, node_labels)
# print('Walk: {}'.format(0), nc_acc, nc_f1)
lp_acc, lp_f1 = lp_evaluate(embeddings, lp_test)
print("Walk: {}".format(0), lp_acc, lp_f1)
    
for i in range(1, nwalk+1):
    embeddings = model.train(1, penalty)
    lp_acc, lp_f1 = lp_evaluate(embeddings, lp_test)
    print("Walk: {}".format(i), lp_acc, lp_f1)

Walk: 0 0.6803364879074658 0.7179962894248609
Epoch: 1, Loss: 0.7018, Acc: 0.3871, F1: 0.5582, Gain: 1.0163 || Nonneighbor_Avg: 1.4974, Neighbor_Avg: 1.5203, Time: 11.3339s
Walk: 1 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6998, Acc: 0.3871, F1: 0.5582, Gain: 0.8661 || Nonneighbor_Avg: 1.5089, Neighbor_Avg: 1.5314, Time: 11.0744s
Walk: 2 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6956, Acc: 0.3871, F1: 0.5582, Gain: 0.8515 || Nonneighbor_Avg: 1.5281, Neighbor_Avg: 1.5251, Time: 10.8987s
Walk: 3 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6919, Acc: 0.3871, F1: 0.5582, Gain: 0.8448 || Nonneighbor_Avg: 1.4749, Neighbor_Avg: 1.5487, Time: 10.9997s
Walk: 4 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6879, Acc: 0.3871, F1: 0.5582, Gain: 0.8269 || Nonneighbor_Avg: 1.5100, Neighbor_Avg: 1.5343, Time: 10.9891s
Walk: 5 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6870, Acc: 0.3871, F1: 0.5582, Gain: 0.8234 || Nonneighbor_Avg: 1.

Epoch: 1, Loss: 0.6385, Acc: 0.3935, F1: 0.5606, Gain: 0.6211 || Nonneighbor_Avg: 1.5244, Neighbor_Avg: 1.5270, Time: 10.9279s
Walk: 48 0.5036803364879074 0.6680731364275667
Epoch: 1, Loss: 0.6394, Acc: 0.3935, F1: 0.5605, Gain: 0.6235 || Nonneighbor_Avg: 1.4849, Neighbor_Avg: 1.5388, Time: 11.1185s
Walk: 49 0.501577287066246 0.6671348314606742
Epoch: 1, Loss: 0.6413, Acc: 0.3929, F1: 0.5603, Gain: 0.6272 || Nonneighbor_Avg: 1.5233, Neighbor_Avg: 1.5358, Time: 10.9665s
Walk: 50 0.5078864353312302 0.6699576868829337
Epoch: 1, Loss: 0.6394, Acc: 0.3928, F1: 0.5604, Gain: 0.6218 || Nonneighbor_Avg: 1.4664, Neighbor_Avg: 1.5499, Time: 10.9913s
Walk: 51 0.5047318611987381 0.6685432793807178
Epoch: 1, Loss: 0.6403, Acc: 0.3934, F1: 0.5606, Gain: 0.6213 || Nonneighbor_Avg: 1.5078, Neighbor_Avg: 1.5414, Time: 10.9538s
Walk: 52 0.5036803364879074 0.6680731364275667
Epoch: 1, Loss: 0.6374, Acc: 0.3949, F1: 0.5611, Gain: 0.6245 || Nonneighbor_Avg: 1.4631, Neighbor_Avg: 1.5583, Time: 11.2051s
Walk

Epoch: 1, Loss: 0.6358, Acc: 0.3992, F1: 0.5627, Gain: 0.6191 || Nonneighbor_Avg: 1.5188, Neighbor_Avg: 1.5318, Time: 11.0692s
Walk: 96 0.5120925341745531 0.6704545454545454
Epoch: 1, Loss: 0.6389, Acc: 0.3990, F1: 0.5628, Gain: 0.6300 || Nonneighbor_Avg: 1.5447, Neighbor_Avg: 1.5192, Time: 10.8450s
Walk: 97 0.5036803364879074 0.6680731364275667
Epoch: 1, Loss: 0.6377, Acc: 0.3987, F1: 0.5623, Gain: 0.6378 || Nonneighbor_Avg: 1.5270, Neighbor_Avg: 1.5306, Time: 11.0758s
Walk: 98 0.5099894847528917 0.6704384724186705
Epoch: 1, Loss: 0.6397, Acc: 0.3973, F1: 0.5621, Gain: 0.6354 || Nonneighbor_Avg: 1.5037, Neighbor_Avg: 1.5340, Time: 11.1313s
Walk: 99 0.5152471083070452 0.6732813607370659
Epoch: 1, Loss: 0.6399, Acc: 0.3987, F1: 0.5626, Gain: 0.6459 || Nonneighbor_Avg: 1.5240, Neighbor_Avg: 1.5451, Time: 11.1566s
Walk: 100 0.5047318611987381 0.6685432793807178
