Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ The package consists of the following operations:
* [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html)
* [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html)
* [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html)
* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html)

In addition, we provide composite functions which make use of `scatter_*` operations under the hood:

* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax)
* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax)

All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations.

Expand Down
9 changes: 9 additions & 0 deletions docs/source/composite/softmax.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
Scatter Softmax
===============

.. automodule:: torch_scatter.composite
:noindex:

.. autofunction:: scatter_softmax

.. autofunction:: scatter_log_softmax
7 changes: 7 additions & 0 deletions docs/source/functions/logsumexp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Scatter LogSumExp
=================

.. automodule:: torch_scatter
:noindex:

.. autofunction:: scatter_logsumexp
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i
:caption: Package reference

functions/*
composite/*

Indices and tables
==================
Expand Down
47 changes: 47 additions & 0 deletions test/composite/test_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from itertools import product

import pytest
import torch
from torch_scatter.composite import scatter_log_softmax, scatter_softmax

from test.utils import devices, tensor, grad_dtypes


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

out = scatter_softmax(src, index)

out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)

expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)

assert torch.allclose(out, expected)


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_log_softmax(dtype, device):
src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

out = scatter_log_softmax(src, index)

out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1)
out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1)
out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype),
dim=-1)

expected = torch.stack([
out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1]
], dim=0)

assert torch.allclose(out, expected)
24 changes: 24 additions & 0 deletions test/test_logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from itertools import product

import torch
import pytest
from torch_scatter import scatter_logsumexp

from .utils import devices, tensor, grad_dtypes


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
def test_logsumexp(dtype, device):
src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device)
index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device)

out = scatter_logsumexp(src, index)

out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1)
out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1)
out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1)
out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype)
out4 = torch.tensor(-1, dtype=dtype)

expected = torch.stack([out0, out1, out2, out3, out4], dim=0)
assert torch.allclose(out, expected)
4 changes: 4 additions & 0 deletions torch_scatter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from .std import scatter_std
from .max import scatter_max
from .min import scatter_min
from .logsumexp import scatter_logsumexp
import torch_scatter.composite

__version__ = '1.3.2'

Expand All @@ -18,5 +20,7 @@
'scatter_std',
'scatter_max',
'scatter_min',
'scatter_logsumexp',
'torch_scatter',
'__version__',
]
6 changes: 6 additions & 0 deletions torch_scatter/composite/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .softmax import scatter_log_softmax, scatter_softmax

__all__ = [
'scatter_softmax',
'scatter_log_softmax',
]
86 changes: 86 additions & 0 deletions torch_scatter/composite/softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch

from torch_scatter import scatter_add, scatter_max


def scatter_softmax(src, index, dim=-1, eps=1e-12):
r"""
Softmax operation over all values in :attr:`src` tensor that share indices
specified in the :attr:`index` tensor along a given axis :attr:`dim`.

For one-dimensional tensors, the operation computes

.. math::
\mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i =
\frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}

where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.

Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)

:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_softmax` can only be computed over tensors '
'with floating point data types.')

max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)

recentered_scores = src - max_per_src_element
recentered_scores_exp = recentered_scores.exp()

sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim)
normalizing_constants = (sum_per_index + eps).gather(dim, index)

return recentered_scores_exp / normalizing_constants


def scatter_log_softmax(src, index, dim=-1, eps=1e-12):
r"""
Log-softmax operation over all values in :attr:`src` tensor that share
indices specified in the :attr:`index` tensor along a given axis
:attr:`dim`.

For one-dimensional tensors, the operation computes

.. math::
\mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i =
\log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)}
\right)

where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.

Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)

:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_log_softmax` can only be computed over '
'tensors with floating point data types.')

max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0)
max_per_src_element = max_value_per_index.gather(dim, index)

recentered_scores = src - max_per_src_element

sum_per_index = scatter_add(src=recentered_scores.exp(), index=index,
dim=dim)

normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index)

return recentered_scores - normalizing_constants
54 changes: 54 additions & 0 deletions torch_scatter/logsumexp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch

from . import scatter_add, scatter_max


def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None,
fill_value=None, eps=1e-12):
r"""Fills :attr:`out` with the log of summed exponentials of all values
from the :attr:`src` tensor at the indices specified in the :attr:`index`
tensor along a given axis :attr:`dim`.
If multiple indices reference the same location, their
**exponential contributions add**
(`cf.` :meth:`~torch_scatter.scatter_add`).

For one-dimensional tensors, the operation computes

.. math::
\mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j
\exp(\mathrm{src}_j) \right)

where :math:`\sum_j` is over :math:`j` such that
:math:`\mathrm{index}_j = i`.

Args:
src (Tensor): The source tensor.
index (LongTensor): The indices of elements to scatter.
dim (int, optional): The axis along which to index.
(default: :obj:`-1`)
out (Tensor, optional): The destination tensor. (default: :obj:`None`)
dim_size (int, optional): If :attr:`out` is not given, automatically
create output with size :attr:`dim_size` at dimension :attr:`dim`.
If :attr:`dim_size` is not given, a minimal sized output tensor is
returned. (default: :obj:`None`)
fill_value (int, optional): If :attr:`out` is not given, automatically
fill output tensor with :attr:`fill_value`. (default: :obj:`None`)
eps (float, optional): Small value to ensure numerical stability.
(default: :obj:`1e-12`)

:rtype: :class:`Tensor`
"""
if not torch.is_floating_point(src):
raise ValueError('`scatter_logsumexp` can only be computed over '
'tensors with floating point data types.')

max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size,
fill_value)
max_per_src_element = max_value_per_index.gather(dim, index)
recentered_scores = src - max_per_src_element
out = (out - max_per_src_element).exp() if out is not None else None

sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out,
dim_size, fill_value=0)

return torch.log(sum_per_index + eps) + max_value_per_index