Skip to content

Commit

Permalink
Add EdgeIndex.add_ and EdgeIndex.sub_ (#9239)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 25, 2024
1 parent 3573d21 commit adb3866
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 57 deletions.
28 changes: 16 additions & 12 deletions test/test_edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,12 +708,14 @@ def test_add(dtype, device, is_undirected):
assert not out.is_undirected
assert out.sparse_size() == (6, 6)

# TODO Bring back.
# adj += 2
# assert isinstance(adj, EdgeIndex)
# assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))
# assert adj.is_undirected == is_undirected
# assert adj.sparse_size() == (5, 5)
adj += 2
assert isinstance(adj, EdgeIndex)
assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))
assert adj.is_undirected == is_undirected
assert adj.sparse_size() == (5, 5)

with pytest.raises(RuntimeError, match="can't be cast"):
adj += 2.5


@withCUDA
Expand Down Expand Up @@ -753,12 +755,14 @@ def test_sub(dtype, device, is_undirected):
assert not out.is_undirected
assert out.sparse_size() == (None, None)

# TODO Bring back.
# adj -= 2
# assert isinstance(adj, EdgeIndex)
# assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))
# assert adj.is_undirected == is_undirected
# assert adj.sparse_size() == (5, 5)
adj -= 2
assert isinstance(adj, EdgeIndex)
assert adj.equal(tensor([[2, 3, 3, 4], [3, 2, 4, 3]], device=device))
assert adj.is_undirected == is_undirected
assert adj.sparse_size() == (5, 5)

with pytest.raises(RuntimeError, match="can't be cast"):
adj -= 2.5


@withCUDA
Expand Down
168 changes: 123 additions & 45 deletions torch_geometric/edge_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,25 @@ def fill_cache_(self, no_transpose: bool = False) -> 'EdgeIndex':

# Methods #################################################################

def share_memory_(self) -> 'EdgeIndex':
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
if self._T_perm is not None:
self._T_perm.share_memory_()
if self._T_index[0] is not None:
self._T_index[0].share_memory_()
if self._T_index[1] is not None:
self._T_index[1].share_memory_()
if self._T_indptr is not None:
self._T_indptr.share_memory_()
if self._value is not None:
self._value.share_memory_()
return self

def is_shared(self) -> bool:
return self._data.is_shared()

def as_tensor(self) -> Tensor:
r"""Zero-copies the :class:`EdgeIndex` representation back to a
:class:`torch.Tensor` representation.
Expand Down Expand Up @@ -1111,6 +1130,8 @@ def sparse_narrow(
edge_index._indptr = colptr
return edge_index

# PyTorch/Python builtins #################################################

def __tensor_flatten__(self) -> Tuple[List[str], Tuple[Any, ...]]:
attrs = ['_data']
if self._indptr is not None:
Expand Down Expand Up @@ -1208,7 +1229,9 @@ def __repr__(self) -> str: # type: ignore
return torch._tensor_str._add_suffixes(prefix + tensor_str, suffixes,
indent, force_newline=False)

def shallow_copy(self) -> 'EdgeIndex':
# Helpers #################################################################

def _shallow_copy(self) -> 'EdgeIndex':
out = EdgeIndex(self._data)
out._sparse_size = self._sparse_size
out._sort_order = self._sort_order
Expand All @@ -1221,25 +1244,18 @@ def shallow_copy(self) -> 'EdgeIndex':
out._cat_metadata = self._cat_metadata
return out

def share_memory_(self) -> 'EdgeIndex':
self._data.share_memory_()
if self._indptr is not None:
self._indptr.share_memory_()
if self._T_perm is not None:
self._T_perm.share_memory_()
if self._T_index[0] is not None:
self._T_index[0].share_memory_()
if self._T_index[1] is not None:
self._T_index[1].share_memory_()
if self._T_indptr is not None:
self._T_indptr.share_memory_()
if self._value is not None:
self._value.share_memory_()
def _clear_metadata(self) -> 'EdgeIndex':
self._sparse_size = (None, None)
self._sort_order = None
self._is_undirected = False
self._indptr = None
self._T_perm = None
self._T_index = (None, None)
self._T_indptr = None
self._value = None
self._cat_metadata = None
return self

def is_shared(self) -> bool:
return self._data.is_shared()


class SortReturnType(NamedTuple):
values: EdgeIndex
Expand Down Expand Up @@ -1435,7 +1451,7 @@ def _slice(

if ((start is None or start <= 0)
and (end is None or end > input.size(dim)) and step == 1):
return input.shallow_copy() # No-op.
return input._shallow_copy() # No-op.

out = aten.slice.Tensor(input._data, dim, start, end, step)

Expand Down Expand Up @@ -1539,19 +1555,52 @@ def _add(
return out


# @implements(aten.add_.Tensor)
# def add_(
# input: EdgeIndex,
# other: Union[int, Tensor, EdgeIndex],
# *,
# alpha: int = 1,
# ) -> Tensor:
# aten.add_.Tensor(
# input._data,
# other._data if isinstance(other, EdgeIndex) else other,
# alpha=alpha,
# )
# return input
@implements(aten.add_.Tensor)
def add_(
input: EdgeIndex,
other: Union[int, Tensor, EdgeIndex],
*,
alpha: int = 1,
) -> EdgeIndex:

sparse_size = input._sparse_size
sort_order = input._sort_order
is_undirected = input._is_undirected
T_perm = input._T_perm
input._clear_metadata()

aten.add_.Tensor(
input._data,
other._data if isinstance(other, EdgeIndex) else other,
alpha=alpha,
)

if isinstance(other, Tensor) and other.numel() <= 1:
other = int(other)

if isinstance(other, int):
size = maybe_add(sparse_size, other, alpha)
assert len(size) == 2
input._sparse_size = size
input._sort_order = sort_order
input._is_undirected = is_undirected
input._T_perm = T_perm

elif isinstance(other, Tensor) and other.size() == (2, 1):
size = maybe_add(sparse_size, other.view(-1).tolist(), alpha)
assert len(size) == 2
input._sparse_size = size
input._sort_order = sort_order
if torch.equal(other[0], other[1]):
input._is_undirected = is_undirected
input._T_perm = T_perm

elif isinstance(other, EdgeIndex):
size = maybe_add(sparse_size, other._sparse_size, alpha)
assert len(size) == 2
input._sparse_size = size

return input


@implements(aten.sub.Tensor)
Expand Down Expand Up @@ -1598,19 +1647,48 @@ def _sub(
return out


# @implements(aten.sub_.Tensor)
# def sub_(
# input: EdgeIndex,
# other: Union[int, Tensor, EdgeIndex],
# *,
# alpha: int = 1,
# ) -> EdgeIndex:
# aten.sub_.Tensor(
# input._data,
# other._data if isinstance(other, EdgeIndex) else other,
# alpha=alpha,
# )
# return input
@implements(aten.sub_.Tensor)
def sub_(
input: EdgeIndex,
other: Union[int, Tensor, EdgeIndex],
*,
alpha: int = 1,
) -> EdgeIndex:

sparse_size = input._sparse_size
sort_order = input._sort_order
is_undirected = input._is_undirected
T_perm = input._T_perm
input._clear_metadata()

aten.sub_.Tensor(
input._data,
other._data if isinstance(other, EdgeIndex) else other,
alpha=alpha,
)

if isinstance(other, Tensor) and other.numel() <= 1:
other = int(other)

if isinstance(other, int):
size = maybe_sub(sparse_size, other, alpha)
assert len(size) == 2
input._sparse_size = size
input._sort_order = sort_order
input._is_undirected = is_undirected
input._T_perm = T_perm

elif isinstance(other, Tensor) and other.size() == (2, 1):
size = maybe_sub(sparse_size, other.view(-1).tolist(), alpha)
assert len(size) == 2
input._sparse_size = size
input._sort_order = sort_order
if torch.equal(other[0], other[1]):
input._is_undirected = is_undirected
input._T_perm = T_perm

return input


# Sparse-Dense Matrix Multiplication ##########################################

Expand Down

0 comments on commit adb3866

Please sign in to comment.