In [None]:
!cat reddit.csv | awk '{print $1, $2, 1}' > reddit.ijv
!csrcnv reddit.ijv 6 reddit.adj 3
!gpmetis reddit.adj 4

In [2]:
from scipy import sparse as spsp
import numpy as np

In [3]:
mat = np.loadtxt('reddit.csv')

In [7]:
spm = spsp.coo_matrix((np.ones(mat.shape[0]), (mat[:,0].astype(np.int64),
                                               mat[:,1].astype(np.int64))))
spm = spm.tocsr()

In [36]:
num_parts = 4
part = np.loadtxt('reddit.adj.part.{}'.format(num_parts))

In [50]:
class MetisPartition:
    def __init__(self, spm, part, num_parts):
        num_nodes = spm.shape[0]
        assert num_nodes == spm.shape[1]
        self.inner_edges = []
        self.cut_edges = []
        self.part_nodes = []
        self.map2parts = part
        
        for i in range(num_parts):
            print('part', i)
            nodes = np.nonzero(part == i)[0]
            spm_part = spm[nodes]
            print('#nodes:', spm_part.shape)
            print('#edges:', spm_part.nnz)
            print('#inside edges:', np.sum(spm_part[:,nodes]))
            self.inner_edges.append(spsp.coo_matrix(spm_part[:,nodes]))
            spm_part1 = spm_part.transpose()
            deg = spm_part1.dot(np.ones(spm_part1.shape[1]))
            col_nodes = np.nonzero(deg > 0)[0]
            all_nodes = np.unique(np.concatenate([nodes, col_nodes]))
            print('all nodes:', len(all_nodes))
            halo_nodes = np.setdiff1d(all_nodes, nodes)
            print('halo nodes:', len(halo_nodes))
            print('edge cut:', np.sum(spm_part[:,halo_nodes]))
            self.cut_edges.append(spsp.coo_matrix(spm_part[:,halo_nodes]))
            self.part_nodes.append((nodes, halo_nodes))
            
    def get_part(self, i):
        inner_nodes, halo_nodes = self.part_nodes[i]
        # this is symmetric
        inner_row, inner_col = self.inner_edges[i].row, self.inner_edges[i].col
        inner_row = inner_nodes[inner_row]
        inner_col = inner_nodes[inner_col]
        # this is asymmetric. Rows are inner nodes, cols are halo nodes
        cut_row, cut_col = self.cut_edges[i].row, self.cut_edges[i].col
        cut_row = inner_nodes[cut_row]
        cut_col = halo_nodes[cut_col]
        # inner edges are undirected, the cut edges has only one direction.
        # We should make them undirected as well.
        row = np.concatenate([inner_row, cut_row, cut_col])
        col = np.concatenate([inner_col, cut_col, cut_row])
        spm_part = spsp.coo_matrix((np.ones(len(row)), (row, col))).tocsr()
        assert np.sum(spm_part[inner_nodes][:,inner_nodes]) == len(inner_row)
        assert np.sum(spm_part[inner_nodes][:,halo_nodes]) == len(cut_row)
        assert np.sum(spm_part[halo_nodes][:,inner_nodes]) == len(cut_row)
        nodes = np.concatenate([inner_nodes, halo_nodes])
        nodes = np.sort(nodes)
        return spm_part[nodes][:,nodes], nodes, self.map2parts[nodes]

In [51]:
metis_parts = MetisPartition(spm, part, num_parts)

part 0
#nodes: (58075, 232965)
#edges: 29023404
#inside edges: 26036992.0
all nodes: 138336
halo nodes: 80261
edge cut: 2986412.0
part 1
#nodes: (59991, 232965)
#edges: 46928783
#inside edges: 41347084.0
all nodes: 182845
halo nodes: 122854
edge cut: 5581699.0
part 2
#nodes: (56803, 232965)
#edges: 19071516
#inside edges: 16206226.0
all nodes: 140812
halo nodes: 84009
edge cut: 2865290.0
part 3
#nodes: (58096, 232965)
#edges: 19592189
#inside edges: 17715632.0
all nodes: 135112
halo nodes: 77016
edge cut: 1876557.0


In [55]:
import pickle

for i in range(num_parts):
    part, part_nodes, part_loc = metis_parts.get_part(i)
    pickle.dump((part, part_nodes, part_loc), open('reddit_part_{}.pkl'.format(i), 'wb'))

In [57]:
from dgl.data import RedditDataset
data = RedditDataset()


Finished data loading.
  NumNodes: 232965
  NumEdges: 114615892
  NumFeats: 602
  NumClasses: 41
  NumTrainingSamples: 153431
  NumValidationSamples: 23831
  NumTestSamples: 55703


In [61]:
ndata = {'feature': data.features,
         'label': data.labels,
         'train_mask': data.train_mask,
         'val_mask': data.val_mask,
         'test_mask': data.test_mask}
pickle.dump(ndata, open('reddit_ndata.pkl', 'wb'))