Skip to content

Commit

Permalink
Enable index_sort (#6554)
Browse files Browse the repository at this point in the history
Enable `index_sort` introduced in
[pyg-lib](https://pyg-lib.readthedocs.io/en/latest/modules/ops.html#pyg_lib.ops.index_sort).

Till now, sorting indices (1d tensor) have been working sequentially.
This PR enables the `index_sort` operation, which makes such a sort work
in parallel (parallelized radix sort).
Related PR:
[pytorch_sparse:306](rusty1s/pytorch_sparse#306)

In training benchmarks, data set load time was improved on average by
3.82 times, and training time by 1.06 times.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people authored Feb 1, 2023
1 parent af45d62 commit 46d6102
Show file tree
Hide file tree
Showing 14 changed files with 66 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added accelerated `index_sort` function from `pyg-lib` for faster sorting ([#6554](https://github.com/pyg-team/pytorch_geometric/pull/6554))
- Fix incorrect device in `EquilibriumAggregration` ([#6560](https://github.com/pyg-team/pytorch_geometric/pull/6560))
- Added bipartite graph support in `dense_to_sparse()` ([#6546](https://github.com/pyg-team/pytorch_geometric/pull/6546))
- Add CPU affinity support for more data loaders ([#6534](https://github.com/pyg-team/pytorch_geometric/pull/6534))
Expand Down
8 changes: 6 additions & 2 deletions benchmark/inference/inference_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ def run(args: argparse.ArgumentParser) -> None:
assert dataset_name in supported_sets.keys(
), f"Dataset {dataset_name} isn't supported."
print(f'Dataset: {dataset_name}')
dataset, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor, args.bf16)
load_time = timeit() if args.measure_load_time else nullcontext()
with load_time:
dataset, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor,
args.bf16)
data = dataset.to(device)
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', None) if dataset_name == 'ogbn-mag' else None
Expand Down Expand Up @@ -166,4 +169,5 @@ def run(args: argparse.ArgumentParser) -> None:
help="Use DataLoader affinitzation.")
add('--loader-cores', nargs='+', default=[], type=int,
help="List of CPU core IDs to use for DataLoader workers.")
add('--measure-load-time', action='store_true')
run(argparser.parse_args())
7 changes: 5 additions & 2 deletions benchmark/training/training_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,10 @@ def run(args: argparse.ArgumentParser) -> None:
assert dataset_name in supported_sets.keys(
), f"Dataset {dataset_name} isn't supported."
print(f'Dataset: {dataset_name}')
data, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor, args.bf16)
load_time = timeit() if args.measure_load_time else nullcontext()
with load_time:
data, num_classes = get_dataset(dataset_name, args.root,
args.use_sparse_tensor, args.bf16)
hetero = True if dataset_name == 'ogbn-mag' else False
mask = ('paper', data['paper'].train_mask
) if dataset_name == 'ogbn-mag' else data.train_mask
Expand Down Expand Up @@ -219,6 +221,7 @@ def run(args: argparse.ArgumentParser) -> None:
help="Use DataLoader affinitzation.")
add('--loader-cores', nargs='+', default=[], type=int,
help="List of CPU core IDs to use for DataLoader workers.")
add('--measure-load-time', action='store_true')
args = argparser.parse_args()

run(args)
9 changes: 5 additions & 4 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from torch import Tensor

from torch_geometric.typing import EdgeTensorType, EdgeType, OptTensor
from torch_geometric.utils import index_sort
from torch_geometric.utils.mixin import CastMixin

# The output of converting between two types in the GraphStore is a Tuple of
Expand Down Expand Up @@ -276,20 +277,20 @@ def _edge_to_layout(
col = ptr2ind(col, row.numel())

if attr.layout != EdgeLayout.CSR: # COO->CSR
row, perm = row.sort() # Cannot be sorted by destination.
col = col[perm]
num_rows = attr.size[0] if attr.size else int(row.max()) + 1
row, perm = index_sort(row, max_value=num_rows)
col = col[perm]
row = ind2ptr(row, num_rows)

else: # CSC output requested:
if attr.layout == EdgeLayout.CSR: # CSR->COO
row = ptr2ind(row, col.numel())

if attr.layout != EdgeLayout.CSC: # COO->CSC
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
if not attr.is_sorted: # Not sorted by destination.
col, perm = col.sort()
col, perm = index_sort(col, max_value=num_cols)
row = row[perm]
num_cols = attr.size[1] if attr.size else int(col.max()) + 1
col = ind2ptr(col, num_cols)

if attr.layout != layout and store:
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
download_url,
extract_tar,
)
from torch_geometric.utils import index_sort


class Entities(InMemoryDataset):
Expand Down Expand Up @@ -147,7 +148,7 @@ def process(self):
edges.append([dst, src, 2 * rel + 1])

edges = torch.tensor(edges, dtype=torch.long).t().contiguous()
perm = (N * R * edges[0] + R * edges[1] + edges[2]).argsort()
_, perm = index_sort(N * R * edges[0] + R * edges[1] + edges[2])
edges = edges[:, perm]

edge_index, edge_type = edges[:2], edges[2]
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/datasets/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from torch_geometric.data import Data, Dataset
from torch_geometric.utils import scatter
from torch_geometric.utils import index_sort, scatter


class TrackingData(Data):
Expand Down Expand Up @@ -85,7 +85,7 @@ def get(self, idx):
weight = torch.from_numpy(y['weight'].values).to(torch.float)

# Sort.
perm = (particle_id * hit_id.size(0) + hit_id).argsort()
_, perm = index_sort(particle_id * hit_id.size(0) + hit_id)
hit_id = hit_id[perm]
particle_id = particle_id[perm]
weight = weight[perm]
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/datasets/word_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from torch_geometric.data import Data, InMemoryDataset, download_url
from torch_geometric.utils import index_sort


class WordNet18(InMemoryDataset):
Expand Down Expand Up @@ -81,7 +82,7 @@ def process(self):
test_mask[srcs[0].size(0) + srcs[1].size(0):] = True

num_nodes = max(int(src.max()), int(dst.max())) + 1
perm = (num_nodes * src + dst).argsort()
_, perm = index_sort(num_nodes * src + dst)

edge_index = torch.stack([src[perm], dst[perm]], dim=0)
edge_type = edge_type[perm]
Expand Down Expand Up @@ -191,7 +192,7 @@ def process(self):
test_mask[srcs[0].size(0) + srcs[1].size(0):] = True

num_nodes = max(int(src.max()), int(dst.max())) + 1
perm = (num_nodes * src + dst).argsort()
_, perm = index_sort(num_nodes * src + dst)

edge_index = torch.stack([src[perm], dst[perm]], dim=0)
edge_type = edge_type[perm]
Expand Down
5 changes: 3 additions & 2 deletions torch_geometric/nn/conv/rgcn_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch_geometric.typing
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor, pyg_lib
from torch_geometric.utils import scatter
from torch_geometric.utils import index_sort, scatter

from ..inits import glorot, zeros

Expand Down Expand Up @@ -230,7 +230,8 @@ def forward(self, x: Union[OptTensor, Tuple[OptTensor, Tensor]],
and isinstance(edge_index, Tensor)):
if not self.is_sorted:
if (edge_type[1:] < edge_type[:-1]).any():
edge_type, perm = edge_type.sort()
edge_type, perm = index_sort(
edge_type, max_value=self.num_relations)
edge_index = edge_index[:, perm]
edge_type_ptr = torch.ops.torch_sparse.ind2ptr(
edge_type, self.num_relations)
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/nn/dense/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch_geometric.typing
from torch_geometric.nn import inits
from torch_geometric.typing import pyg_lib
from torch_geometric.utils import index_sort


def is_uninitialized_parameter(x: Any) -> bool:
Expand Down Expand Up @@ -255,7 +256,7 @@ def forward(self, x: Tensor, type_vec: Tensor) -> Tensor:
perm: Optional[Tensor] = None
if not self.is_sorted:
if (type_vec[1:] < type_vec[:-1]).any():
type_vec, perm = type_vec.sort()
type_vec, perm = index_sort(type_vec, self.num_types)
x = x[perm]

type_vec_ptr = torch.ops.torch_sparse.ind2ptr(
Expand Down
3 changes: 2 additions & 1 deletion torch_geometric/sampler/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage
from torch_geometric.typing import NodeType, OptTensor
from torch_geometric.utils import index_sort

# Edge Layout Conversion ######################################################

Expand All @@ -17,7 +18,7 @@ def sort_csc(
src_node_time: OptTensor = None,
) -> Tuple[Tensor, Tensor, Tensor]:
if src_node_time is None:
col, perm = col.sort()
col, perm = index_sort(col)
return row[perm], col, perm
else:
# We use `np.lexsort` to sort based on multiple keys.
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .scatter import scatter
from .sort import index_sort
from .degree import degree
from .softmax import softmax
from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path
Expand Down Expand Up @@ -44,6 +45,7 @@

__all__ = [
'scatter',
'index_sort',
'degree',
'softmax',
'dropout_node',
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/utils/coalesce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import Tensor

from torch_geometric.utils import scatter
from torch_geometric.utils import index_sort, scatter

from .num_nodes import maybe_num_nodes

Expand Down Expand Up @@ -94,7 +94,7 @@ def coalesce(
idx[1:].mul_(num_nodes).add_(edge_index[int(sort_by_row)])

if not is_sorted:
idx[1:], perm = idx[1:].sort()
idx[1:], perm = index_sort(idx[1:], max_value=num_nodes * num_nodes)
edge_index = edge_index[:, perm]
if isinstance(edge_attr, Tensor):
edge_attr = edge_attr[perm]
Expand Down
28 changes: 28 additions & 0 deletions torch_geometric/utils/sort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Optional, Tuple

import torch

from torch_geometric.typing import WITH_PYG_LIB, pyg_lib

WITH_INDEX_SORT = WITH_PYG_LIB and hasattr(torch.ops.pyg, 'index_sort')


def index_sort(
inputs: torch.Tensor,
max_value: Optional[int] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Sorts the elements of the :obj:`inputs` tensor in ascending order.
It is expected that :obj:`inputs` is one-dimensional and that it only
contains positive integer values. If :obj:`max_value` is given, it can
be used by the underlying algorithm for better performance.
Args:
inputs (torch.Tensor): A vector with positive integer values.
max_value (int, optional): The maximum value stored inside
:obj:`inputs`. This value can be an estimation, but needs to be
greater than or equal to the real maximum.
(default: :obj:`None`)
"""
if not WITH_INDEX_SORT: # pragma: no cover
return inputs.sort()
return pyg_lib.ops.index_sort(inputs, max_value=max_value)
4 changes: 3 additions & 1 deletion torch_geometric/utils/sort_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import Tensor

from torch_geometric.utils import index_sort

from .num_nodes import maybe_num_nodes


Expand Down Expand Up @@ -71,7 +73,7 @@ def sort_edge_index(
idx = edge_index[1 - int(sort_by_row)] * num_nodes
idx += edge_index[int(sort_by_row)]

perm = idx.argsort()
_, perm = index_sort(idx, max_value=num_nodes * num_nodes)

edge_index = edge_index[:, perm]

Expand Down

0 comments on commit 46d6102

Please sign in to comment.