diff --git a/README.md b/README.md index 51986a59..efa7a8f6 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,12 @@ The package consists of the following operations: * [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html) * [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) * [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) +* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html) + +In addition, we provide composite functions which make use of `scatter_*` operations under the hood: + +* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax) +* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax) All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. diff --git a/docs/source/composite/softmax.rst b/docs/source/composite/softmax.rst new file mode 100644 index 00000000..4f03820d --- /dev/null +++ b/docs/source/composite/softmax.rst @@ -0,0 +1,9 @@ +Scatter Softmax +=============== + +.. automodule:: torch_scatter.composite + :noindex: + +.. autofunction:: scatter_softmax + +.. autofunction:: scatter_log_softmax diff --git a/docs/source/functions/logsumexp.rst b/docs/source/functions/logsumexp.rst new file mode 100644 index 00000000..88aea63b --- /dev/null +++ b/docs/source/functions/logsumexp.rst @@ -0,0 +1,7 @@ +Scatter LogSumExp +================= + +.. automodule:: torch_scatter + :noindex: + +.. autofunction:: scatter_logsumexp diff --git a/docs/source/index.rst b/docs/source/index.rst index 50228ac9..eaa2574c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i :caption: Package reference functions/* + composite/* Indices and tables ================== diff --git a/test/composite/test_softmax.py b/test/composite/test_softmax.py new file mode 100644 index 00000000..385c1949 --- /dev/null +++ b/test/composite/test_softmax.py @@ -0,0 +1,47 @@ +from itertools import product + +import pytest +import torch +from torch_scatter.composite import scatter_log_softmax, scatter_softmax + +from test.utils import devices, tensor, grad_dtypes + + +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_softmax(dtype, device): + src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_softmax(src, index) + + out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) + out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) + out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1) + out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype), + dim=-1) + + expected = torch.stack([ + out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] + ], dim=0) + + assert torch.allclose(out, expected) + + +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_log_softmax(dtype, device): + src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_log_softmax(src, index) + + out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) + out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) + out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1) + out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype), + dim=-1) + + expected = torch.stack([ + out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] + ], dim=0) + + assert torch.allclose(out, expected) diff --git a/test/test_logsumexp.py b/test/test_logsumexp.py new file mode 100644 index 00000000..9d87e2f7 --- /dev/null +++ b/test/test_logsumexp.py @@ -0,0 +1,24 @@ +from itertools import product + +import torch +import pytest +from torch_scatter import scatter_logsumexp + +from .utils import devices, tensor, grad_dtypes + + +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_logsumexp(dtype, device): + src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_logsumexp(src, index) + + out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1) + out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) + out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1) + out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype) + out4 = torch.tensor(-1, dtype=dtype) + + expected = torch.stack([out0, out1, out2, out3, out4], dim=0) + assert torch.allclose(out, expected) diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index c3e57a89..868f8187 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -6,6 +6,8 @@ from .std import scatter_std from .max import scatter_max from .min import scatter_min +from .logsumexp import scatter_logsumexp +import torch_scatter.composite __version__ = '1.3.2' @@ -18,5 +20,7 @@ 'scatter_std', 'scatter_max', 'scatter_min', + 'scatter_logsumexp', + 'torch_scatter', '__version__', ] diff --git a/torch_scatter/composite/__init__.py b/torch_scatter/composite/__init__.py new file mode 100644 index 00000000..74cdb7db --- /dev/null +++ b/torch_scatter/composite/__init__.py @@ -0,0 +1,6 @@ +from .softmax import scatter_log_softmax, scatter_softmax + +__all__ = [ + 'scatter_softmax', + 'scatter_log_softmax', +] diff --git a/torch_scatter/composite/softmax.py b/torch_scatter/composite/softmax.py new file mode 100644 index 00000000..78757f65 --- /dev/null +++ b/torch_scatter/composite/softmax.py @@ -0,0 +1,86 @@ +import torch + +from torch_scatter import scatter_add, scatter_max + + +def scatter_softmax(src, index, dim=-1, eps=1e-12): + r""" + Softmax operation over all values in :attr:`src` tensor that share indices + specified in the :attr:`index` tensor along a given axis :attr:`dim`. + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i = + \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + eps (float, optional): Small value to ensure numerical stability. + (default: :obj:`1e-12`) + + :rtype: :class:`Tensor` + """ + if not torch.is_floating_point(src): + raise ValueError('`scatter_softmax` can only be computed over tensors ' + 'with floating point data types.') + + max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) + max_per_src_element = max_value_per_index.gather(dim, index) + + recentered_scores = src - max_per_src_element + recentered_scores_exp = recentered_scores.exp() + + sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim) + normalizing_constants = (sum_per_index + eps).gather(dim, index) + + return recentered_scores_exp / normalizing_constants + + +def scatter_log_softmax(src, index, dim=-1, eps=1e-12): + r""" + Log-softmax operation over all values in :attr:`src` tensor that share + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`. + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i = + \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} + \right) + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + eps (float, optional): Small value to ensure numerical stability. + (default: :obj:`1e-12`) + + :rtype: :class:`Tensor` + """ + if not torch.is_floating_point(src): + raise ValueError('`scatter_log_softmax` can only be computed over ' + 'tensors with floating point data types.') + + max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) + max_per_src_element = max_value_per_index.gather(dim, index) + + recentered_scores = src - max_per_src_element + + sum_per_index = scatter_add(src=recentered_scores.exp(), index=index, + dim=dim) + + normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index) + + return recentered_scores - normalizing_constants diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py new file mode 100644 index 00000000..16e9d182 --- /dev/null +++ b/torch_scatter/logsumexp.py @@ -0,0 +1,54 @@ +import torch + +from . import scatter_add, scatter_max + + +def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, + fill_value=None, eps=1e-12): + r"""Fills :attr:`out` with the log of summed exponentials of all values + from the :attr:`src` tensor at the indices specified in the :attr:`index` + tensor along a given axis :attr:`dim`. + If multiple indices reference the same location, their + **exponential contributions add** + (`cf.` :meth:`~torch_scatter.scatter_add`). + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j + \exp(\mathrm{src}_j) \right) + + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + out (Tensor, optional): The destination tensor. (default: :obj:`None`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor is + returned. (default: :obj:`None`) + fill_value (int, optional): If :attr:`out` is not given, automatically + fill output tensor with :attr:`fill_value`. (default: :obj:`None`) + eps (float, optional): Small value to ensure numerical stability. + (default: :obj:`1e-12`) + + :rtype: :class:`Tensor` + """ + if not torch.is_floating_point(src): + raise ValueError('`scatter_logsumexp` can only be computed over ' + 'tensors with floating point data types.') + + max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size, + fill_value) + max_per_src_element = max_value_per_index.gather(dim, index) + recentered_scores = src - max_per_src_element + out = (out - max_per_src_element).exp() if out is not None else None + + sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out, + dim_size, fill_value=0) + + return torch.log(sum_per_index + eps) + max_value_per_index