Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added log_softmax to utils with tests #8909

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 5 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added implementation of `log_softmax` in `torch_geometric.utils` ([#8909](https://github.com/pyg-team/pytorch_geometric/pull/8909))
- Added an example for recommender systems, including k-NN search and retrieval metrics ([#8546](https://github.com/pyg-team/pytorch_geometric/pull/8546))
- Added multi-GPU evaluation in distributed sampling example ([#8880](https://github.com/pyg-team/pytorch_geometric/pull/8880))
- Added end-to-end example for distributed CPU training ([#8713](https://github.com/pyg-team/pytorch_geometric/pull/8713))
Expand Down Expand Up @@ -209,7 +210,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added possibility to run training benchmarks on XPU device ([#7925](https://github.com/pyg-team/pytorch_geometric/pull/7925))
- Added `utils.ppr` for personalized PageRank computation ([#7917](https://github.com/pyg-team/pytorch_geometric/pull/7917))
- Added support for XPU device in `PrefetchLoader` ([#7918](https://github.com/pyg-team/pytorch_geometric/pull/7918))
- Added support for floating-point slicing in `Dataset`, *e.g.*, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added support for floating-point slicing in `Dataset`, _e.g._, `dataset[:0.9]` ([#7915](https://github.com/pyg-team/pytorch_geometric/pull/7915))
- Added nightly GPU tests ([#7895](https://github.com/pyg-team/pytorch_geometric/pull/7895))
- Added the `HalfHop` graph upsampling augmentation ([#7827](https://github.com/pyg-team/pytorch_geometric/pull/7827))
- Added the `Wikidata5M` dataset ([#7864](https://github.com/pyg-team/pytorch_geometric/pull/7864))
Expand All @@ -227,7 +228,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added back support for PyTorch >= 1.11.0 ([#7656](https://github.com/pyg-team/pytorch_geometric/pull/7656))
- Added `Data.sort()` and `HeteroData.sort()` functionalities ([#7649](https://github.com/pyg-team/pytorch_geometric/pull/7649))
- Added `torch.nested_tensor` support in `Data` and `Batch` ([#7643](https://github.com/pyg-team/pytorch_geometric/pull/7643), [#7647](https://github.com/pyg-team/pytorch_geometric/pull/7647))
- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700))
- Added `interval` argument to `Cartesian`, `LocalCartesian` and `Distance` transformations ([#7533](https://github.com/pyg-team/pytorch_geometric/pull/7533), [#7614](https://github.com/pyg-team/pytorch_geometric/pull/7614), [#7700](https://github.com/pyg-team/pytorch_geometric/pull/7700))
- Added a `LightGCN` example on the `AmazonBook` dataset ([7603](https://github.com/pyg-team/pytorch_geometric/pull/7603))
- Added a tutorial on hierarchical neighborhood sampling ([#7594](https://github.com/pyg-team/pytorch_geometric/pull/7594))
- Enabled different attention modes in `HypergraphConv` via the `attention_mode` argument ([#7601](https://github.com/pyg-team/pytorch_geometric/pull/7601))
Expand Down Expand Up @@ -287,7 +288,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166))
- Fixed `HeteroConv` for layers that have a non-default argument order, _e.g._, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166))
- Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163))
- Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145))
- Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143))
Expand All @@ -296,7 +297,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Changed the `trim_to_layer` function to filter out non-reachable node and edge types when operating on heterogeneous graphs ([#7942](https://github.com/pyg-team/pytorch_geometric/pull/7942))
- Accelerated and simplified `top_k` computation in `TopKPooling` ([#7737](https://github.com/pyg-team/pytorch_geometric/pull/7737))
- Updated `GIN` implementation in kernel benchmarks to have sequential batchnorms ([#7955](https://github.com/pyg-team/pytorch_geometric/pull/7955))
- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)
- Fixed bugs in benchmarks caused by a lack of the device conditions for CPU and unexpected `cache` argument in heterogeneous models ([#7956](https://github.com/pyg-team/pytorch_geometric/pull/7956)
- Fixed a bug in which `batch.e_id` was not correctly computed on unsorted graph inputs ([#7953](https://github.com/pyg-team/pytorch_geometric/pull/7953))
- Fixed `from_networkx` conversion from `nx.stochastic_block_model` graphs ([#7941](https://github.com/pyg-team/pytorch_geometric/pull/7941))
- Fixed the usage of `bias_initializer` in `HeteroLinear` ([#7923](https://github.com/pyg-team/pytorch_geometric/pull/7923))
Expand Down
156 changes: 156 additions & 0 deletions test/utils/test_log_softmax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import pytest
import torch

import torch_geometric.typing
from torch_geometric.profile import benchmark
from torch_geometric.utils import log_softmax

CALCULATION_VIA_PTR_AVAILABLE = (torch_geometric.typing.WITH_LOG_SOFTMAX
or torch_geometric.typing.WITH_TORCH_SCATTER)

ATOL = 1e-4
RTOL = 1e-4


def test_log_softmax():
src = torch.tensor([1.0, 1.0, 1.0, 1.0])
index = torch.tensor([0, 0, 1, 2])
ptr = torch.tensor([0, 2, 3, 4])

out = log_softmax(src, index)
assert torch.allclose(out, torch.tensor([-0.6931, -0.6931, 0.0000,
0.0000]), atol=ATOL, rtol=RTOL)
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(log_softmax(src, ptr=ptr), out, atol=ATOL,
rtol=RTOL)
else:
with pytest.raises(NotImplementedError, match="requires 'index'"):
log_softmax(src, ptr=ptr)

src = src.view(-1, 1)
out = log_softmax(src, index)
assert torch.allclose(
out,
torch.tensor([[-0.6931], [-0.6931], [0.0000], [0.0000]]),
atol=ATOL,
rtol=RTOL,
)
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(log_softmax(src, None, ptr), out, atol=ATOL,
rtol=RTOL)

jit = torch.jit.script(log_softmax)
assert torch.allclose(jit(src, index), out, atol=ATOL, rtol=RTOL)


def test_log_softmax_backward():
src_sparse = torch.rand(4, 8, requires_grad=True)
index = torch.tensor([0, 0, 1, 1])
src_dense = src_sparse.clone().detach().view(2, 2, src_sparse.size(-1))
src_dense.requires_grad_(True)

out_sparse = log_softmax(src_sparse, index)
out_sparse.sum().backward()
out_dense = torch.log_softmax(src_dense, dim=1)
out_dense.sum().backward()

assert torch.allclose(out_sparse, out_dense.view_as(out_sparse), atol=ATOL)
assert torch.allclose(src_sparse.grad, src_dense.grad.view_as(src_sparse),
atol=ATOL)


def test_log_softmax_dim():
index = torch.tensor([0, 0, 0, 0])
ptr = torch.tensor([0, 4])

src = torch.randn(4)
assert torch.allclose(
log_softmax(src, index, dim=0),
torch.log_softmax(src, dim=0),
atol=ATOL,
rtol=RTOL,
)
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(
log_softmax(src, ptr=ptr, dim=0),
torch.log_softmax(src, dim=0),
atol=ATOL,
rtol=RTOL,
)

src = torch.randn(4, 16)
assert torch.allclose(
log_softmax(src, index, dim=0),
torch.log_softmax(src, dim=0),
atol=ATOL,
rtol=RTOL,
)
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(
log_softmax(src, ptr=ptr, dim=0),
torch.log_softmax(src, dim=0),
atol=ATOL,
rtol=RTOL,
)

src = torch.randn(4, 4)
assert torch.allclose(
log_softmax(src, index, dim=-1),
torch.log_softmax(src, dim=-1),
atol=ATOL,
rtol=RTOL,
)
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(
log_softmax(src, ptr=ptr, dim=-1),
torch.log_softmax(src, dim=-1),
atol=ATOL,
rtol=RTOL,
)

src = torch.randn(4, 4, 16)
assert torch.allclose(
log_softmax(src, index, dim=1),
torch.log_softmax(src, dim=1),
atol=ATOL,
rtol=RTOL,
)
if CALCULATION_VIA_PTR_AVAILABLE:
assert torch.allclose(
log_softmax(src, ptr=ptr, dim=1),
torch.log_softmax(src, dim=1),
atol=ATOL,
rtol=RTOL,
)


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--backward", action="store_true")
args = parser.parse_args()

num_nodes, num_edges = 10_000, 200_000
x = torch.randn(num_edges, 64, device=args.device)
index = torch.randint(num_nodes, (num_edges, ), device=args.device)

compiled_log_softmax = torch.compile(log_softmax)

def dense_softmax(x, index):
x = x.view(num_nodes, -1, x.size(-1))
return x.softmax(dim=-1)

def dense_log_softmax(x, index):
x = x.view(num_nodes, -1, x.size(-1))
return torch.log_softmax(x, dim=-1)

benchmark(
funcs=[dense_log_softmax, log_softmax, compiled_log_softmax],
func_names=["Dense Log Softmax", "Vanilla", "Compiled"],
args=(x, index),
num_steps=50 if args.device == "cpu" else 500,
num_warmups=10 if args.device == "cpu" else 100,
backward=args.backward,
)
50 changes: 30 additions & 20 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,31 @@

try:
import pyg_lib # noqa

WITH_PYG_LIB = True
WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, 'grouped_matmul')
WITH_SEGMM = hasattr(pyg_lib.ops, 'segment_matmul')
if WITH_SEGMM and 'pytest' in sys.modules and torch.cuda.is_available():
WITH_GMM = WITH_PT20 and hasattr(pyg_lib.ops, "grouped_matmul")
WITH_SEGMM = hasattr(pyg_lib.ops, "segment_matmul")
if WITH_SEGMM and "pytest" in sys.modules and torch.cuda.is_available():
# NOTE `segment_matmul` is currently bugged on older NVIDIA cards which
# let our GPU tests on CI crash. Try if this error is present on the
# current GPU and disable `WITH_SEGMM`/`WITH_GMM` if necessary.
# TODO Drop this code block once `segment_matmul` is fixed.
try:
x = torch.randn(3, 4, device='cuda')
ptr = torch.tensor([0, 2, 3], device='cuda')
weight = torch.randn(2, 4, 4, device='cuda')
x = torch.randn(3, 4, device="cuda")
ptr = torch.tensor([0, 2, 3], device="cuda")
weight = torch.randn(2, 4, 4, device="cuda")
out = pyg_lib.ops.segment_matmul(x, ptr, weight)
except RuntimeError:
WITH_GMM = False
WITH_SEGMM = False
WITH_SAMPLED_OP = hasattr(pyg_lib.ops, 'sampled_add')
WITH_SOFTMAX = hasattr(pyg_lib.ops, 'softmax_csr')
WITH_INDEX_SORT = hasattr(pyg_lib.ops, 'index_sort')
WITH_METIS = hasattr(pyg_lib, 'partition')
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ('edge_time' in inspect.signature(
WITH_SAMPLED_OP = hasattr(pyg_lib.ops, "sampled_add")
WITH_SOFTMAX = hasattr(pyg_lib.ops, "softmax_csr")
WITH_LOG_SOFTMAX = hasattr(pyg_lib.ops, "log_softmax_csr")
WITH_INDEX_SORT = hasattr(pyg_lib.ops, "index_sort")
WITH_METIS = hasattr(pyg_lib, "partition")
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = ("edge_time" in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ('edge_weight' in inspect.signature(
WITH_WEIGHTED_NEIGHBOR_SAMPLE = ("edge_weight" in inspect.signature(
pyg_lib.sampler.neighbor_sample).parameters)
except Exception as e:
if not isinstance(e, ImportError): # pragma: no cover
Expand All @@ -72,13 +74,15 @@
WITH_SEGMM = False
WITH_SAMPLED_OP = False
WITH_SOFTMAX = False
WITH_LOG_SOFTMAX = False
WITH_INDEX_SORT = False
WITH_METIS = False
WITH_EDGE_TIME_NEIGHBOR_SAMPLE = False
WITH_WEIGHTED_NEIGHBOR_SAMPLE = False

try:
import torch_scatter # noqa

WITH_TORCH_SCATTER = True
except Exception as e:
if not isinstance(e, ImportError): # pragma: no cover
Expand All @@ -89,8 +93,9 @@

try:
import torch_cluster # noqa

WITH_TORCH_CLUSTER = True
WITH_TORCH_CLUSTER_BATCH_SIZE = 'batch_size' in torch_cluster.knn.__doc__
WITH_TORCH_CLUSTER_BATCH_SIZE = "batch_size" in torch_cluster.knn.__doc__
except Exception as e:
if not isinstance(e, ImportError): # pragma: no cover
warnings.warn(f"An issue occurred while importing 'torch-cluster'. "
Expand All @@ -106,6 +111,7 @@ def __getattr__(self, key: str) -> Any:

try:
import torch_spline_conv # noqa

WITH_TORCH_SPLINE_CONV = True
except Exception as e:
if not isinstance(e, ImportError): # pragma: no cover
Expand All @@ -117,6 +123,7 @@ def __getattr__(self, key: str) -> Any:
try:
import torch_sparse # noqa
from torch_sparse import SparseStorage, SparseTensor

WITH_TORCH_SPARSE = True
except Exception as e:
if not isinstance(e, ImportError): # pragma: no cover
Expand Down Expand Up @@ -169,7 +176,7 @@ def from_edge_index(
sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None,
is_sorted: bool = False,
trust_data: bool = False,
) -> 'SparseTensor':
) -> "SparseTensor":
raise ImportError("'SparseTensor' requires 'torch-sparse'")

@property
Expand All @@ -178,7 +185,7 @@ def storage(self) -> SparseStorage:

@classmethod
def from_dense(self, mat: Tensor,
has_value: bool = True) -> 'SparseTensor':
has_value: bool = True) -> "SparseTensor":
raise ImportError("'SparseTensor' requires 'torch-sparse'")

def size(self, dim: int) -> int:
Expand All @@ -194,11 +201,11 @@ def has_value(self) -> bool:
raise ImportError("'SparseTensor' requires 'torch-sparse'")

def set_value(self, value: Optional[Tensor],
layout: Optional[str] = None) -> 'SparseTensor':
layout: Optional[str] = None) -> "SparseTensor":
raise ImportError("'SparseTensor' requires 'torch-sparse'")

def fill_value(self, fill_value: float,
dtype: Optional[torch.dtype] = None) -> 'SparseTensor':
dtype: Optional[torch.dtype] = None) -> "SparseTensor":
raise ImportError("'SparseTensor' requires 'torch-sparse'")

def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]:
Expand Down Expand Up @@ -248,6 +255,7 @@ def masked_select_nnz(src: SparseTensor, mask: Tensor,

try:
import torch_frame # noqa

WITH_TORCH_FRAME = True
from torch_frame import TensorFrame
except Exception:
Expand All @@ -260,6 +268,7 @@ class TensorFrame: # type: ignore

try:
import intel_extension_for_pytorch # noqa

WITH_IPEX = True
except Exception:
WITH_IPEX = False
Expand All @@ -278,6 +287,7 @@ def __init__(

def t(self) -> Tensor: # Only support accessing its transpose:
from torch_geometric.utils import to_torch_csr_tensor

size = self.size
return to_torch_csr_tensor(
self.edge_index.flip([0]),
Expand All @@ -297,15 +307,15 @@ def t(self) -> Tensor: # Only support accessing its transpose:

NodeOrEdgeType = Union[NodeType, EdgeType]

DEFAULT_REL = 'to'
EDGE_TYPE_STR_SPLIT = '__'
DEFAULT_REL = "to"
EDGE_TYPE_STR_SPLIT = "__"


class EdgeTypeStr(str):
r"""A helper class to construct serializable edge types by merging an edge
type tuple into a single string.
"""
def __new__(cls, *args: Any) -> 'EdgeTypeStr':
def __new__(cls, *args: Any) -> "EdgeTypeStr":
if isinstance(args[0], (list, tuple)):
# Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
args = tuple(args[0])
Expand Down