From dd50d35f588a7aaaa03d0b36d775995edc1946cf Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Mon, 4 Nov 2019 19:40:03 +0000 Subject: [PATCH 1/8] A first round of implementation of scatter_logsumexp/softmax/logsoftmax ops. --- docs/source/functions/logsumexp.rst | 7 +++ torch_scatter/__init__.py | 2 + torch_scatter/composite/__init__.py | 6 ++ torch_scatter/composite/softmax.py | 85 +++++++++++++++++++++++++++++ torch_scatter/logsumexp.py | 66 ++++++++++++++++++++++ 5 files changed, 166 insertions(+) create mode 100644 docs/source/functions/logsumexp.rst create mode 100644 torch_scatter/composite/__init__.py create mode 100644 torch_scatter/composite/softmax.py create mode 100644 torch_scatter/logsumexp.py diff --git a/docs/source/functions/logsumexp.rst b/docs/source/functions/logsumexp.rst new file mode 100644 index 00000000..900a0918 --- /dev/null +++ b/docs/source/functions/logsumexp.rst @@ -0,0 +1,7 @@ +Scatter LogSumExp +=========== + +.. automodule:: torch_scatter + :noindex: + +.. autofunction:: scatter_logsumexp diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index c3e57a89..2c34f755 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -6,6 +6,7 @@ from .std import scatter_std from .max import scatter_max from .min import scatter_min +from .logsumexp import scatter_logsumexp __version__ = '1.3.2' @@ -18,5 +19,6 @@ 'scatter_std', 'scatter_max', 'scatter_min', + 'scatter_logsumexp', '__version__', ] diff --git a/torch_scatter/composite/__init__.py b/torch_scatter/composite/__init__.py new file mode 100644 index 00000000..ab3d1dbc --- /dev/null +++ b/torch_scatter/composite/__init__.py @@ -0,0 +1,6 @@ +from .softmax import scatter_log_softmax, scatter_softmax + +__all__ = [ + 'scatter_softmax', + 'scatter_log_softmax' +] \ No newline at end of file diff --git a/torch_scatter/composite/softmax.py b/torch_scatter/composite/softmax.py new file mode 100644 index 00000000..1369e77c --- /dev/null +++ b/torch_scatter/composite/softmax.py @@ -0,0 +1,85 @@ +import torch + +from torch_scatter.logsumexp import _scatter_logsumexp + +def scatter_log_softmax(src, index, dim=-1, dim_size=None): + r""" + Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`.If multiple indices reference the same location, their + **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) + + where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Compute a numerically safe log softmax operation + from the :attr:`src` tensor into :attr:`out` at the indices + specified in the :attr:`index` tensor along a given axis :attr:`dim`. For + each value in :attr:`src`, its output index is specified by its index in + :attr:`input` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor is + returned. (default: :obj:`None`) + fill_value (int, optional): If :attr:`out` is not given, automatically + fill output tensor with :attr:`fill_value`. If set to :obj:`None`, + the output tensor is filled with the smallest possible value of + :obj:`src.dtype`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + per_index_logsumexp, recentered_src = _scatter_logsumexp(src, index, dim=dim, dim_size=dim_size) + return recentered_src - per_index_logsumexp.gather(dim, index) + + +def scatter_softmax(src, index, dim=-1, dim_size=None): + r""" + Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`. If multiple indices reference the same location, their + **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} + + where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. + + Compute a numerically safe softmax operation + from the :attr:`src` tensor into :attr:`out` at the indices + specified in the :attr:`index` tensor along a given axis :attr:`dim`. For + each value in :attr:`src`, its output index is specified by its index in + :attr:`input` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor is + returned. (default: :obj:`None`) + fill_value (int, optional): If :attr:`out` is not given, automatically + fill output tensor with :attr:`fill_value`. If set to :obj:`None`, + the output tensor is filled with the smallest possible value of + :obj:`src.dtype`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + return scatter_log_softmax(src, index, dim, dim_size).exp() diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py new file mode 100644 index 00000000..327f4230 --- /dev/null +++ b/torch_scatter/logsumexp.py @@ -0,0 +1,66 @@ +import torch + +from . import scatter_add, scatter_max + +EPSILON = 1e-16 + +def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None): + if not torch.is_floating_point(src): + raise ValueError('logsumexp can be computed over tensors floating point data types.') + + if fill_value is None: + fill_value = torch.finfo(src.dtype).min + + dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size + max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value) + max_per_src_element = max_value_per_index.gather(dim, index) + + recentered_scores = src - max_per_src_element + + sum_per_index = scatter_add( + src=recentered_scores.exp(), + index=index, + dim=dim, + out=(src - max_per_src_element).exp() if out is not None else None, + dim_size=dim_size, + fill_value=fill_value, + ) + return torch.log(sum_per_index + EPSILON) + max_value_per_index, recentered_scores + +def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None): + r""" + Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the + indices specified in the :attr:`index` tensor along a given axis + :attr:`dim`. If multiple indices reference the same location, their + **contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`). + + For one-dimensional tensors, the operation computes + + .. math:: + \mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right) + + Compute a numerically safe logsumexp operation + from the :attr:`src` tensor into :attr:`out` at the indices + specified in the :attr:`index` tensor along a given axis :attr:`dim`. For + each value in :attr:`src`, its output index is specified by its index in + :attr:`input` for dimensions outside of :attr:`dim` and by the + corresponding value in :attr:`index` for dimension :attr:`dim`. + + Args: + src (Tensor): The source tensor. + index (LongTensor): The indices of elements to scatter. + dim (int, optional): The axis along which to index. + (default: :obj:`-1`) + out (Tensor, optional): The destination tensor. (default: :obj:`None`) + dim_size (int, optional): If :attr:`out` is not given, automatically + create output with size :attr:`dim_size` at dimension :attr:`dim`. + If :attr:`dim_size` is not given, a minimal sized output tensor is + returned. (default: :obj:`None`) + fill_value (int, optional): If :attr:`out` is not given, automatically + fill output tensor with :attr:`fill_value`. If set to :obj:`None`, + the output tensor is filled with the smallest possible value of + :obj:`src.dtype`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value)[0] From 9c7af8dfefb72d98f166c7477d483829358a06ff Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Mon, 4 Nov 2019 19:53:26 +0000 Subject: [PATCH 2/8] Move epsilon to an argument. --- torch_scatter/logsumexp.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py index 327f4230..7529d115 100644 --- a/torch_scatter/logsumexp.py +++ b/torch_scatter/logsumexp.py @@ -2,9 +2,8 @@ from . import scatter_add, scatter_max -EPSILON = 1e-16 -def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None): +def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16): if not torch.is_floating_point(src): raise ValueError('logsumexp can be computed over tensors floating point data types.') @@ -25,9 +24,10 @@ def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=N dim_size=dim_size, fill_value=fill_value, ) - return torch.log(sum_per_index + EPSILON) + max_value_per_index, recentered_scores + return torch.log(sum_per_index + epsilon) + max_value_per_index, recentered_scores -def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None): + +def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16): r""" Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis @@ -63,4 +63,4 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No :rtype: :class:`Tensor` """ - return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value)[0] + return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value, epsilon=epsilon)[0] From eb78c985a25123966c31f22f8a8ab9c38dec5ce8 Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Mon, 4 Nov 2019 20:05:11 +0000 Subject: [PATCH 3/8] Flake linting warnings. --- docs/source/functions/logsumexp.rst | 2 +- torch_scatter/composite/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/functions/logsumexp.rst b/docs/source/functions/logsumexp.rst index 900a0918..88aea63b 100644 --- a/docs/source/functions/logsumexp.rst +++ b/docs/source/functions/logsumexp.rst @@ -1,5 +1,5 @@ Scatter LogSumExp -=========== +================= .. automodule:: torch_scatter :noindex: diff --git a/torch_scatter/composite/__init__.py b/torch_scatter/composite/__init__.py index ab3d1dbc..63ea6bd3 100644 --- a/torch_scatter/composite/__init__.py +++ b/torch_scatter/composite/__init__.py @@ -3,4 +3,4 @@ __all__ = [ 'scatter_softmax', 'scatter_log_softmax' -] \ No newline at end of file +] From 0ef926023e5b4bce2ded46c0293677712357fa3d Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Mon, 4 Nov 2019 21:20:13 +0100 Subject: [PATCH 4/8] Update __init__.py --- torch_scatter/composite/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_scatter/composite/__init__.py b/torch_scatter/composite/__init__.py index 63ea6bd3..74cdb7db 100644 --- a/torch_scatter/composite/__init__.py +++ b/torch_scatter/composite/__init__.py @@ -2,5 +2,5 @@ __all__ = [ 'scatter_softmax', - 'scatter_log_softmax' + 'scatter_log_softmax', ] From 7b14c67136ea0bf5a1ca2e14eca4eb4ccce7d31f Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Tue, 5 Nov 2019 21:58:17 +0000 Subject: [PATCH 5/8] Bug fixes, testing and other minor edits. * `log_softmax` has now stand-alone to save one operation (and fix a bug). * `softmax` is implemented in a similar stand-alone way. * Address some PR comments. --- test/test_logsumexp.py | 24 ++++++++++++++ test/test_softmax.py | 51 ++++++++++++++++++++++++++++++ torch_scatter/composite/softmax.py | 47 +++++++++++++++++++++++---- torch_scatter/logsumexp.py | 46 ++++++++++++--------------- 4 files changed, 136 insertions(+), 32 deletions(-) create mode 100644 test/test_logsumexp.py create mode 100644 test/test_softmax.py diff --git a/test/test_logsumexp.py b/test/test_logsumexp.py new file mode 100644 index 00000000..73b1c9b3 --- /dev/null +++ b/test/test_logsumexp.py @@ -0,0 +1,24 @@ +from itertools import product + +import torch +import pytest +from torch_scatter import scatter_max, scatter_logsumexp + +from .utils import devices, tensor + +SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} + +@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) +def test_logsumexp(dtype, device): + src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_logsumexp(src, index) + + idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist() + idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() + idx2 = 7 # Single element + idx3 = torch.finfo(dtype).min # Empty index, returns yield value + idx4 = -1 # logsumexp with -inf is the identity + + assert out.tolist() == [idx0, idx1, idx2, idx3, idx4] diff --git a/test/test_softmax.py b/test/test_softmax.py new file mode 100644 index 00000000..f1d38aca --- /dev/null +++ b/test/test_softmax.py @@ -0,0 +1,51 @@ +from itertools import product + +import numpy as np +import pytest +import torch +from torch_scatter.composite import scatter_log_softmax, scatter_softmax + +from .utils import devices, tensor + +SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} + +@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) +def test_log_softmax(dtype, device): + src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_log_softmax(src, index) + + # Expected results per index + idx0 = [np.log(0.5), np.log(0.5)] + idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() + idx2 = 0.0 # Single element, has logprob=0 + # index=3 is empty. Should not matter. + idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf + + np.testing.assert_allclose( + out.tolist(), + [idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]], + rtol=1e-05, atol=1e-10 + ) + + +@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) +def test_softmax(dtype, device): + src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_softmax(src, index) + + # Expected results per index + idx0 = [0.5, 0.5] + idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() + idx2 = 1 # Single element, has prob=1 + # index=3 is empty. Should not matter. + idx4 = [1.0, 0.0] # softmax with -inf yields zero probability + + np.testing.assert_allclose( + out.tolist(), + [idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]], + rtol=1e-05, atol=1e-10 + ) \ No newline at end of file diff --git a/torch_scatter/composite/softmax.py b/torch_scatter/composite/softmax.py index 1369e77c..da95df6f 100644 --- a/torch_scatter/composite/softmax.py +++ b/torch_scatter/composite/softmax.py @@ -1,6 +1,6 @@ import torch -from torch_scatter.logsumexp import _scatter_logsumexp +from torch_scatter import scatter_add, scatter_max def scatter_log_softmax(src, index, dim=-1, dim_size=None): r""" @@ -12,7 +12,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None): For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) + \mathrm{out}_i = softmax(\mathrm{src}_i) = + \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that :math:`\mathrm{index}_j = i`. @@ -40,11 +41,26 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None): :rtype: :class:`Tensor` """ - per_index_logsumexp, recentered_src = _scatter_logsumexp(src, index, dim=dim, dim_size=dim_size) - return recentered_src - per_index_logsumexp.gather(dim, index) + if not torch.is_floating_point(src): + raise ValueError('log_softmax can be computed only over tensors with floating point data types.') + max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size) + max_per_src_element = max_value_per_index.gather(dim, index) -def scatter_softmax(src, index, dim=-1, dim_size=None): + recentered_scores = src - max_per_src_element + + sum_per_index = scatter_add( + src=recentered_scores.exp(), + index=index, + dim=dim, + dim_size=dim_size + ) + log_normalizing_constants = sum_per_index.log().gather(dim, index) + + return recentered_scores - log_normalizing_constants + + +def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): r""" Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis @@ -54,7 +70,8 @@ def scatter_softmax(src, index, dim=-1, dim_size=None): For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} + \mathrm{out}_i = softmax(\mathrm{src}_i) = + \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that :math:`\mathrm{index}_j = i`. @@ -82,4 +99,20 @@ def scatter_softmax(src, index, dim=-1, dim_size=None): :rtype: :class:`Tensor` """ - return scatter_log_softmax(src, index, dim, dim_size).exp() + if not torch.is_floating_point(src): + raise ValueError('softmax can be computed only over tensors with floating point data types.') + + max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size) + max_per_src_element = max_value_per_index.gather(dim, index) + + recentered_scores = src - max_per_src_element + exped_recentered_scores = recentered_scores.exp() + + sum_per_index = scatter_add( + src=exped_recentered_scores, + index=index, + dim=dim, + dim_size=dim_size + ) + normalizing_constant = (sum_per_index + epsilon).gather(dim, index) + return exped_recentered_scores / normalizing_constant diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py index 7529d115..499f1f26 100644 --- a/torch_scatter/logsumexp.py +++ b/torch_scatter/logsumexp.py @@ -3,30 +3,6 @@ from . import scatter_add, scatter_max -def _scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16): - if not torch.is_floating_point(src): - raise ValueError('logsumexp can be computed over tensors floating point data types.') - - if fill_value is None: - fill_value = torch.finfo(src.dtype).min - - dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size - max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value) - max_per_src_element = max_value_per_index.gather(dim, index) - - recentered_scores = src - max_per_src_element - - sum_per_index = scatter_add( - src=recentered_scores.exp(), - index=index, - dim=dim, - out=(src - max_per_src_element).exp() if out is not None else None, - dim_size=dim_size, - fill_value=fill_value, - ) - return torch.log(sum_per_index + epsilon) + max_value_per_index, recentered_scores - - def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16): r""" Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the @@ -63,4 +39,24 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No :rtype: :class:`Tensor` """ - return _scatter_logsumexp(src,index, dim, out, dim_size, fill_value, epsilon=epsilon)[0] + if not torch.is_floating_point(src): + raise ValueError('logsumexp can be computed over tensors with floating point data types.') + + if fill_value is None: + fill_value = torch.finfo(src.dtype).min + + dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size + max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value) + max_per_src_element = max_value_per_index.gather(dim, index) + + recentered_scores = src - max_per_src_element + + sum_per_index = scatter_add( + src=recentered_scores.exp(), + index=index, + dim=dim, + out=(out - max_per_src_element).exp() if out is not None else None, + dim_size=dim_size, + fill_value=0, + ) + return torch.log(sum_per_index + epsilon) + max_value_per_index From 0c127881d33e872cce4e6cacbfd82a462ccb8b21 Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Tue, 5 Nov 2019 22:12:44 +0000 Subject: [PATCH 6/8] Address most flake8, pycodestyle errors. --- test/test_logsumexp.py | 3 ++- test/test_softmax.py | 3 ++- torch_scatter/composite/softmax.py | 25 +++++++++++++++++-------- torch_scatter/logsumexp.py | 20 ++++++++++++++------ 4 files changed, 35 insertions(+), 16 deletions(-) diff --git a/test/test_logsumexp.py b/test/test_logsumexp.py index 73b1c9b3..61c4970d 100644 --- a/test/test_logsumexp.py +++ b/test/test_logsumexp.py @@ -2,12 +2,13 @@ import torch import pytest -from torch_scatter import scatter_max, scatter_logsumexp +from torch_scatter import scatter_logsumexp from .utils import devices, tensor SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} + @pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) def test_logsumexp(dtype, device): src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) diff --git a/test/test_softmax.py b/test/test_softmax.py index f1d38aca..71e7472b 100644 --- a/test/test_softmax.py +++ b/test/test_softmax.py @@ -9,6 +9,7 @@ SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} + @pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) def test_log_softmax(dtype, device): src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) @@ -48,4 +49,4 @@ def test_softmax(dtype, device): out.tolist(), [idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]], rtol=1e-05, atol=1e-10 - ) \ No newline at end of file + ) diff --git a/torch_scatter/composite/softmax.py b/torch_scatter/composite/softmax.py index da95df6f..a382496e 100644 --- a/torch_scatter/composite/softmax.py +++ b/torch_scatter/composite/softmax.py @@ -2,9 +2,11 @@ from torch_scatter import scatter_add, scatter_max + def scatter_log_softmax(src, index, dim=-1, dim_size=None): r""" - Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the + Numerical safe log-softmax of all values from + the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis :attr:`dim`.If multiple indices reference the same location, their **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). @@ -12,7 +14,7 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None): For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = softmax(\mathrm{src}_i) = + \mathrm{out}_i = softmax(\mathrm{src}_i) = \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that @@ -42,9 +44,12 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None): :rtype: :class:`Tensor` """ if not torch.is_floating_point(src): - raise ValueError('log_softmax can be computed only over tensors with floating point data types.') + raise ValueError('log_softmax can be computed only over ' + 'tensors with floating point data types.') - max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size) + max_value_per_index, _ = scatter_max(src, index, + dim=dim, + dim_size=dim_size) max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element @@ -62,7 +67,8 @@ def scatter_log_softmax(src, index, dim=-1, dim_size=None): def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): r""" - Numerical safe log-softmax of all values from the :attr:`src` tensor into :attr:`out` at the + Numerical safe log-softmax of all values from + the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis :attr:`dim`. If multiple indices reference the same location, their **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). @@ -70,7 +76,7 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = softmax(\mathrm{src}_i) = + \mathrm{out}_i = softmax(\mathrm{src}_i) = \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that @@ -100,9 +106,12 @@ def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): :rtype: :class:`Tensor` """ if not torch.is_floating_point(src): - raise ValueError('softmax can be computed only over tensors with floating point data types.') + raise ValueError('softmax can be computed only over ' + 'tensors with floating point data types.') - max_value_per_index, _ = scatter_max(src, index, dim=dim, dim_size=dim_size) + max_value_per_index, _ = scatter_max(src, index, + dim=dim, + dim_size=dim_size) max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py index 499f1f26..1ac8c62a 100644 --- a/torch_scatter/logsumexp.py +++ b/torch_scatter/logsumexp.py @@ -3,9 +3,11 @@ from . import scatter_add, scatter_max -def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=None, epsilon=1e-16): +def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, + fill_value=None, epsilon=1e-16): r""" - Numerically safe logsumexp of all values from the :attr:`src` tensor into :attr:`out` at the + Numerically safe logsumexp of all values from + the :attr:`src` tensor into :attr:`out` at the indices specified in the :attr:`index` tensor along a given axis :attr:`dim`. If multiple indices reference the same location, their **contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`). @@ -13,7 +15,8 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + \sum_j \exp(\mathrm{src}_j) \right) + \mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) + + \sum_j \exp(\mathrm{src}_j) \right) Compute a numerically safe logsumexp operation from the :attr:`src` tensor into :attr:`out` at the indices @@ -40,13 +43,18 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, fill_value=No :rtype: :class:`Tensor` """ if not torch.is_floating_point(src): - raise ValueError('logsumexp can be computed over tensors with floating point data types.') + raise ValueError('logsumexp can only be computed over ' + 'tensors with floating point data types.') if fill_value is None: fill_value = torch.finfo(src.dtype).min - dim_size = out.shape[dim] if dim_size is None and out is not None else dim_size - max_value_per_index, _ = scatter_max(src, index, dim=dim, out=out, dim_size=dim_size, fill_value=fill_value) + dim_size = out.shape[dim] \ + if dim_size is None and out is not None else dim_size + + max_value_per_index, _ = scatter_max(src, index, dim=dim, + out=out, dim_size=dim_size, + fill_value=fill_value) max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element From d63eb9c9d65711617c6d2104c0af61ea47392f03 Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Tue, 5 Nov 2019 22:29:36 +0000 Subject: [PATCH 7/8] Remaining flake8 formatting errors --- test/test_logsumexp.py | 11 ++++++++--- test/test_softmax.py | 20 ++++++++++++++------ 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/test/test_logsumexp.py b/test/test_logsumexp.py index 61c4970d..082e1d4b 100644 --- a/test/test_logsumexp.py +++ b/test/test_logsumexp.py @@ -9,15 +9,20 @@ SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} -@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) +@pytest.mark.parametrize('dtype,device', + product(SUPPORTED_FLOAT_DTYPES, devices)) def test_logsumexp(dtype, device): src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) out = scatter_logsumexp(src, index) - idx0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1).tolist() - idx1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() + idx0 = torch.logsumexp( + torch.tensor([0.5, 0.5], dtype=dtype), + dim=-1).tolist() + idx1 = torch.logsumexp( + torch.tensor([0, -2.1, 3.2], dtype=dtype), + dim=-1).tolist() idx2 = 7 # Single element idx3 = torch.finfo(dtype).min # Empty index, returns yield value idx4 = -1 # logsumexp with -inf is the identity diff --git a/test/test_softmax.py b/test/test_softmax.py index 71e7472b..468166f8 100644 --- a/test/test_softmax.py +++ b/test/test_softmax.py @@ -10,16 +10,20 @@ SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} -@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) +@pytest.mark.parametrize('dtype,device', + product(SUPPORTED_FLOAT_DTYPES, devices)) def test_log_softmax(dtype, device): - src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], + dtype, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) out = scatter_log_softmax(src, index) # Expected results per index idx0 = [np.log(0.5), np.log(0.5)] - idx1 = torch.log_softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() + idx1 = torch.log_softmax( + torch.tensor([0.0, -2.1, 3.2], dtype=dtype), + dim=-1).tolist() idx2 = 0.0 # Single element, has logprob=0 # index=3 is empty. Should not matter. idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf @@ -31,16 +35,20 @@ def test_log_softmax(dtype, device): ) -@pytest.mark.parametrize('dtype,device', product(SUPPORTED_FLOAT_DTYPES, devices)) +@pytest.mark.parametrize('dtype,device', + product(SUPPORTED_FLOAT_DTYPES, devices)) def test_softmax(dtype, device): - src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], + dtype, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) out = scatter_softmax(src, index) # Expected results per index idx0 = [0.5, 0.5] - idx1 = torch.softmax(torch.tensor([0.0, -2.1, 3.2], dtype=dtype), dim=-1).tolist() + idx1 = torch.softmax( + torch.tensor([0.0, -2.1, 3.2], dtype=dtype), + dim=-1).tolist() idx2 = 1 # Single element, has prob=1 # index=3 is empty. Should not matter. idx4 = [1.0, 0.0] # softmax with -inf yields zero probability From 62c61224bfe53dba553d39b01ad5cc9d40800854 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 8 Nov 2019 14:24:38 +0100 Subject: [PATCH 8/8] clean up code base / added new functions to readme / added docs for softmax functions --- README.md | 6 ++ docs/source/composite/softmax.rst | 9 +++ docs/source/index.rst | 1 + test/composite/test_softmax.py | 47 +++++++++++++ test/test_logsumexp.py | 24 +++---- test/test_softmax.py | 60 ---------------- torch_scatter/__init__.py | 2 + torch_scatter/composite/softmax.py | 107 +++++++++-------------------- torch_scatter/logsumexp.py | 60 ++++++---------- 9 files changed, 129 insertions(+), 187 deletions(-) create mode 100644 docs/source/composite/softmax.rst create mode 100644 test/composite/test_softmax.py delete mode 100644 test/test_softmax.py diff --git a/README.md b/README.md index 51986a59..efa7a8f6 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,12 @@ The package consists of the following operations: * [**Scatter Std**](https://pytorch-scatter.readthedocs.io/en/latest/functions/std.html) * [**Scatter Min**](https://pytorch-scatter.readthedocs.io/en/latest/functions/min.html) * [**Scatter Max**](https://pytorch-scatter.readthedocs.io/en/latest/functions/max.html) +* [**Scatter LogSumExp**](https://pytorch-scatter.readthedocs.io/en/latest/functions/logsumexp.html) + +In addition, we provide composite functions which make use of `scatter_*` operations under the hood: + +* [**Scatter Softmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_softmax) +* [**Scatter LogSoftmax**](https://pytorch-scatter.readthedocs.io/en/latest/composite/softmax.html#torch_scatter.composite.scatter_log_softmax) All included operations are broadcastable, work on varying data types, and are implemented both for CPU and GPU with corresponding backward implementations. diff --git a/docs/source/composite/softmax.rst b/docs/source/composite/softmax.rst new file mode 100644 index 00000000..4f03820d --- /dev/null +++ b/docs/source/composite/softmax.rst @@ -0,0 +1,9 @@ +Scatter Softmax +=============== + +.. automodule:: torch_scatter.composite + :noindex: + +.. autofunction:: scatter_softmax + +.. autofunction:: scatter_log_softmax diff --git a/docs/source/index.rst b/docs/source/index.rst index 50228ac9..eaa2574c 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -14,6 +14,7 @@ All included operations are broadcastable, work on varying data types, and are i :caption: Package reference functions/* + composite/* Indices and tables ================== diff --git a/test/composite/test_softmax.py b/test/composite/test_softmax.py new file mode 100644 index 00000000..385c1949 --- /dev/null +++ b/test/composite/test_softmax.py @@ -0,0 +1,47 @@ +from itertools import product + +import pytest +import torch +from torch_scatter.composite import scatter_log_softmax, scatter_softmax + +from test.utils import devices, tensor, grad_dtypes + + +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_softmax(dtype, device): + src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_softmax(src, index) + + out0 = torch.softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) + out1 = torch.softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) + out2 = torch.softmax(torch.tensor([7], dtype=dtype), dim=-1) + out4 = torch.softmax(torch.tensor([-1, float('-inf')], dtype=dtype), + dim=-1) + + expected = torch.stack([ + out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] + ], dim=0) + + assert torch.allclose(out, expected) + + +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) +def test_log_softmax(dtype, device): + src = tensor([0.2, 0, 0.2, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) + index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) + + out = scatter_log_softmax(src, index) + + out0 = torch.log_softmax(torch.tensor([0.2, 0.2], dtype=dtype), dim=-1) + out1 = torch.log_softmax(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) + out2 = torch.log_softmax(torch.tensor([7], dtype=dtype), dim=-1) + out4 = torch.log_softmax(torch.tensor([-1, float('-inf')], dtype=dtype), + dim=-1) + + expected = torch.stack([ + out0[0], out1[0], out0[1], out1[1], out1[2], out2[0], out4[0], out4[1] + ], dim=0) + + assert torch.allclose(out, expected) diff --git a/test/test_logsumexp.py b/test/test_logsumexp.py index 082e1d4b..9d87e2f7 100644 --- a/test/test_logsumexp.py +++ b/test/test_logsumexp.py @@ -4,27 +4,21 @@ import pytest from torch_scatter import scatter_logsumexp -from .utils import devices, tensor +from .utils import devices, tensor, grad_dtypes -SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} - -@pytest.mark.parametrize('dtype,device', - product(SUPPORTED_FLOAT_DTYPES, devices)) +@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices)) def test_logsumexp(dtype, device): src = tensor([0.5, 0, 0.5, -2.1, 3.2, 7, -1, float('-inf')], dtype, device) index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) out = scatter_logsumexp(src, index) - idx0 = torch.logsumexp( - torch.tensor([0.5, 0.5], dtype=dtype), - dim=-1).tolist() - idx1 = torch.logsumexp( - torch.tensor([0, -2.1, 3.2], dtype=dtype), - dim=-1).tolist() - idx2 = 7 # Single element - idx3 = torch.finfo(dtype).min # Empty index, returns yield value - idx4 = -1 # logsumexp with -inf is the identity + out0 = torch.logsumexp(torch.tensor([0.5, 0.5], dtype=dtype), dim=-1) + out1 = torch.logsumexp(torch.tensor([0, -2.1, 3.2], dtype=dtype), dim=-1) + out2 = torch.logsumexp(torch.tensor(7, dtype=dtype), dim=-1) + out3 = torch.tensor(torch.finfo(dtype).min, dtype=dtype) + out4 = torch.tensor(-1, dtype=dtype) - assert out.tolist() == [idx0, idx1, idx2, idx3, idx4] + expected = torch.stack([out0, out1, out2, out3, out4], dim=0) + assert torch.allclose(out, expected) diff --git a/test/test_softmax.py b/test/test_softmax.py deleted file mode 100644 index 468166f8..00000000 --- a/test/test_softmax.py +++ /dev/null @@ -1,60 +0,0 @@ -from itertools import product - -import numpy as np -import pytest -import torch -from torch_scatter.composite import scatter_log_softmax, scatter_softmax - -from .utils import devices, tensor - -SUPPORTED_FLOAT_DTYPES = {torch.float32, torch.float64} - - -@pytest.mark.parametrize('dtype,device', - product(SUPPORTED_FLOAT_DTYPES, devices)) -def test_log_softmax(dtype, device): - src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], - dtype, device) - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) - - out = scatter_log_softmax(src, index) - - # Expected results per index - idx0 = [np.log(0.5), np.log(0.5)] - idx1 = torch.log_softmax( - torch.tensor([0.0, -2.1, 3.2], dtype=dtype), - dim=-1).tolist() - idx2 = 0.0 # Single element, has logprob=0 - # index=3 is empty. Should not matter. - idx4 = [0.0, float('-inf')] # log_softmax with -inf preserves the -inf - - np.testing.assert_allclose( - out.tolist(), - [idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]], - rtol=1e-05, atol=1e-10 - ) - - -@pytest.mark.parametrize('dtype,device', - product(SUPPORTED_FLOAT_DTYPES, devices)) -def test_softmax(dtype, device): - src = tensor([0.25, 0, 0.25, -2.1, 3.2, 7, -1, float('-inf')], - dtype, device) - index = tensor([0, 1, 0, 1, 1, 2, 4, 4], torch.long, device) - - out = scatter_softmax(src, index) - - # Expected results per index - idx0 = [0.5, 0.5] - idx1 = torch.softmax( - torch.tensor([0.0, -2.1, 3.2], dtype=dtype), - dim=-1).tolist() - idx2 = 1 # Single element, has prob=1 - # index=3 is empty. Should not matter. - idx4 = [1.0, 0.0] # softmax with -inf yields zero probability - - np.testing.assert_allclose( - out.tolist(), - [idx0[0], idx1[0], idx0[1], idx1[1], idx1[2], idx2, idx4[0], idx4[1]], - rtol=1e-05, atol=1e-10 - ) diff --git a/torch_scatter/__init__.py b/torch_scatter/__init__.py index 2c34f755..868f8187 100644 --- a/torch_scatter/__init__.py +++ b/torch_scatter/__init__.py @@ -7,6 +7,7 @@ from .max import scatter_max from .min import scatter_min from .logsumexp import scatter_logsumexp +import torch_scatter.composite __version__ = '1.3.2' @@ -20,5 +21,6 @@ 'scatter_max', 'scatter_min', 'scatter_logsumexp', + 'torch_scatter', '__version__', ] diff --git a/torch_scatter/composite/softmax.py b/torch_scatter/composite/softmax.py index a382496e..78757f65 100644 --- a/torch_scatter/composite/softmax.py +++ b/torch_scatter/composite/softmax.py @@ -3,125 +3,84 @@ from torch_scatter import scatter_add, scatter_max -def scatter_log_softmax(src, index, dim=-1, dim_size=None): +def scatter_softmax(src, index, dim=-1, eps=1e-12): r""" - Numerical safe log-softmax of all values from - the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`.If multiple indices reference the same location, their - **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). + Softmax operation over all values in :attr:`src` tensor that share indices + specified in the :attr:`index` tensor along a given axis :attr:`dim`. For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = softmax(\mathrm{src}_i) = - \mathrm{src}_i - \mathrm{logsumexp}_j ( \mathrm{src}_j) + \mathrm{out}_i = {\textrm{softmax}(\mathrm{src})}_i = + \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} - where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that + where :math:`\sum_j` is over :math:`j` such that :math:`\mathrm{index}_j = i`. - Compute a numerically safe log softmax operation - from the :attr:`src` tensor into :attr:`out` at the indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For - each value in :attr:`src`, its output index is specified by its index in - :attr:`input` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. - Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements to scatter. dim (int, optional): The axis along which to index. (default: :obj:`-1`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, - the output tensor is filled with the smallest possible value of - :obj:`src.dtype`. (default: :obj:`None`) + eps (float, optional): Small value to ensure numerical stability. + (default: :obj:`1e-12`) :rtype: :class:`Tensor` """ if not torch.is_floating_point(src): - raise ValueError('log_softmax can be computed only over ' - 'tensors with floating point data types.') + raise ValueError('`scatter_softmax` can only be computed over tensors ' + 'with floating point data types.') - max_value_per_index, _ = scatter_max(src, index, - dim=dim, - dim_size=dim_size) + max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element + recentered_scores_exp = recentered_scores.exp() - sum_per_index = scatter_add( - src=recentered_scores.exp(), - index=index, - dim=dim, - dim_size=dim_size - ) - log_normalizing_constants = sum_per_index.log().gather(dim, index) + sum_per_index = scatter_add(recentered_scores_exp, index, dim=dim) + normalizing_constants = (sum_per_index + eps).gather(dim, index) - return recentered_scores - log_normalizing_constants + return recentered_scores_exp / normalizing_constants -def scatter_softmax(src, index, dim=-1, dim_size=None, epsilon=1e-16): +def scatter_log_softmax(src, index, dim=-1, eps=1e-12): r""" - Numerical safe log-softmax of all values from - the :attr:`src` tensor into :attr:`out` at the + Log-softmax operation over all values in :attr:`src` tensor that share indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`. If multiple indices reference the same location, their - **contributions average** (`cf.` :meth:`~torch_scatter.scatter_add`). + :attr:`dim`. For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = softmax(\mathrm{src}_i) = - \frac{\exp(\mathrm{src}_i)}{\mathrm{logsumexp}_j ( \mathrm{src}_j)} + \mathrm{out}_i = {\textrm{log_softmax}(\mathrm{src})}_i = + \log \left( \frac{\exp(\mathrm{src}_i)}{\sum_j \exp(\mathrm{src}_j)} + \right) - where :math:`\mathrm{logsumexp}_j` is over :math:`j` such that + where :math:`\sum_j` is over :math:`j` such that :math:`\mathrm{index}_j = i`. - Compute a numerically safe softmax operation - from the :attr:`src` tensor into :attr:`out` at the indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For - each value in :attr:`src`, its output index is specified by its index in - :attr:`input` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. - Args: src (Tensor): The source tensor. index (LongTensor): The indices of elements to scatter. dim (int, optional): The axis along which to index. (default: :obj:`-1`) - dim_size (int, optional): If :attr:`out` is not given, automatically - create output with size :attr:`dim_size` at dimension :attr:`dim`. - If :attr:`dim_size` is not given, a minimal sized output tensor is - returned. (default: :obj:`None`) - fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, - the output tensor is filled with the smallest possible value of - :obj:`src.dtype`. (default: :obj:`None`) + eps (float, optional): Small value to ensure numerical stability. + (default: :obj:`1e-12`) :rtype: :class:`Tensor` """ if not torch.is_floating_point(src): - raise ValueError('softmax can be computed only over ' + raise ValueError('`scatter_log_softmax` can only be computed over ' 'tensors with floating point data types.') - max_value_per_index, _ = scatter_max(src, index, - dim=dim, - dim_size=dim_size) + max_value_per_index, _ = scatter_max(src, index, dim=dim, fill_value=0) max_per_src_element = max_value_per_index.gather(dim, index) recentered_scores = src - max_per_src_element - exped_recentered_scores = recentered_scores.exp() - - sum_per_index = scatter_add( - src=exped_recentered_scores, - index=index, - dim=dim, - dim_size=dim_size - ) - normalizing_constant = (sum_per_index + epsilon).gather(dim, index) - return exped_recentered_scores / normalizing_constant + + sum_per_index = scatter_add(src=recentered_scores.exp(), index=index, + dim=dim) + + normalizing_constants = torch.log(sum_per_index + eps).gather(dim, index) + + return recentered_scores - normalizing_constants diff --git a/torch_scatter/logsumexp.py b/torch_scatter/logsumexp.py index 1ac8c62a..16e9d182 100644 --- a/torch_scatter/logsumexp.py +++ b/torch_scatter/logsumexp.py @@ -4,26 +4,22 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, - fill_value=None, epsilon=1e-16): - r""" - Numerically safe logsumexp of all values from - the :attr:`src` tensor into :attr:`out` at the - indices specified in the :attr:`index` tensor along a given axis - :attr:`dim`. If multiple indices reference the same location, their - **contributions logsumexp** (`cf.` :meth:`~torch_scatter.scatter_add`). + fill_value=None, eps=1e-12): + r"""Fills :attr:`out` with the log of summed exponentials of all values + from the :attr:`src` tensor at the indices specified in the :attr:`index` + tensor along a given axis :attr:`dim`. + If multiple indices reference the same location, their + **exponential contributions add** + (`cf.` :meth:`~torch_scatter.scatter_add`). For one-dimensional tensors, the operation computes .. math:: - \mathrm{out}_i = \log \left( \exp(\mathrm{out}_i) - + \sum_j \exp(\mathrm{src}_j) \right) + \mathrm{out}_i = \log \, \left( \exp(\mathrm{out}_i) + \sum_j + \exp(\mathrm{src}_j) \right) - Compute a numerically safe logsumexp operation - from the :attr:`src` tensor into :attr:`out` at the indices - specified in the :attr:`index` tensor along a given axis :attr:`dim`. For - each value in :attr:`src`, its output index is specified by its index in - :attr:`input` for dimensions outside of :attr:`dim` and by the - corresponding value in :attr:`index` for dimension :attr:`dim`. + where :math:`\sum_j` is over :math:`j` such that + :math:`\mathrm{index}_j = i`. Args: src (Tensor): The source tensor. @@ -36,35 +32,23 @@ def scatter_logsumexp(src, index, dim=-1, out=None, dim_size=None, If :attr:`dim_size` is not given, a minimal sized output tensor is returned. (default: :obj:`None`) fill_value (int, optional): If :attr:`out` is not given, automatically - fill output tensor with :attr:`fill_value`. If set to :obj:`None`, - the output tensor is filled with the smallest possible value of - :obj:`src.dtype`. (default: :obj:`None`) + fill output tensor with :attr:`fill_value`. (default: :obj:`None`) + eps (float, optional): Small value to ensure numerical stability. + (default: :obj:`1e-12`) :rtype: :class:`Tensor` """ if not torch.is_floating_point(src): - raise ValueError('logsumexp can only be computed over ' + raise ValueError('`scatter_logsumexp` can only be computed over ' 'tensors with floating point data types.') - if fill_value is None: - fill_value = torch.finfo(src.dtype).min - - dim_size = out.shape[dim] \ - if dim_size is None and out is not None else dim_size - - max_value_per_index, _ = scatter_max(src, index, dim=dim, - out=out, dim_size=dim_size, - fill_value=fill_value) + max_value_per_index, _ = scatter_max(src, index, dim, out, dim_size, + fill_value) max_per_src_element = max_value_per_index.gather(dim, index) - recentered_scores = src - max_per_src_element + out = (out - max_per_src_element).exp() if out is not None else None + + sum_per_index = scatter_add(recentered_scores.exp(), index, dim, out, + dim_size, fill_value=0) - sum_per_index = scatter_add( - src=recentered_scores.exp(), - index=index, - dim=dim, - out=(out - max_per_src_element).exp() if out is not None else None, - dim_size=dim_size, - fill_value=0, - ) - return torch.log(sum_per_index + epsilon) + max_value_per_index + return torch.log(sum_per_index + eps) + max_value_per_index