Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions .coveragerc
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,5 @@ source=torch_scatter
[report]
exclude_lines =
pragma: no cover
cuda
forward
backward
apply
torch.jit.script
raise
min_value
max_value
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ install:
- pip install codecov
- pip install sphinx
- pip install sphinx_rtd_theme
- pip install sphinx-autodoc-typehints
script:
- python -c "import torch; print(torch.__version__)"
- pycodestyle .
Expand Down
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2019 Matthias Fey <matthias.fey@tu-dortmund.de>
Copyright (c) 2020 Matthias Fey <matthias.fey@tu-dortmund.de>

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
39 changes: 16 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,34 +22,27 @@

**[Documentation](https://pytorch-scatter.readthedocs.io)**

This package consists of a small extension library of highly optimized sparse update (scatter) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter operations can be roughly described as reduce operations based on a given "group-index" tensor.
The package consists of the following operations:
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in [PyTorch](http://pytorch.org/), which are missing in the main package.
Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor.
Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.

* [**Scatter Add**](https://pytorch-scatter.readthedocs.io/en/latest/functions/add.html)
* [**Scatter Sub**](https://pytorch-scatter.readthedocs.io/en/latest/functions/sub.html)
* [**Scatter Mul**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mul.html)
* [**Scatter Div**](https://pytorch-scatter.readthedocs.io/en/latest/functions/div.html)
* [**Scatter Mean**](https://pytorch-scatter.readthedocs.io/en/latest/functions/mean.html)
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)
The package consists of the following operations with reduction types `"sum"|"mean"|"min"|"max"`:

In addition, we provide composite functions which make use of `scatter_*` operations under the hood:
* [**scatter**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment.html) based on arbitrary indices
* [**segment_coo**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_coo.html) based on sorted indices
* [**segment_csr**](https://pytorch-scatter.readthedocs.io/en/latest/functions/segment_csr.html) based on compressed indices via pointers

* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)
In addition, we provide the following **composite functions** which make use of `scatter_*` operations under the hood: :`scatter_std`, `scatter_logsumexp`, `scatter_softmax` and `scatter_log_softmax`.

All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.

## Installation

Ensure that at least PyTorch 1.1.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:
Ensure that at least PyTorch 1.3.0 is installed and verify that `cuda/bin` and `cuda/include` are in your `$PATH` and `$CPATH` respectively, *e.g.*:

```
$ python -c "import torch; print(torch.__version__)"
>>> 1.1.0
>>> 1.3.0

$ echo $PATH
>>> /usr/local/cuda/bin:...
Expand Down Expand Up @@ -81,17 +74,17 @@ from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])

out, argmax = scatter_max(src, index, fill_value=0)
out, argmax = scatter_max(src, index, dim=-1)
```

```
print(out)
tensor([[ 0, 0, 4, 3, 2, 0],
[ 2, 4, 3, 0, 0, 0]])
tensor([[0, 0, 4, 3, 2, 0],
[2, 4, 3, 0, 0, 0]])

print(argmax)
tensor([[-1, -1, 3, 4, 0, 1]
[ 1, 4, 3, -1, -1, -1]])
tensor([[5, 5, 3, 4, 0, 1]
[1, 4, 3, 5, 5, 5]])
```

## Running tests
Expand Down
35 changes: 13 additions & 22 deletions benchmark/scatter_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
import torch
from scipy.io import loadmat

import torch_scatter
from torch_scatter import scatter_add, scatter_mean, scatter_min, scatter_max
from torch_scatter import segment_coo, segment_csr
from torch_scatter import scatter, segment_coo, segment_csr

short_rows = [
('DIMACS10', 'citationCiteseer'),
Expand Down Expand Up @@ -47,34 +45,30 @@ def correctness(dataset):
x = torch.randn((row.size(0), size), device=args.device)
x = x.squeeze(-1) if size == 1 else x

out1 = scatter_add(x, row, dim=0, dim_size=dim_size)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='add')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='add')
out3 = segment_csr(x, rowptr, reduce='add')

assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)

out1 = scatter_mean(x, row, dim=0, dim_size=dim_size)
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='mean')
out2 = segment_coo(x, row, dim_size=dim_size, reduce='mean')
out3 = segment_csr(x, rowptr, reduce='mean')

assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)

x = x.abs_().mul_(-1)

out1, _ = scatter_min(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='min')
out3, _ = segment_csr(x, rowptr, reduce='min')
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='min')
out2 = segment_coo(x, row, reduce='min')
out3 = segment_csr(x, rowptr, reduce='min')

assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)

x = x.abs_()

out1, _ = scatter_max(x, row, 0, torch.zeros_like(out1))
out2, _ = segment_coo(x, row, reduce='max')
out3, _ = segment_csr(x, rowptr, reduce='max')
out1 = scatter(x, row, dim=0, dim_size=dim_size, reduce='max')
out2 = segment_coo(x, row, reduce='max')
out3 = segment_csr(x, rowptr, reduce='max')

assert torch.allclose(out1, out2, atol=1e-4)
assert torch.allclose(out1, out3, atol=1e-4)
Expand Down Expand Up @@ -117,17 +111,15 @@ def timing(dataset):
mat = loadmat(f'{name}.mat')['Problem'][0][0][2].tocsr()
rowptr = torch.from_numpy(mat.indptr).to(args.device, torch.long)
row = torch.from_numpy(mat.tocoo().row).to(args.device, torch.long)
row_perm = row[torch.randperm(row.size(0))]
row2 = row[torch.randperm(row.size(0))]
dim_size = rowptr.size(0) - 1
avg_row_len = row.size(0) / dim_size

def sca_row(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row, dim=0, dim_size=dim_size)
return scatter(x, row, dim=0, dim_size=dim_size, reduce=args.reduce)

def sca_col(x):
op = getattr(torch_scatter, f'scatter_{args.scatter_reduce}')
return op(x, row_perm, dim=0, dim_size=dim_size)
return scatter(x, row2, dim=0, dim_size=dim_size, reduce=args.reduce)

def seg_coo(x):
return segment_coo(x, row, reduce=args.reduce)
Expand Down Expand Up @@ -205,11 +197,10 @@ def dense2(x):
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--reduce', type=str, required=True,
choices=['sum', 'mean', 'min', 'max'])
choices=['sum', 'add', 'mean', 'min', 'max'])
parser.add_argument('--with_backward', action='store_true')
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
args.scatter_reduce = 'add' if args.reduce == 'sum' else args.reduce
iters = 1 if args.device == 'cpu' else 20
sizes = [1, 16, 32, 64, 128, 256, 512]
sizes = sizes[:3] if args.device == 'cpu' else sizes
Expand Down
5 changes: 0 additions & 5 deletions cpu/compat.h

This file was deleted.

120 changes: 0 additions & 120 deletions cpu/dim_apply.h

This file was deleted.

Loading