In [None]:
!cat reddit.csv | awk '{print $1, $2, 1}' > reddit.ijv
!csrcnv reddit.ijv 6 reddit.adj 3
!gpmetis reddit.adj 4

In [None]:
from scipy import sparse as spsp
import numpy as np

In [None]:
mat = np.loadtxt('reddit.csv')

In [None]:
spm = spsp.coo_matrix((np.ones(mat.shape[0]), (mat[:,0].astype(np.int64),
                                               mat[:,1].astype(np.int64))))
spm = spm.tocsr()

In [None]:
import dgl
spm = spsp.load_npz('reddit.npz')
g = dgl.DGLGraph(spm, readonly=True)

In [None]:
num_parts = 4
node_locs = np.loadtxt('reddit.adj.part.{}'.format(num_parts))

In [None]:
parts = dgl.transform.partition_graph_with_halo(g, node_locs, 2)

In [None]:
import torch as th
import pickle

print('#nodes: {}, #edges: {}'.format(g.number_of_nodes(), g.number_of_edges()))
part_results = []
for i, part in parts.items():
    print('part:', i)
    print('#nodes: {}, #inner nodes: {}'.format(part.number_of_nodes(),
                                                th.sum(part.ndata['inner_node'])))
    print('#edges: {}, #inner edges: {}'.format(part.number_of_edges(),
                                                th.sum(part.edata['inner_edge'])))
    out_spm = part.adjacency_matrix_scipy(transpose=True)
    part_nodes = part.parent_nid.numpy()
    part_loc = node_locs[part_nodes]
    print(out_spm.shape, part_nodes.shape, part_loc.shape)
    pickle.dump((out_spm, part_nodes, part_loc), open('reddit_part_{}.pkl'.format(i), 'wb'))

In [None]:
for i, part in parts.items():
    out_spm = part.adjacency_matrix_scipy(transpose=True).tocoo()
    row = np.expand_dims(out_spm.row, 1)
    col = np.expand_dims(out_spm.col, 1)
    out_mat = np.concatenate([row, col], 1)
    print(out.mat.shape)
    np.savetxt('reddit_part_{}.csv'.format(i), out_mat)
    part_nodes = part.parent_nid.numpy()
    np.savetxt('reddit_part_map_{}.txt', part_nodes)

In [None]:
from dgl.data import RedditDataset
data = RedditDataset()


In [None]:
ndata = {'feature': data.features,
         'label': data.labels,
         'train_mask': data.train_mask,
         'val_mask': data.val_mask,
         'test_mask': data.test_mask}
pickle.dump(ndata, open('reddit_ndata.pkl', 'wb'))

In [None]:
import pickle

def get_nodeflow(g, node_ids, num_layers):
    batch_size = len(node_ids)
    expand_factor = g.number_of_nodes()
    sampler = dgl.contrib.sampling.NeighborSampler(g, batch_size,
            expand_factor=expand_factor, num_hops=num_layers,
            seed_nodes=node_ids)
    return next(iter(sampler))

for i in range(num_parts):
    print(i)
    out_spm, part_nodes, part_loc = pickle.load(open('reddit_part_{}.pkl'.format(i), 'rb'))
    subg = dgl.DGLGraph(out_spm, readonly=True)
    node_ids = np.nonzero(node_locs == i)[0]
    lnode_ids = np.nonzero(part_loc == i)[0]
    nf = get_nodeflow(g, node_ids, 2)
    lnf = get_nodeflow(subg, lnode_ids, 2)
    for i in range(nf.num_layers):
        layer_nids1 = nf.layer_parent_nid(i).detach().numpy()
        layer_nids2 = lnf.layer_parent_nid(i)
        layer_nids2 = part_nodes[layer_nids2]
        assert np.all(np.sort(layer_nids1) == np.sort(layer_nids2))
