Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable index_sort #306

Merged
merged 3 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions torch_sparse/storage.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import warnings
from typing import Optional, List, Tuple
from typing import List, Optional, Tuple

import torch
from torch_scatter import segment_csr, scatter_add
from torch_sparse.utils import Final
from torch_scatter import scatter_add, segment_csr

from torch_sparse.utils import Final, index_sort

layouts: Final[List[str]] = ['coo', 'csr', 'csc']

Expand Down Expand Up @@ -151,7 +152,8 @@ def __init__(
idx[1:] *= self._sparse_sizes[1]
idx[1:] += self._col
if (idx[1:] < idx[:-1]).any():
perm = idx[1:].argsort()
max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
_, perm = index_sort(idx[1:], max_value)
self._row = self.row()[perm]
self._col = self._col[perm]
if value is not None:
Expand All @@ -163,10 +165,20 @@ def __init__(
def empty(self):
row = torch.tensor([], dtype=torch.long)
col = torch.tensor([], dtype=torch.long)
return SparseStorage(row=row, rowptr=None, col=col, value=None,
sparse_sizes=(0, 0), rowcount=None, colptr=None,
colcount=None, csr2csc=None, csc2csr=None,
is_sorted=True, trust_data=True)
return SparseStorage(
row=row,
rowptr=None,
col=col,
value=None,
sparse_sizes=(0, 0),
rowcount=None,
colptr=None,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
trust_data=True,
)

def has_row(self) -> bool:
return self._row is not None
Expand Down Expand Up @@ -209,8 +221,11 @@ def has_value(self) -> bool:
def value(self) -> Optional[torch.Tensor]:
return self._value

def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
def set_value_(
self,
value: Optional[torch.Tensor],
layout: Optional[str] = None,
):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
Expand All @@ -221,8 +236,11 @@ def set_value_(self, value: Optional[torch.Tensor],
self._value = value
return self

def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
def set_value(
self,
value: Optional[torch.Tensor],
layout: Optional[str] = None,
):
if value is not None:
if get_layout(layout) == 'csc':
value = value[self.csc2csr()]
Expand Down Expand Up @@ -375,8 +393,11 @@ def colcount(self) -> torch.Tensor:
if colptr is not None:
colcount = colptr[1:] - colptr[:-1]
else:
colcount = scatter_add(torch.ones_like(self._col), self._col,
dim_size=self._sparse_sizes[1])
colcount = scatter_add(
torch.ones_like(self._col),
self._col,
dim_size=self._sparse_sizes[1],
)
self._colcount = colcount
return colcount

Expand All @@ -389,7 +410,8 @@ def csr2csc(self) -> torch.Tensor:
return csr2csc

idx = self._sparse_sizes[0] * self._col + self.row()
csr2csc = idx.argsort()
max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
_, csr2csc = index_sort(idx, max_value)
self._csr2csc = csr2csc
return csr2csc

Expand All @@ -401,7 +423,8 @@ def csc2csr(self) -> torch.Tensor:
if csc2csr is not None:
return csc2csr

csc2csr = self.csr2csc().argsort()
max_value = self._sparse_sizes[0] * self._sparse_sizes[1]
_, csc2csr = index_sort(self.csr2csc(), max_value)
self._csc2csr = csc2csr
return csc2csr

Expand Down Expand Up @@ -543,7 +566,8 @@ def type(self, dtype: torch.dtype, non_blocking: bool = False):
else:
return self.set_value(
value.to(dtype=dtype, non_blocking=non_blocking),
layout='coo')
layout='coo',
)
else:
return self

Expand Down
8 changes: 8 additions & 0 deletions torch_sparse/typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
try:
import pyg_lib # noqa
WITH_PYG_LIB = True
WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
except ImportError:
pyg_lib = object
WITH_PYG_LIB = False
WITH_INDEX_SORT = False
17 changes: 16 additions & 1 deletion torch_sparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
from typing import Any
from typing import Any, Optional, Tuple

import torch

import torch_sparse.typing
from torch_sparse.typing import pyg_lib

try:
from typing_extensions import Final # noqa
except ImportError:
from torch.jit import Final # noqa


def index_sort(
inputs: torch.Tensor,
max_value: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
r"""See pyg-lib documentation for more details:
https://pyg-lib.readthedocs.io/en/latest/modules/ops.html"""
if not torch_sparse.typing.WITH_INDEX_SORT: # pragma: no cover
return inputs.sort()
return pyg_lib.ops.index_sort(inputs, max_value)


def is_scalar(other: Any) -> bool:
return isinstance(other, int) or isinstance(other, float)