In [1]:
import pickle
import random
import numpy as np
import networkx as nx
from sklearn.decomposition import PCA

import torch

from smartsampling import *
from evaluation import *

In [2]:
seed = 1
device = "cuda" if torch.cuda.is_available() else "cpu"

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device=="cuda":
    torch.cuda.manual_seed(seed)

In [3]:
node_features = pickle.load(open('data/cora/node_features.pkl', 'rb'))
text_features = node_features.toarray()

pca = PCA(n_components=100)
pca.fit(text_features)
emb_features = pca.transform(text_features)

nnodes = text_features.shape[0]
device = "cuda" if torch.cuda.is_available() else "cpu"

## 2 nodes each walk, strongly favoring 1-hop neighbor

In [4]:
p_space = [9e50] #[0.2, 0.5, 1, 2, 5]
q_space = [9e50] #[0.2, 0.5, 1, 2, 5]
c_space = [2] #[1, 2, 3, 4, 5]

nsamplers = int(nnodes/1)
nwalk = 100
nnega = 5
penalty = 3e-4
dropout = 0.4

lr = 0.005
weight_decay = 0.0005

ratio = 0.1
nfold = 5

In [5]:
lp_train = pickle.load(open('data/cora/lp_train.pkl', 'rb'))
lp_test = pickle.load(open('data/cora/lp_test.pkl', 'rb'))

X_train, y_train, X_test, y_test = pickle.load(open('simulation/train_test.pkl', 'rb'))

# node_adjs = nx.from_scipy_sparse_matrix(pickle.load(open('data/cora/node_adjs.pkl', 'rb')))
# node_labels = pickle.load(open('data/cora/node_labels.pkl', 'rb'))

In [6]:
model = SmartSampling(text_features, emb_features, p_space, q_space, c_space, 
                      nnodes, nsamplers, nnega, lr, weight_decay, dropout, device, 
                      lp_train, X_train, y_train)

In [7]:
embeddings = model.train(0, penalty)
# nc_acc, nc_f1 = nc_evaluate(embeddings, node_labels)
# print('Walk: {}'.format(0), nc_acc, nc_f1)
lp_acc, lp_f1 = lp_evaluate(embeddings, lp_test)
print("Walk: {}".format(0), lp_acc, lp_f1)
    
for i in range(1, nwalk+1):
    embeddings = model.train(1, penalty)
    lp_acc, lp_f1 = lp_evaluate(embeddings, lp_test)
    print("Walk: {}".format(i), lp_acc, lp_f1)

Walk: 0 0.6803364879074658 0.7179962894248609
Epoch: 1, Loss: 0.7041, Acc: 0.3871, F1: 0.5582, Gain: 1.0160 || Nonneighbor_Avg: 0.4657, Neighbor_Avg: 1.2703, Time: 10.7902s
Walk: 1 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.7025, Acc: 0.3871, F1: 0.5582, Gain: 0.8604 || Nonneighbor_Avg: 0.4557, Neighbor_Avg: 1.2725, Time: 10.5685s
Walk: 2 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6988, Acc: 0.3871, F1: 0.5582, Gain: 0.8436 || Nonneighbor_Avg: 0.4572, Neighbor_Avg: 1.2821, Time: 10.3698s
Walk: 3 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6975, Acc: 0.3871, F1: 0.5582, Gain: 0.8384 || Nonneighbor_Avg: 0.4694, Neighbor_Avg: 1.2703, Time: 10.5859s
Walk: 4 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6944, Acc: 0.3871, F1: 0.5582, Gain: 0.8328 || Nonneighbor_Avg: 0.4605, Neighbor_Avg: 1.2751, Time: 10.5378s
Walk: 5 0.49947423764458465 0.6661991584852734
Epoch: 1, Loss: 0.6914, Acc: 0.3871, F1: 0.5582, Gain: 0.8247 || Nonneighbor_Avg: 0.

Epoch: 1, Loss: 0.6442, Acc: 0.3948, F1: 0.5611, Gain: 0.6235 || Nonneighbor_Avg: 0.4712, Neighbor_Avg: 1.2747, Time: 10.6919s
Walk: 48 0.5068349106203995 0.6690190543401552
Epoch: 1, Loss: 0.6470, Acc: 0.3942, F1: 0.5610, Gain: 0.6287 || Nonneighbor_Avg: 0.4542, Neighbor_Avg: 1.2740, Time: 10.6094s
Walk: 49 0.5068349106203995 0.6694855532064834
Epoch: 1, Loss: 0.6461, Acc: 0.3944, F1: 0.5610, Gain: 0.6301 || Nonneighbor_Avg: 0.4465, Neighbor_Avg: 1.2747, Time: 10.5098s
Walk: 50 0.5057833859095688 0.6690140845070423
Epoch: 1, Loss: 0.6451, Acc: 0.3953, F1: 0.5614, Gain: 0.6307 || Nonneighbor_Avg: 0.4490, Neighbor_Avg: 1.2755, Time: 10.4563s
Walk: 51 0.5047318611987381 0.6685432793807178
Epoch: 1, Loss: 0.6449, Acc: 0.3941, F1: 0.5609, Gain: 0.6272 || Nonneighbor_Avg: 0.4612, Neighbor_Avg: 1.2770, Time: 10.5262s
Walk: 52 0.5026288117770767 0.6676036542515812
Epoch: 1, Loss: 0.6437, Acc: 0.3971, F1: 0.5621, Gain: 0.6303 || Nonneighbor_Avg: 0.4708, Neighbor_Avg: 1.2714, Time: 10.5807s
Wal

Epoch: 1, Loss: 0.6419, Acc: 0.3977, F1: 0.5623, Gain: 0.6221 || Nonneighbor_Avg: 0.4531, Neighbor_Avg: 1.2744, Time: 10.5787s
Walk: 96 0.508937960042061 0.669964664310954
Epoch: 1, Loss: 0.6442, Acc: 0.3984, F1: 0.5626, Gain: 0.6291 || Nonneighbor_Avg: 0.4727, Neighbor_Avg: 1.2707, Time: 10.4819s
Walk: 97 0.5026288117770767 0.6671358198451794
Epoch: 1, Loss: 0.6447, Acc: 0.4009, F1: 0.5633, Gain: 0.6437 || Nonneighbor_Avg: 0.4542, Neighbor_Avg: 1.2773, Time: 10.4849s
Walk: 98 0.5057833859095688 0.6680790960451978
Epoch: 1, Loss: 0.6435, Acc: 0.3958, F1: 0.5614, Gain: 0.6382 || Nonneighbor_Avg: 0.4535, Neighbor_Avg: 1.2818, Time: 10.4816s
Walk: 99 0.5005257623554153 0.6661981728742095
Epoch: 1, Loss: 0.6420, Acc: 0.3986, F1: 0.5626, Gain: 0.6362 || Nonneighbor_Avg: 0.4498, Neighbor_Avg: 1.2751, Time: 10.6278s
Walk: 100 0.5078864353312302 0.6699576868829337
