In [37]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import copy
import os.path as osp
import os
import argparse
import scipy
import numpy as np
import json

from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv

__file__ = os.path.abspath('')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset_path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', 'Reddit')
dataset = Reddit(dataset_path)

In [38]:
# Construct sparse formats
print('Creating coo/csc/csr format of dataset...')
num_nodes = dataset[0].num_nodes
coo = dataset[0].edge_index.numpy()
v = np.ones_like(coo[0])
coo = scipy.sparse.coo_matrix((v, (coo[0], coo[1])), shape=(num_nodes, num_nodes))
csc = coo.tocsc()
csr = coo.tocsr()
print('Done!')

Creating coo/csc/csr format of dataset...
Done!


In [39]:

# Save csc-formatted dataset
indptr = csc.indptr.astype(np.int64)
indices = csc.indices.astype(np.int64)
features = dataset[0].x
labels = dataset[0].y

os.makedirs(dataset_path, exist_ok=True)
indptr_path = os.path.join(dataset_path, 'indptr.dat')
indices_path = os.path.join(dataset_path, 'indices.dat')
features_path = os.path.join(dataset_path, 'features.dat')
labels_path = os.path.join(dataset_path, 'labels.dat')
conf_path = os.path.join(dataset_path, 'conf.json')
split_idx_path = os.path.join(dataset_path, 'split_idx.pth')

print('Saving indptr...')
indptr_mmap = np.memmap(indptr_path, mode='w+', shape=indptr.shape, dtype=indptr.dtype)
indptr_mmap[:] = indptr[:]
indptr_mmap.flush()
print('Done!')

print('Saving indices...')
indices_mmap = np.memmap(indices_path, mode='w+', shape=indices.shape, dtype=indices.dtype)
indices_mmap[:] = indices[:]
indices_mmap.flush()
print('Done!')

print('Saving features...')
features_mmap = np.memmap(features_path, mode='w+', shape=dataset[0].x.shape, dtype=np.float32)
features_mmap[:] = features[:]
features_mmap.flush()
print('Done!')

print('Saving labels...')
labels = labels.type(torch.float32)
labels_mmap = np.memmap(labels_path, mode='w+', shape=dataset[0].y.shape, dtype=np.float32)
labels_mmap[:] = labels[:]
labels_mmap.flush()
print('Done!')

print('Making conf file...')
mmap_config = dict()
mmap_config['num_nodes'] = int(dataset[0].num_nodes)
mmap_config['indptr_shape'] = tuple(indptr.shape)
mmap_config['indptr_dtype'] = str(indptr.dtype)
mmap_config['indices_shape'] = tuple(indices.shape)
mmap_config['indices_dtype'] = str(indices.dtype)
mmap_config['indices_shape'] = tuple(indices.shape)
mmap_config['indices_dtype'] = str(indices.dtype)
mmap_config['indices_shape'] = tuple(indices.shape)
mmap_config['indices_dtype'] = str(indices.dtype)
mmap_config['features_shape'] = tuple(features_mmap.shape)
mmap_config['features_dtype'] = str(features_mmap.dtype)
mmap_config['labels_shape'] = tuple(labels_mmap.shape)
mmap_config['labels_dtype'] = str(labels_mmap.dtype)
mmap_config['num_classes'] = int(dataset.num_classes)
json.dump(mmap_config, open(conf_path, 'w'))
print('Done!')

print('Saving split index...')
splits = {'train': dataset[0].train_mask, 'test': dataset[0].test_mask, 'valid': dataset[0].val_mask}
torch.save(splits, split_idx_path)
print('Done!')

# Calculate and save score for neighbor cache construction
print('Calculating score for neighbor cache construction...')
score_path = os.path.join(dataset_path, 'nc_score.pth')
csc_indptr_tensor = torch.from_numpy(csc.indptr.astype(np.int64))
csr_indptr_tensor = torch.from_numpy(csr.indptr.astype(np.int64))

eps = 0.00000001
in_num_neighbors = (csc_indptr_tensor[1:] - csc_indptr_tensor[:-1]) + eps
out_num_neighbors = (csr_indptr_tensor[1:] - csr_indptr_tensor[:-1]) + eps
score = out_num_neighbors / in_num_neighbors
print('Done!')

print('Saving score...')
torch.save(score, score_path)
print('Done!')


Saving indptr...
Done!
Saving indices...
Done!
Saving features...
Done!
Saving labels...
Done!
Making conf file...
Done!
Saving split index...
Done!
Calculating score for neighbor cache construction...
Done!
Saving score...
Done!


In [40]:
def get_mmap_dataset(path='../data/Reddit'):
    indptr_path = os.path.join(path, 'indptr.dat')
    indices_path = os.path.join(path, 'indices.dat')
    features_path = os.path.join(path, 'features.dat')
    labels_path = os.path.join(path, 'labels.dat')
    conf_path = os.path.join(path, 'conf.json')
    split_idx_path = os.path.join(path, 'split_idx.pth')

    conf = json.load(open(conf_path, 'r'))

    # Assume we only memmap for large files - the adjacency matrix (indices) + features ~ 13GB and 50GB respectively

    indptr = np.fromfile(indptr_path, dtype=conf['indptr_dtype']).reshape(tuple(conf['indptr_shape']))
    indices = np.memmap(indices_path, mode='r', shape=tuple(conf['indices_shape']), dtype=conf['indices_dtype'])
    print(indptr)
    features_shape = conf['features_shape']
    features = np.memmap(features_path, mode='r', shape=tuple(features_shape), dtype=conf['features_dtype'])
    labels = np.fromfile(labels_path, dtype=conf['labels_dtype'], count=conf['num_nodes']).reshape(tuple([conf['labels_shape'][0]]))

    indptr = torch.from_numpy(indptr)
    indices = torch.from_numpy(indices)
    features = torch.from_numpy(features)
    labels = torch.from_numpy(labels)

    num_nodes = conf['num_nodes']
    num_features = conf['features_shape'][1]
    num_classes = conf['num_classes']

    split_idx = torch.load(split_idx_path)
    train_idx = split_idx['train']
    val_idx = split_idx['valid']
    test_idx = split_idx['test']

    return indptr, indices, features, labels, num_features, num_classes, num_nodes, train_idx, val_idx, test_idx

In [41]:
indptr, indices, x, y, num_features, num_classes, num_nodes, train_idx, valid_idx, test_idx = get_mmap_dataset()
train_idx.size(0)

[        0      2204      2358 ... 114615262 114615401 114615892]


232965

In [42]:
from customNeighborSampler import MMAPNeighborSampler


size = "10,10"
sizes = [int(size) for size in size.split(',')]
train_loader = MMAPNeighborSampler(indptr, indices, node_idx=train_idx,
                               sizes=sizes, batch_size=1000,
                               shuffle=False, num_workers=32)


indptr, indices, x, y, num_features, num_classes, num_nodes, train_idx, valid_idx, test_idx = get_mmap_dataset()

[        0      2204      2358 ... 114615262 114615401 114615892]


In [43]:
len(train_loader.node_idx.view(-1).tolist())

232965

In [44]:
total = 0
for step, (batch_size, ids, adjs) in enumerate(train_loader):
    print(adjs)
    total += batch_size

[Adj(adj_t=SparseTensor(row=tensor([    0,     0,     0,  ..., 10114, 10114, 10114]),
             col=tensor([10115, 10116, 10117,  ..., 68171, 68172, 68173]),
             size=(10115, 68174), nnz=100795, density=0.01%), e_id=None, size=(68174, 10115)), Adj(adj_t=SparseTensor(row=tensor([  0,   0,   0,  ..., 999, 999, 999]),
             col=tensor([ 1000,  1001,  1002,  ..., 10112, 10113, 10114]),
             size=(1000, 10115), nnz=9748, density=0.10%), e_id=None, size=(10115, 1000))]
[Adj(adj_t=SparseTensor(row=tensor([    0,     0,     0,  ..., 10098, 10098, 10098]),
             col=tensor([ 1009,  5089, 10099,  ..., 68395, 68396, 68397]),
             size=(10099, 68398), nnz=100556, density=0.01%), e_id=None, size=(68398, 10099)), Adj(adj_t=SparseTensor(row=tensor([  0,   0,   0,  ..., 999, 999, 999]),
             col=tensor([ 1000,  1001,  1002,  ..., 10096, 10097, 10098]),
             size=(1000, 10099), nnz=9697, density=0.10%), e_id=None, size=(10099, 1000))]
[Adj(adj_t

In [35]:
total

153431