Skip to content

Commit

Permalink
Improve docs for scatter and gather functions (#49679)
Browse files Browse the repository at this point in the history
Summary:
- Add warning about non-unique indices
- And note that these functions don't broadcast
- Add missing `torch.scatter` and `torch.scatter_add` doc entries
- Fix parameter descriptions
- Improve code examples to make indexing behaviour easier to understand

Closes gh-48214
Closes gh-26191
Closes gh-37130
Closes gh-34062
xref gh-31776

Pull Request resolved: #49679

Reviewed By: mruberry

Differential Revision: D25693660

Pulled By: ngimel

fbshipit-source-id: 4983e7b4efcbdf1ab9f04e58973b4f983e8e43a4
  • Loading branch information
rgommers authored and facebook-github-bot committed Dec 23, 2020
1 parent b338713 commit d99a0c3
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 50 deletions.
2 changes: 2 additions & 0 deletions docs/source/torch.rst
Expand Up @@ -97,6 +97,8 @@ Indexing, Slicing, Joining, Mutating Ops
nonzero
reshape
row_stack
scatter
scatter_add
split
squeeze
stack
Expand Down
109 changes: 65 additions & 44 deletions torch/_tensor_docs.py
Expand Up @@ -3127,14 +3127,25 @@ def callable(a, b) -> number
This is the reverse operation of the manner described in :meth:`~Tensor.gather`.
:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should have same
number of dimensions. It is also required that ``index.size(d) <= src.size(d)``
for all dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all
dimensions ``d != dim``.
:attr:`self`, :attr:`index` and :attr:`src` (if it is a Tensor) should all have
the same number of dimensions. It is also required that
``index.size(d) <= src.size(d)`` for all dimensions ``d``, and that
``index.size(d) <= self.size(d)`` for all dimensions ``d != dim``.
Note that ``index`` and ``src`` do not broadcast.
Moreover, as for :meth:`~Tensor.gather`, the values of :attr:`index` must be
between ``0`` and ``self.size(dim) - 1`` inclusive, and all values in a row
along the specified dimension :attr:`dim` must be unique.
between ``0`` and ``self.size(dim) - 1`` inclusive.
.. warning::
When indices are not unique, the behavior is non-deterministic (one of the
values from ``src`` will be picked arbitrarily) and the gradient will be
incorrect (it will be propagated to all locations in the source that
correspond to the same index)!
.. note::
The backward pass is implemented only for ``src.shape == index.shape``.
Additionally accepts an optional :attr:`reduce` argument that allows
specification of an optional reduction operation, which is applied to all
Expand All @@ -3156,36 +3167,39 @@ def callable(a, b) -> number
Args:
dim (int): the axis along which to index
index (LongTensor): the indices of elements to scatter,
can be either empty or the same size of src.
When empty, the operation returns identity
src (Tensor): the source element(s) to scatter,
incase `value` is not specified
value (float): the source element(s) to scatter,
incase `src` is not specified
reduce (string): reduction operation to apply,
can be either 'add' or 'multiply'.
index (LongTensor): the indices of elements to scatter, can be either empty
or of the same dimensionality as ``src``. When empty, the operation
returns ``self`` unchanged.
src (Tensor or float): the source element(s) to scatter.
reduce (str, optional): reduction operation to apply, can be either
``'add'`` or ``'multiply'``.
Example::
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])
>>> src = torch.arange(1, 11).reshape((2, 5))
>>> src
tensor([[ 1, 2, 3, 4, 5],
[ 6, 7, 8, 9, 10]])
>>> index = torch.tensor([[0, 1, 2, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(0, index, src)
tensor([[1, 0, 0, 4, 0],
[0, 2, 0, 0, 0],
[0, 0, 3, 0, 0]])
>>> index = torch.tensor([[0, 1, 2], [0, 1, 4]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_(1, index, src)
tensor([[1, 2, 3, 0, 0],
[6, 7, 0, 0, 8],
[0, 0, 0, 0, 0]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='multiply')
tensor([[2.0000, 2.0000, 2.4600, 2.0000],
[2.0000, 2.0000, 2.0000, 2.4600]])
>>> torch.full((2, 4), 2.).scatter_(1, torch.tensor([[2], [3]]),
... 1.23, reduce='add')
tensor([[2.0000, 2.0000, 3.2300, 2.0000],
[2.0000, 2.0000, 2.0000, 3.2300]])
>>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
>>> z
tensor([[1.0000, 1.0000, 1.2300, 1.0000],
[1.0000, 1.0000, 1.0000, 1.2300]])
""")

add_docstr_all('scatter_add_',
Expand All @@ -3208,28 +3222,35 @@ def callable(a, b) -> number
:attr:`self`, :attr:`index` and :attr:`src` should have same number of
dimensions. It is also required that ``index.size(d) <= src.size(d)`` for all
dimensions ``d``, and that ``index.size(d) <= self.size(d)`` for all dimensions
``d != dim``.
``d != dim``. Note that ``index`` and ``src`` do not broadcast.
Note:
{forward_reproducibility_note}
.. note::
The backward pass is implemented only for ``src.shape == index.shape``.
Args:
dim (int): the axis along which to index
index (LongTensor): the indices of elements to scatter and add,
can be either empty or the same size of src.
When empty, the operation returns identity.
index (LongTensor): the indices of elements to scatter and add, can be
either empty or of the same dimensionality as ``src``. When empty, the
operation returns ``self`` unchanged.
src (Tensor): the source elements to scatter and add
Example::
>>> x = torch.rand(2, 5)
>>> x
tensor([[0.7404, 0.0427, 0.6480, 0.3806, 0.8328],
[0.7953, 0.2009, 0.9154, 0.6782, 0.9620]])
>>> torch.ones(3, 5).scatter_add_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[1.7404, 1.2009, 1.9154, 1.3806, 1.8328],
[1.0000, 1.0427, 1.0000, 1.6782, 1.0000],
[1.7953, 1.0000, 1.6480, 1.0000, 1.9620]])
>>> src = torch.ones((2, 5))
>>> index = torch.tensor([[0, 1, 2, 0, 0]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[1., 0., 0., 1., 1.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
>>> index = torch.tensor([[0, 1, 2, 0, 0], [0, 1, 2, 2, 2]])
>>> torch.zeros(3, 5, dtype=src.dtype).scatter_add_(0, index, src)
tensor([[2., 0., 0., 1., 1.],
[0., 2., 0., 0., 0.],
[0., 0., 2., 1., 1.]])
""".format(**reproducibility_notes))

Expand Down
28 changes: 22 additions & 6 deletions torch/_torch_docs.py
Expand Up @@ -3111,7 +3111,6 @@ def merge_dicts(*dicts):
[5, 6, 7, 8]])
""".format(**common_args))

# TODO: see https://github.com/pytorch/pytorch/issues/43667
add_docstr(torch.gather,
r"""
gather(input, dim, index, *, sparse_grad=False, out=None) -> Tensor
Expand All @@ -3128,19 +3127,22 @@ def merge_dicts(*dicts):
:math:`(x_0, x_1..., x_{i-1}, x_i, x_{i+1}, ..., x_{n-1})`
and ``dim = i``, then :attr:`index` must be an :math:`n`-dimensional tensor with
size :math:`(x_0, x_1, ..., x_{i-1}, y, x_{i+1}, ..., x_{n-1})` where :math:`y \geq 1`
and :attr:`out` will have the same size as :attr:`index`.
""" + r"""
and :attr:`out` will have the same size as :attr:`index`. Note that ``input``
and ``index`` do not broadcast against each other.
Args:
input (Tensor): the source tensor
dim (int): the axis along which to index
index (LongTensor): the indices of elements to gather
sparse_grad(bool,optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.
Keyword arguments:
sparse_grad (bool, optional): If ``True``, gradient w.r.t. :attr:`input` will be a sparse tensor.
out (Tensor, optional): the destination tensor
Example::
>>> t = torch.tensor([[1,2],[3,4]])
>>> torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1, 1],
[ 4, 3]])
""")
Expand Down Expand Up @@ -7338,6 +7340,20 @@ def merge_dicts(*dicts):
tensor([ nan, 1.8351, 0.8053, nan])
""".format(**common_args))

add_docstr(torch.scatter,
r"""
scatter(input, dim, index, src) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_`
""")

add_docstr(torch.scatter_add,
r"""
scatter_add(input, dim, index, src) -> Tensor
Out-of-place version of :meth:`torch.Tensor.scatter_add_`
""")

add_docstr(torch.set_flush_denormal,
r"""
set_flush_denormal(mode) -> bool
Expand Down

0 comments on commit d99a0c3

Please sign in to comment.