In [1]:
import pandas
el = pandas.read_parquet('/datasets/abarghi/ogbn_papers100M/parquet/paper__cites__paper/edge_index.parquet')

In [2]:
import torch
from torch_geometric.data import HeteroData

sz = max(el.src.max(), el.dst.max()) - 1

data = HeteroData()

data.put_edge_index(
    torch.stack([torch.as_tensor(el.src.values), torch.as_tensor(el.dst.values)]),
    edge_type=('paper','cites','paper'),
    layout='coo',
    is_sorted=False,
    size=(sz, sz)
)

data['paper']['num_nodes'] = sz


data

HeteroData(
  paper={ num_nodes=111059954 },
  (paper, cites, paper)={ edge_index=[2, 1615685872] }
)

In [3]:
from torch_geometric.sampler import NeighborSampler

sampler_disjoint = NeighborSampler(
    data,
    num_neighbors={('paper','cites','paper'):[3]},
    disjoint=True,
    replace=False,
)

sampler_default = NeighborSampler(
    data,
    num_neighbors={('paper','cites','paper'):[3]},
    disjoint=False,
    replace=False,
)

from torch_geometric.sampler import NodeSamplerInput

sampler_input = NodeSamplerInput(
    None,
    node=torch.tensor([0,1,7]),
    input_type='paper'
)


In [4]:
%%timeit
sampler_default.sample_from_nodes(sampler_input)

537 ms ± 30.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%%timeit
sampler_disjoint.sample_from_nodes(sampler_input)

110 µs ± 317 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [6]:
print('converting to csc...')
from torch_geometric.utils.sparse import index2ptr
src = torch.as_tensor(el['src'].values)
dst = torch.as_tensor(el['dst'].values)
dst = index2ptr(dst, max(src.max(), dst.max()) - 1)


converting to csc...


In [7]:
%%timeit
torch.ops.pyg.neighbor_sample(
    dst,
    src,
    torch.arange(5),  # seed
    [10,25],
    time=None,
    seed_time=None,
    csc=True,  # csc
    replace=False,
    directed=True,
    disjoint=False,
    temporal_strategy='uniform',
    return_edge_id=True,  # return_edge_id
)

506 ms ± 2.62 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit
torch.ops.pyg.neighbor_sample(
    dst,
    src,
    torch.arange(5),  # seed
    [10,25],
    time=None,
    seed_time=None,
    csc=True,  # csc
    replace=False,
    directed=True,
    disjoint=True,
    temporal_strategy='uniform',
    return_edge_id=True,  # return_edge_id
)

27.9 µs ± 189 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
