In [14]:
import networkx as nx
import numpy as np
from pyspark.sql import SparkSession

from snpp.cores.lowrank import alq_spark
from snpp.utils.matrix import split_train_test, load_sparse_csr
from snpp.utils.signed_graph import fill_diagonal

In [11]:
dataset = 'slashdot'
lambda_ = 0.1
k = 40
max_iter = 100
random_seed = 123456

In [15]:
sc.setCheckpointDir('.checkpoint')  # stackoverflow errors
    
    
m = load_sparse_csr('data/{}.npz'.format(dataset))
train_m, test_m = split_train_test(m, [.9, .1])

train_m = fill_diagonal(train_m)
targets = list(zip(*test_m.nonzero()))



In [27]:
print(train_m.shape)
print(test_m.shape)
print(train_m[0, 0])

train_m = fill_diagonal(train_m)
print(train_m[0, 0])


(77357, 77357)
(77357, 77357)
0.0


100%|██████████| 77357/77357 [00:01<00:00, 57344.64it/s]


1.0


In [40]:
X, Y = alq_spark(train_m, k=k, sc=sc,
                 lambda_=lambda_, iterations=max_iter,
                 seed=random_seed)



In [29]:
print(X.shape)
print(m.shape)

(77357, 40)
(77357, 77357)


In [41]:
Xb, Yb = sc.broadcast(X), sc.broadcast(np.transpose(Y))
preds = sc.parallelize(targets).map(
    lambda e: (e[0], e[1], np.sign(np.dot(Xb.value[e[0]], Yb.value[e[1]])))
).collect()

In [37]:
print(preds[:100])

[(0, 5, 1.0), (0, 26, 1.0), (1, 92, 1.0), (1, 112, 1.0), (1, 614, 1.0), (1, 621, 1.0), (1, 638, 1.0), (1, 650, 1.0), (1, 652, 1.0), (1, 684, 1.0), (4, 625, 1.0), (4, 705, 1.0), (4, 729, 1.0), (4, 745, 1.0), (4, 748, -1.0), (6, 753, 1.0), (6, 765, 1.0), (6, 771, 1.0), (8, 44, 1.0), (8, 54, 1.0), (8, 65, 1.0), (8, 85, -1.0), (8, 803, 1.0), (8, 816, 1.0), (8, 818, 1.0), (8, 826, 1.0), (8, 828, 1.0), (8, 837, 1.0), (8, 860, 1.0), (8, 864, 1.0), (8, 878, 1.0), (8, 886, 1.0), (8, 891, 1.0), (8, 896, 1.0), (8, 897, 1.0), (8, 912, 1.0), (8, 917, 1.0), (8, 939, 1.0), (8, 955, 1.0), (8, 958, 1.0), (8, 959, 1.0), (8, 961, 1.0), (8, 988, 1.0), (8, 990, 1.0), (8, 997, 1.0), (8, 1003, 1.0), (8, 1006, 1.0), (8, 1015, 1.0), (8, 1028, 1.0), (8, 1041, 1.0), (8, 1075, -1.0), (10, 1088, -1.0), (11, 1098, 1.0), (11, 1100, 1.0), (11, 1102, 1.0), (12, 1108, 1.0), (12, 1109, 1.0), (12, 1111, 1.0), (12, 1118, 1.0), (12, 1128, 1.0), (12, 1134, 1.0), (12, 1138, 1.0), (12, 1146, 1.0), (12, 1158, 1.0), (13, 746, 1

In [42]:
test_m = test_m.todok()
truth = set((i, j, test_m[i, j]) for i, j in targets)
assert len(truth) == len(preds)
print('=> final accuracy {}'.format(len(truth.intersection(preds)) / len(truth)))

=> final accuracy 0.8554918889620194
