In [12]:
import os.path as osp
import os
import torch
from overflowDataset import OverFlowDataset
from torch_geometric.datasets import JODIEDataset
from tqdm import tqdm
from neighbor_sampler import NeighborSampler
import scipy
import numpy as np
from collections import OrderedDict
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected

# Set arguments
# Percentage of data to sample from
subsetPerc = 1

# CPU cache percentage of nodes
CPUCachePerc = 100

# Datset to use
dataName = 'wiki' # 'overflow', 'taobao' , 'reddit', wiki'


In [13]:
# Load data
__file__ = os.path.abspath('')

print("Loading dataset ", dataName, "...")

if dataName == 'overflow':
    path = '/mnt/raid0nvme1/zz/data/' + 'overflow'
    dataset = OverFlowDataset(path)
    data = dataset[0]
    orig_edge_index = data.edge_index
elif dataName == 'taobao':
    path = '/mnt/raid0nvme1/zz/data/' + 'taobao/taobao.pt'
    data = torch.load(path)
    orig_edge_index = data.edge_index
    data.edge_index = to_undirected(data.edge_index)
elif dataName == 'reddit':
    path = '/mnt/raid0nvme1/zz/data/' + 'JODIE'
    dataset = JODIEDataset(path, name='reddit')
    data_orig = dataset[0]
    data = Data(x=data_orig.msg, edge_index=torch.stack([data_orig.src, data_orig.dst], dim=0), edge_attr=data_orig.t)
    orig_edge_index = data.edge_index
    data.edge_index = to_undirected(data.edge_index)
elif dataName == 'wiki':
    path = '/mnt/raid0nvme1/zz/data/' + 'JODIE'
    dataset = JODIEDataset(path, name='wikipedia')
    data_orig = dataset[0]
    data = Data(x=data_orig.msg, edge_index=torch.stack([data_orig.src, data_orig.dst], dim=0), edge_attr=data_orig.t)
    orig_edge_index = data.edge_index
    data.edge_index = to_undirected(data.edge_index)

print(data)

Loading dataset  wiki ...
Data(x=[157474, 172], edge_index=[2, 36514], edge_attr=[157474])


In [14]:

# Number of edges to sample fro
subset = int(orig_edge_index[0].numel() / (100/subsetPerc))

n1 = torch.unique(orig_edge_index[0])
n2 = torch.unique(orig_edge_index[1])
total_nodes = torch.unique(torch.cat((n1,n2))).numel()

# Assume CPU cache is 10% of data
CPUCacheNum = int(total_nodes / (100/CPUCachePerc))
print("We are using a cache size of ", CPUCacheNum)

# Assume GPU cache is 0.25% of data
GPUCacheNum = int(total_nodes / 200)

node_ids = torch.flatten(orig_edge_index.t())
nodes_to_sample = node_ids[len(node_ids) - subset*2:]
nodes_to_sample_unique_num = torch.unique(nodes_to_sample).numel()
print("Total number of unique nodes in dataset: ", total_nodes)
print("Number of unique nodes in edges we sample: ", nodes_to_sample_unique_num)
print("Number of total edges sampled: ", subset)

#loader = NeighborSampler(data.edge_index, sizes=[10,10], node_idx=nodes_to_sample, batch_size=2)
loader = NeighborSampler(data.edge_index, sizes=[25,10], node_idx=torch.unique(torch.cat((n1,n2))), batch_size=1)

We are using a cache size of  9227
Total number of unique nodes in dataset:  9227
Number of unique nodes in edges we sample:  724
Number of total edges sampled:  1574


In [15]:
sample_cnt = {}
cnt = 0
pbar = tqdm(total=total_nodes)
for batch_size, ids, adjs in loader:
    sample_cnt[cnt] = len(ids)
    cnt +=1
    pbar.update(batch_size)
pbar.close()

100%|██████████| 9227/9227 [00:01<00:00, 7367.60it/s]


In [16]:
sample_cnt = {k: v for k, v in sorted(sample_cnt.items(), key=lambda item: item[1], reverse=True)}

In [18]:
sample_cnt.values()

dict_values([253, 252, 250, 248, 248, 246, 246, 245, 245, 244, 243, 243, 242, 241, 240, 239, 238, 236, 236, 232, 229, 228, 228, 226, 225, 223, 222, 222, 220, 219, 217, 216, 212, 212, 210, 210, 207, 207, 205, 204, 204, 204, 204, 203, 203, 200, 198, 195, 192, 190, 186, 185, 184, 184, 182, 182, 181, 179, 177, 174, 173, 173, 171, 170, 168, 168, 168, 164, 160, 156, 155, 153, 152, 152, 151, 151, 151, 148, 148, 148, 147, 146, 146, 146, 146, 144, 144, 144, 143, 141, 141, 141, 141, 140, 139, 139, 139, 139, 139, 138, 138, 138, 137, 137, 137, 137, 137, 136, 136, 135, 135, 134, 134, 133, 133, 133, 132, 132, 132, 132, 132, 132, 131, 131, 130, 130, 130, 130, 130, 130, 130, 129, 129, 129, 129, 127, 126, 126, 126, 126, 125, 125, 125, 125, 125, 125, 124, 124, 124, 123, 123, 123, 123, 123, 123, 123, 123, 122, 122, 122, 122, 122, 121, 121, 121, 121, 120, 120, 120, 120, 120, 119, 119, 119, 119, 119, 119, 119, 119, 119, 118, 118, 118, 118, 117, 117, 117, 117, 117, 116, 116, 116, 116, 116, 115, 115, 115, 11

In [17]:
import json
path = '/mnt/raid0nvme1/zz/cache_data/'
with open(path + "wiki" + ".json", 'w') as fp:
    json.dump(sample_cnt, fp)