diff --git a/CHANGELOG.md b/CHANGELOG.md index 271acd853b5d..4dbe093e9bb7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -114,6 +114,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added implementation of `log_softmax` in `torch_geometric.utils` ([#8909](https://github.com/pyg-team/pytorch_geometric/pull/8909)) - Added an example for recommender systems, including k-NN search and retrieval metrics ([#8546](https://github.com/pyg-team/pytorch_geometric/pull/8546)) - Added multi-GPU evaluation in distributed sampling example ([#8880](https://github.com/pyg-team/pytorch_geometric/pull/8880)) - Added end-to-end example for distributed CPU training ([#8713](https://github.com/pyg-team/pytorch_geometric/pull/8713)) @@ -248,7 +249,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925)) - Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917)) - Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918)) -- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) +- Added support for floating-point slicing in `Dataset`, _e.g._, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915)) - Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895)) - Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827)) - Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864)) @@ -266,7 +267,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656)) - Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649)) - Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647)) -- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) +- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700)) - Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603)) - Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594)) - Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601)) @@ -326,7 +327,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166)) +- Fixed `HeteroConv` for layers that have a non-default argument order, _e.g._, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166)) - Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163)) - Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145)) - Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143)) @@ -335,7 +336,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942)) - Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737)) - Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955)) -- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) +- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956) - Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953)) - Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941)) - Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923)) diff --git a/test/utils/test_log_softmax.py b/test/utils/test_log_softmax.py new file mode 100644 index 000000000000..feda7bbdab89 --- /dev/null +++ b/test/utils/test_log_softmax.py @@ -0,0 +1,156 @@ +import pytest +import torch + +import torch_geometric.typing +from torch_geometric.profile import benchmark +from torch_geometric.utils import log_softmax + +CALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_LOG_SOFTMAX + or torch_geometric.typing.WITH_TORCH_SCATTER) + +ATOL = 1e-4 +RTOL = 1e-4 + + +def test_log_softmax(): + src = torch.tensor([1.0, 1.0, 1.0, 1.0]) + index = torch.tensor([0, 0, 1, 2]) + ptr = torch.tensor([0, 2, 3, 4]) + + out = log_softmax(src, index) + assert torch.allclose(out, torch.tensor([-0.6931, -0.6931, 0.0000, + 0.0000]), atol=ATOL, rtol=RTOL) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose(log_softmax(src, ptr=ptr), out, atol=ATOL, + rtol=RTOL) + else: + with pytest.raises(NotImplementedError, match="requires 'index'"): + log_softmax(src, ptr=ptr) + + src = src.view(-1, 1) + out = log_softmax(src, index) + assert torch.allclose( + out, + torch.tensor([[-0.6931], [-0.6931], [0.0000], [0.0000]]), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose(log_softmax(src, None, ptr), out, atol=ATOL, + rtol=RTOL) + + jit = torch.jit.script(log_softmax) + assert torch.allclose(jit(src, index), out, atol=ATOL, rtol=RTOL) + + +def test_log_softmax_backward(): + src_sparse = torch.rand(4, 8, requires_grad=True) + index = torch.tensor([0, 0, 1, 1]) + src_dense = src_sparse.clone().detach().view(2, 2, src_sparse.size(-1)) + src_dense.requires_grad_(True) + + out_sparse = log_softmax(src_sparse, index) + out_sparse.sum().backward() + out_dense = torch.log_softmax(src_dense, dim=1) + out_dense.sum().backward() + + assert torch.allclose(out_sparse, out_dense.view_as(out_sparse), atol=ATOL) + assert torch.allclose(src_sparse.grad, src_dense.grad.view_as(src_sparse), + atol=ATOL) + + +def test_log_softmax_dim(): + index = torch.tensor([0, 0, 0, 0]) + ptr = torch.tensor([0, 4]) + + src = torch.randn(4) + assert torch.allclose( + log_softmax(src, index, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + + src = torch.randn(4, 16) + assert torch.allclose( + log_softmax(src, index, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=0), + torch.log_softmax(src, dim=0), + atol=ATOL, + rtol=RTOL, + ) + + src = torch.randn(4, 4) + assert torch.allclose( + log_softmax(src, index, dim=-1), + torch.log_softmax(src, dim=-1), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=-1), + torch.log_softmax(src, dim=-1), + atol=ATOL, + rtol=RTOL, + ) + + src = torch.randn(4, 4, 16) + assert torch.allclose( + log_softmax(src, index, dim=1), + torch.log_softmax(src, dim=1), + atol=ATOL, + rtol=RTOL, + ) + if CALCULATION_VIA_PTR_AVAILABLE: + assert torch.allclose( + log_softmax(src, ptr=ptr, dim=1), + torch.log_softmax(src, dim=1), + atol=ATOL, + rtol=RTOL, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--backward", action="store_true") + args = parser.parse_args() + + num_nodes, num_edges = 10_000, 200_000 + x = torch.randn(num_edges, 64, device=args.device) + index = torch.randint(num_nodes, (num_edges, ), device=args.device) + + compiled_log_softmax = torch.compile(log_softmax) + + def dense_softmax(x, index): + x = x.view(num_nodes, -1, x.size(-1)) + return x.softmax(dim=-1) + + def dense_log_softmax(x, index): + x = x.view(num_nodes, -1, x.size(-1)) + return torch.log_softmax(x, dim=-1) + + benchmark( + funcs=[dense_log_softmax, log_softmax, compiled_log_softmax], + func_names=["Dense Log Softmax", "Vanilla", "Compiled"], + args=(x, index), + num_steps=50 if args.device == "cpu" else 500, + num_warmups=10 if args.device == "cpu" else 100, + backward=args.backward, + ) diff --git a/torch_geometric/typing.py b/torch_geometric/typing.py index cf4bf7bd8dfd..a88c32cb26d3 100644 --- a/torch_geometric/typing.py +++ b/torch_geometric/typing.py @@ -39,29 +39,31 @@ try: import pyg_lib # noqa + WITH_PYG_LIB = True - WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, 'grouped_matmul') - WITH_SEGMM = hasattr(pyg_lib.ops, 'segment_matmul') - if WITH_SEGMM and 'pytest' in sys.modules and torch.cuda.is_available(): + WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, "grouped_matmul") + WITH_SEGMM = hasattr(pyg_lib.ops, "segment_matmul") + if WITH_SEGMM and "pytest" in sys.modules and torch.cuda.is_available(): # NOTE `segment_matmul` is currently bugged on older NVIDIA cards which # let our GPU tests on CI crash. Try if this error is present on the # current GPU and disable `WITH_SEGMM`/`WITH_GMM` if necessary. # TODO Drop this code block once `segment_matmul` is fixed. try: - x = torch.randn(3, 4, device='cuda') - ptr = torch.tensor([0, 2, 3], device='cuda') - weight = torch.randn(2, 4, 4, device='cuda') + x = torch.randn(3, 4, device="cuda") + ptr = torch.tensor([0, 2, 3], device="cuda") + weight = torch.randn(2, 4, 4, device="cuda") out = pyg_lib.ops.segment_matmul(x, ptr, weight) except RuntimeError: WITH_GMM = False WITH_SEGMM = False - WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add') - WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr') - WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort') - WITH_METIS = hasattr(pyg_lib, 'partition') - WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature( + WITH_SAMPLED_OP = hasattr(pyg_lib.ops, "sampled_add") + WITH_SOFTMAX = hasattr(pyg_lib.ops, "softmax_csr") + WITH_LOG_SOFTMAX = hasattr(pyg_lib.ops, "log_softmax_csr") + WITH_INDEX_SORT = hasattr(pyg_lib.ops, "index_sort") + WITH_METIS = hasattr(pyg_lib, "partition") + WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ("edge_time" in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) - WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature( + WITH_WEIGHTED_NEIGHBOR_SAMPLE = ("edge_weight" in inspect.signature( pyg_lib.sampler.neighbor_sample).parameters) except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -73,6 +75,7 @@ WITH_SEGMM = False WITH_SAMPLED_OP = False WITH_SOFTMAX = False + WITH_LOG_SOFTMAX = False WITH_INDEX_SORT = False WITH_METIS = False WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False @@ -80,6 +83,7 @@ try: import torch_scatter # noqa + WITH_TORCH_SCATTER = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -90,8 +94,9 @@ try: import torch_cluster # noqa + WITH_TORCH_CLUSTER = True - WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__ + WITH_TORCH_CLUSTER_BATCH_SIZE = "batch_size" in torch_cluster.knn.__doc__ except Exception as e: if not isinstance(e, ImportError): # pragma: no cover warnings.warn(f"An issue occurred while importing 'torch-cluster'. " @@ -107,6 +112,7 @@ def __getattr__(self, key: str) -> Any: try: import torch_spline_conv # noqa + WITH_TORCH_SPLINE_CONV = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -118,6 +124,7 @@ def __getattr__(self, key: str) -> Any: try: import torch_sparse # noqa from torch_sparse import SparseStorage, SparseTensor + WITH_TORCH_SPARSE = True except Exception as e: if not isinstance(e, ImportError): # pragma: no cover @@ -170,7 +177,7 @@ def from_edge_index( sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, is_sorted: bool = False, trust_data: bool = False, - ) -> 'SparseTensor': + ) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") @property @@ -179,7 +186,7 @@ def storage(self) -> SparseStorage: @classmethod def from_dense(self, mat: Tensor, - has_value: bool = True) -> 'SparseTensor': + has_value: bool = True) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") def size(self, dim: int) -> int: @@ -195,11 +202,11 @@ def has_value(self) -> bool: raise ImportError("'SparseTensor' requires 'torch-sparse'") def set_value(self, value: Optional[Tensor], - layout: Optional[str] = None) -> 'SparseTensor': + layout: Optional[str] = None) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") def fill_value(self, fill_value: float, - dtype: Optional[torch.dtype] = None) -> 'SparseTensor': + dtype: Optional[torch.dtype] = None) -> "SparseTensor": raise ImportError("'SparseTensor' requires 'torch-sparse'") def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]: @@ -249,6 +256,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor, try: import torch_frame # noqa + WITH_TORCH_FRAME = True from torch_frame import TensorFrame except Exception: @@ -261,6 +269,7 @@ class TensorFrame: # type: ignore try: import intel_extension_for_pytorch # noqa + WITH_IPEX = True except Exception: WITH_IPEX = False @@ -279,6 +288,7 @@ def __init__( def t(self) -> Tensor: # Only support accessing its transpose: from torch_geometric.utils import to_torch_csr_tensor + size = self.size return to_torch_csr_tensor( self.edge_index.flip([0]), @@ -298,15 +308,15 @@ def t(self) -> Tensor: # Only support accessing its transpose: NodeOrEdgeType = Union[NodeType, EdgeType] -DEFAULT_REL = 'to' -EDGE_TYPE_STR_SPLIT = '__' +DEFAULT_REL = "to" +EDGE_TYPE_STR_SPLIT = "__" class EdgeTypeStr(str): r"""A helper class to construct serializable edge types by merging an edge type tuple into a single string. """ - def __new__(cls, *args: Any) -> 'EdgeTypeStr': + def __new__(cls, *args: Any) -> "EdgeTypeStr": if isinstance(args[0], (list, tuple)): # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: args = tuple(args[0]) diff --git a/torch_geometric/utils/__init__.py b/torch_geometric/utils/__init__.py index ecae6cc1c3a9..8b2723ac77e5 100644 --- a/torch_geometric/utils/__init__.py +++ b/torch_geometric/utils/__init__.py @@ -8,16 +8,26 @@ from .functions import cumsum from ._degree import degree from ._softmax import softmax +from ._log_softmax import log_softmax from ._lexsort import lexsort from ._sort_edge_index import sort_edge_index from ._coalesce import coalesce from .undirected import is_undirected, to_undirected -from .loop import (contains_self_loops, remove_self_loops, - segregate_self_loops, add_self_loops, - add_remaining_self_loops, get_self_loop_attr) +from .loop import ( + contains_self_loops, + remove_self_loops, + segregate_self_loops, + add_self_loops, + add_remaining_self_loops, + get_self_loop_attr, +) from .isolated import contains_isolated_nodes, remove_isolated_nodes -from ._subgraph import (get_num_hops, subgraph, k_hop_subgraph, - bipartite_subgraph) +from ._subgraph import ( + get_num_hops, + subgraph, + k_hop_subgraph, + bipartite_subgraph, +) from .dropout import dropout_adj, dropout_node, dropout_edge, dropout_path from ._homophily import homophily from ._assortativity import assortativity @@ -29,10 +39,16 @@ from ._to_dense_batch import to_dense_batch from ._to_dense_adj import to_dense_adj from .nested import to_nested_tensor, from_nested_tensor -from .sparse import (dense_to_sparse, is_sparse, is_torch_sparse_tensor, - to_torch_coo_tensor, to_torch_csr_tensor, - to_torch_csc_tensor, to_torch_sparse_tensor, - to_edge_index) +from .sparse import ( + dense_to_sparse, + is_sparse, + is_torch_sparse_tensor, + to_torch_coo_tensor, + to_torch_csr_tensor, + to_torch_csc_tensor, + to_torch_sparse_tensor, + to_edge_index, +) from ._spmm import spmm from ._unbatch import unbatch, unbatch_edge_index from ._one_hot import one_hot @@ -67,6 +83,7 @@ 'cumsum', 'degree', 'softmax', + 'log_softmax', 'lexsort', 'sort_edge_index', 'coalesce', @@ -153,4 +170,4 @@ # `structured_negative_sampling_feasible` is a long name and thus destroys the # documentation rendering. We remove it for now from the documentation: classes = copy.copy(__all__) -classes.remove('structured_negative_sampling_feasible') +classes.remove("structured_negative_sampling_feasible") diff --git a/torch_geometric/utils/_log_softmax.py b/torch_geometric/utils/_log_softmax.py new file mode 100644 index 000000000000..38f54baadfc1 --- /dev/null +++ b/torch_geometric/utils/_log_softmax.py @@ -0,0 +1,94 @@ +from typing import Optional + +import torch + +import torch_geometric.typing +from torch_geometric import is_compiling +from torch_geometric.utils import scatter, segment +from torch_geometric.utils.num_nodes import maybe_num_nodes + + +def log_softmax( + src: torch.Tensor, + index: Optional[torch.Tensor] = None, + ptr: Optional[torch.Tensor] = None, + num_nodes: Optional[int] = None, + dim: int = 0, +) -> torch.Tensor: + r"""Computes a sparsely evaluated log_softmax. + + Given a value tensor :attr:`src`, this function first groups the values + along the specified dimension based on the indices specified in + :attr:`index` or sorted inputs in CSR representation given by :attr:`ptr`, + and then proceeds to compute the log_softmax individually for each group. + + The log_softmax operation is defined as the logarithm of the softmax + probabilities, which can provide numerical stability improvements over + separately computing softmax followed by a logarithm. + + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + log_softmax. When specified, `src` values are grouped by `index` to + compute the log_softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the log_softmax based on + sorted inputs in CSR representation. This allows for efficient + computation over contiguous ranges of nodes. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.*, the maximum + value + 1 of :attr:`index`. This is required when `index` is + specified to determine the dimension for scattering operations. + (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. For most use + cases, this should be set to 0, as log_softmax is typically applied + along the node dimension in graph neural networks. + (default: :obj:`0`) + + :rtype: :class:`Tensor` + + Examples: + >>> src = torch.tensor([1., 2., 3., 4.]) + >>> index = torch.tensor([0, 0, 1, 1]) + >>> ptr = torch.tensor([0, 2, 4]) + >>> log_softmax(src, index) + tensor([-3.0486, -2.0486, -2.0486, -3.0486]) + + >>> log_softmax(src, None, ptr) + tensor([-3.0486, -2.0486, -2.0486, -3.0486]) + + >>> src = torch.randn(4, 4) + >>> ptr = torch.tensor([0, 4]) + >>> log_softmax(src, index, dim=-1) + tensor([[-1.3130, -0.6931, -0.3130, -1.3130], + [-1.0408, -0.0408, -0.0408, -1.0408], + [-0.5514, -0.5514, -0.1542, -0.5514], + [-0.7520, -0.7520, -0.1542, -0.7520]]) + """ + if (ptr is not None and torch_geometric.typing.WITH_TORCH_SCATTER + and not is_compiling()): + dim = dim + src.dim() if dim < 0 else dim + size = ([1] * dim) + [-1] + count = ptr[1:] - ptr[:-1] + ptr = ptr.view(size) + src_max = segment(src.detach(), ptr, reduce="max") + src_max = src_max.repeat_interleave(count, dim=dim) + out = src - src_max + out_exp = out.exp() + out_sum = segment(out_exp, ptr, reduce="sum") + 1e-16 + out_sum = out_sum.repeat_interleave(count, dim=dim) + log_out_sum = out_sum.log() + out = out - log_out_sum + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src.detach(), index, dim, dim_size=N, reduce="max") + out = src - src_max.index_select(dim, index) + out_exp = out.exp() + out_sum = scatter(out_exp, index, dim, dim_size=N, + reduce="sum") + 1e-16 + out_sum = out_sum.index_select(dim, index) + log_out_sum = out_sum.log() + out = out - log_out_sum + else: + raise NotImplementedError( + "'log_softmax' requires 'index' to be specified") + + return out