-
Notifications
You must be signed in to change notification settings - Fork 145
/
add.py
108 lines (86 loc) · 3.5 KB
/
add.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
from typing import Optional
import torch
from torch import Tensor
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor
@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass
@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass
def add(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).add_(value)
else:
value = other.add_(1)
return src.set_value(value, layout='coo')
elif isinstance(other, SparseTensor):
rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()
row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)
value: Optional[Tensor] = None
if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)
M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)
out = SparseTensor(row=row, col=col, value=value,
sparse_sizes=sparse_sizes)
out = out.coalesce(reduce='sum')
return out
else:
raise NotImplementedError
def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = value.add_(other.to(value.dtype))
else:
value = other.add_(1)
return src.set_value_(value, layout='coo')
def add_nnz(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add(other.to(value.dtype))
else:
value = other.add(1)
return src.set_value(value, layout=layout)
def add_nnz_(src: SparseTensor, other: torch.Tensor,
layout: Optional[str] = None) -> SparseTensor:
value = src.storage.value()
if value is not None:
value = value.add_(other.to(value.dtype))
else:
value = other.add(1)
return src.set_value_(value, layout=layout)
SparseTensor.add = lambda self, other: add(self, other)
SparseTensor.add_ = lambda self, other: add_(self, other)
SparseTensor.add_nnz = lambda self, other, layout=None: add_nnz(
self, other, layout)
SparseTensor.add_nnz_ = lambda self, other, layout=None: add_nnz_(
self, other, layout)
SparseTensor.__add__ = SparseTensor.add
SparseTensor.__radd__ = SparseTensor.add
SparseTensor.__iadd__ = SparseTensor.add_