In [2]:
import os
import pickle
import numpy as np
from sklearn.cluster import KMeans

In [5]:
def load_data(data_name, data_path, test_len = 3): 
    adj_time_list_path = os.path.join(data_path, data_name, "adj_orig_dense_list.pickle")
    with open(adj_time_list_path, 'rb') as handle:
        adj_time_list = pickle.load(handle,encoding="bytes")
    edge_list = set()
    for i in range(len(adj_time_list) - test_len):
        for j in range(len(adj_time_list[i])):
            for k in range(len(adj_time_list[i][j])):
                edge_list.add((str(i) + ' ' + str(k)))
    with open(os.path.join('graph', data_name + ".edgelist"), 'w') as handle:
        for item in edge_list:
            handle.write(item + '\n')
        handle.close()

def cluster(data_name, n_class):
    with open("emb/{}.emb".format(data_name), 'r') as f:
        n_nodes, emb_dim = f.readline().split()
        emb = np.zeros((int(n_nodes), int(emb_dim)))
        for line in f:
            line = line.strip().split()
            emb[int(line[0])] = np.array(line[1:])
    
    kmeans = KMeans(n_clusters=n_class, random_state=0).fit(emb)
    label = kmeans.labels_
    assert len(label) == len(emb)
    return label, emb
    
def main(data_name, data_path, emb_dim = 128, n_class = 3):
    load_data(data_name, data_path)
    os.system("python2 src/main.py --input graph/{}.edgelist --output emb/{}.emb --dimensions {}".format(data_name, data_name, emb_dim))
    label, emb = cluster(data_name, n_class)
    np.save(os.path.join(data_path, data_name, "label.npy"), label)
    np.save(os.path.join(data_path, data_name, "feat.npy"), emb)

In [8]:
data_name = 'reddit'
data_path = "../../data"
main(data_name, data_path)

In [17]:
# to load the data:
# feat = np.load(os.path.join(data_path, data_name, "feat.npy"))
# label = np.load(os.path.join(data_path, data_name, "label.npy"))

In [18]:
a

array([[-0.19666353, -0.55651754,  0.0535409 , ...,  0.00447828,
        -0.20099391, -0.10245145],
       [-0.24548902, -0.5429511 ,  0.10524947, ..., -0.02781564,
        -0.19045416, -0.05871525],
       [-0.1648987 , -0.43083873,  0.07670794, ...,  0.03225543,
        -0.22442435, -0.12292486],
       ...,
       [-0.16394128, -0.43593812,  0.1420592 , ...,  0.19042273,
        -0.3346524 , -0.11045839],
       [-0.1597476 , -0.39228153,  0.11786331, ...,  0.16231477,
        -0.333428  , -0.11982337],
       [-0.1853529 , -0.4134217 ,  0.12221798, ...,  0.16823433,
        -0.3280529 , -0.11259485]])