Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reverse support in aggregation_resolver #5084

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the `bias` vector to the `GCN` model definition in the "Create Message Passing Networks" tutorial ([#4755](https://github.com/pyg-team/pytorch_geometric/pull/4755))
- Added `transforms.RootedSubgraph` interface with two implementations: `RootedEgoNets` and `RootedRWSubgraph` ([#3926](https://github.com/pyg-team/pytorch_geometric/pull/3926))
- Added `ptr` vectors for `follow_batch` attributes within `Batch.from_data_list` ([#4723](https://github.com/pyg-team/pytorch_geometric/pull/4723))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/issues/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]))
- Added `torch_geometric.nn.aggr` package ([#4687](https://github.com/pyg-team/pytorch_geometric/pull/4687), [#4721](https://github.com/pyg-team/pytorch_geometric/pull/4721), [#4731](https://github.com/pyg-team/pytorch_geometric/pull/4731), [#4762](https://github.com/pyg-team/pytorch_geometric/pull/4762), [#4749](https://github.com/pyg-team/pytorch_geometric/pull/4749), [#4779](https://github.com/pyg-team/pytorch_geometric/pull/4779), [#4863](https://github.com/pyg-team/pytorch_geometric/pull/4863), [#4864](https://github.com/pyg-team/pytorch_geometric/pull/4864), [#4865](https://github.com/pyg-team/pytorch_geometric/pull/4865), [#4866](https://github.com/pyg-team/pytorch_geometric/pull/4866), [#4872](https://github.com/pyg-team/pytorch_geometric/pull/4872), [#4934](https://github.com/pyg-team/pytorch_geometric/pull/4934), [#4935](https://github.com/pyg-team/pytorch_geometric/pull/4935), [#4957](https://github.com/pyg-team/pytorch_geometric/pull/4957), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4973](https://github.com/pyg-team/pytorch_geometric/pull/4973), [#4986](https://github.com/pyg-team/pytorch_geometric/pull/4986), [#4995](https://github.com/pyg-team/pytorch_geometric/pull/4995), [#5000](https://github.com/pyg-team/pytorch_geometric/pull/5000), [#5034](https://github.com/pyg-team/pytorch_geometric/pull/5034), [#5036](https://github.com/pyg-team/pytorch_geometric/pull/5036), [#5039](https://github.com/pyg-team/pytorch_geometric/pull/5039), [#4522](https://github.com/pyg-team/pytorch_geometric/pull/4522), [#5033](https://github.com/pyg-team/pytorch_geometric/pull/5033]), [#5084](https://github.com/pyg-team/pytorch_geometric/pull/5084))
- Added the `DimeNet++` model ([#4432](https://github.com/pyg-team/pytorch_geometric/pull/4432), [#4699](https://github.com/pyg-team/pytorch_geometric/pull/4699), [#4700](https://github.com/pyg-team/pytorch_geometric/pull/4700), [#4800](https://github.com/pyg-team/pytorch_geometric/pull/4800))
- Added an example of using PyG with PyTorch Ignite ([#4487](https://github.com/pyg-team/pytorch_geometric/pull/4487))
- Added `GroupAddRev` module with support for reducing training GPU memory ([#4671](https://github.com/pyg-team/pytorch_geometric/pull/4671), [#4701](https://github.com/pyg-team/pytorch_geometric/pull/4701), [#4715](https://github.com/pyg-team/pytorch_geometric/pull/4715), [#4730](https://github.com/pyg-team/pytorch_geometric/pull/4730))
Expand Down
14 changes: 10 additions & 4 deletions test/nn/conv/test_gen_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,18 @@
import torch
from torch_sparse import SparseTensor

from torch_geometric.nn import GENConv
from torch_geometric.nn import (
GENConv,
PowerMeanAggregation,
SoftmaxAggregation,
)
from torch_geometric.testing import is_full_test


@pytest.mark.parametrize('aggr', ['softmax', 'powermean'])
def test_gen_conv(aggr):
@pytest.mark.parametrize('aggr_tuple', [('softmax', SoftmaxAggregation()),
('powermean', PowerMeanAggregation())])
def test_gen_conv(aggr_tuple):
aggr, aggr_module = aggr_tuple
x1 = torch.randn(4, 16)
x2 = torch.randn(2, 16)
edge_index = torch.tensor([[0, 1, 2, 3], [0, 0, 1, 1]])
Expand All @@ -17,7 +23,7 @@ def test_gen_conv(aggr):
adj2 = SparseTensor(row=row, col=col, value=value, sparse_sizes=(4, 4))

conv = GENConv(16, 32, aggr)
assert conv.__repr__() == f'GENConv(16, 32, aggr={aggr})'
assert conv.__repr__() == f'GENConv(16, 32, aggr={str(aggr_module)})'
out11 = conv(x1, edge_index)
assert out11.size() == (4, 32)
assert torch.allclose(conv(x1, edge_index, size=(4, 4)), out11, atol=1e-6)
Expand Down
2 changes: 1 addition & 1 deletion test/nn/conv/test_sage_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_lstm_aggr_sage_conv():
adj = SparseTensor(row=row, col=col, sparse_sizes=(4, 4))

conv = SAGEConv(8, 32, aggr='lstm')
assert str(conv) == 'SAGEConv(8, 32, aggr=lstm)'
assert str(conv) == 'SAGEConv(8, 32, aggr=LSTMAggregation(8, 8))'
out = conv(x, edge_index)
assert out.size() == (4, 32)
assert torch.allclose(conv(x, adj.t()), out)
Expand Down
30 changes: 17 additions & 13 deletions test/nn/test_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,25 @@ def test_activation_resolver():


@pytest.mark.parametrize('aggr_tuple', [
(torch_geometric.nn.aggr.MeanAggregation, 'mean'),
(torch_geometric.nn.aggr.SumAggregation, 'sum'),
(torch_geometric.nn.aggr.SumAggregation, 'add'),
(torch_geometric.nn.aggr.MaxAggregation, 'max'),
(torch_geometric.nn.aggr.MinAggregation, 'min'),
(torch_geometric.nn.aggr.MulAggregation, 'mul'),
(torch_geometric.nn.aggr.VarAggregation, 'var'),
(torch_geometric.nn.aggr.StdAggregation, 'std'),
(torch_geometric.nn.aggr.SoftmaxAggregation, 'softmax'),
(torch_geometric.nn.aggr.PowerMeanAggregation, 'powermean'),
(torch_geometric.nn.aggr.MeanAggregation(), 'mean', ()),
(torch_geometric.nn.aggr.SumAggregation(), 'sum', ()),
(torch_geometric.nn.aggr.MaxAggregation(), 'max', ()),
(torch_geometric.nn.aggr.MinAggregation(), 'min', ()),
(torch_geometric.nn.aggr.MulAggregation(), 'mul', ()),
(torch_geometric.nn.aggr.VarAggregation(), 'var', ()),
(torch_geometric.nn.aggr.StdAggregation(), 'std', ()),
(torch_geometric.nn.aggr.SoftmaxAggregation(), 'softmax', ()),
(torch_geometric.nn.aggr.PowerMeanAggregation(), 'power_mean', ()),
(torch_geometric.nn.aggr.LSTMAggregation(6, 6), 'lstm', (6, 6)),
(torch_geometric.nn.aggr.Set2Set(6, 6), 'set2set', (6, 6)),
])
def test_aggregation_resolver(aggr_tuple):
aggr_module, aggr_repr = aggr_tuple
assert isinstance(aggregation_resolver(aggr_module()), aggr_module)
assert isinstance(aggregation_resolver(aggr_repr), aggr_module)
aggr_module, aggr_repr, aggr_args = aggr_tuple
aggr_cls = type(aggr_module)
assert isinstance(aggregation_resolver(aggr_module), aggr_cls)
assert isinstance(aggregation_resolver(aggr_repr, *aggr_args), aggr_cls)
assert aggregation_resolver(aggr_module, reverse=True) == aggr_repr
assert aggregation_resolver(aggr_repr, reverse=True) == aggr_repr


@pytest.mark.parametrize('norm_tuple', [
Expand Down
7 changes: 4 additions & 3 deletions torch_geometric/nn/conv/message_passing.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,15 @@ def __init__(
super().__init__()

if aggr is None:
self.aggr = None
self.aggr_module = None
self.aggr = None
elif isinstance(aggr, (str, Aggregation)):
self.aggr = str(aggr)
self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {}))
aggr = aggr_resolver(aggr, reverse=True)
self.aggr = aggr if aggr in FUSE_AGGRS else str(self.aggr_module)
elif isinstance(aggr, (tuple, list)):
self.aggr = [str(x) for x in aggr]
self.aggr_module = MultiAggregation(aggr, **(aggr_kwargs or {}))
self.aggr = str(self.aggr_module)
else:
raise ValueError(
f"Only strings, list, tuples and instances of"
Expand Down
61 changes: 55 additions & 6 deletions torch_geometric/nn/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,15 @@ def normalize_string(s: str) -> str:
return s.lower().replace('-', '').replace('_', '').replace(' ', '')


def camel_to_snake(s: str) -> str:
import re
s = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', s)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s).lower()


# Resolvers ###################################################################


def resolver(classes: List[Any], class_dict: Dict[str, Any],
query: Union[Any, str], base_cls: Optional[Any], *args, **kwargs):

Expand Down Expand Up @@ -36,10 +45,43 @@ def resolver(classes: List[Any], class_dict: Dict[str, Any],
assert callable(cls)
return cls

choices = set(cls.__name__ for cls in classes) | set(class_dict.keys())
choices = set(cls.__name__ for cls in classes + list(class_dict.values()))
raise ValueError(f"Could not resolve '{query}' among choices {choices}")


# Reverse Resolvers ###########################################################


def reverse_resolver(classes: List[Any], class_dict: Dict[Any, str],
query: Union[Any, str], base_cls: Optional[Any]):

if isinstance(query, str):
return query

assert callable(query)
query_cls_repr = camel_to_snake(query.__class__.__name__)

if not base_cls:
return query_cls_repr

base_cls_repr = camel_to_snake(base_cls.__name__)

if not isinstance(query, base_cls):
choices = {base_cls_repr} | set(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we should have a test case covering this? the logic is not trivial

camel_to_snake(cls.__name__).replace(base_cls_repr, '').strip('_')
for cls in classes) | set(class_dict.values())
choices.remove('')
raise ValueError(
f"Could not resolve '{query}' among choices {choices}")

for cls, repr in class_dict.items():
if isinstance(query, cls):
return repr

repr = query_cls_repr.replace(base_cls_repr, '').strip("_")
return repr if repr else base_cls_repr


# Activation Resolver #########################################################


Expand Down Expand Up @@ -80,14 +122,21 @@ def normalization_resolver(query: Union[Any, str], *args, **kwargs):
# Aggregation Resolver ########################################################


def aggregation_resolver(query: Union[Any, str], *args, **kwargs):
def aggregation_resolver(query: Union[Any, str], *args, reverse: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not certain I like having a 'reverse' flag because it makes the logic for 'query' a bit complicated. What do you think about splitting this into reverse_aggregation_resolver?

**kwargs):
import torch_geometric.nn.aggr as aggr
base_cls = aggr.Aggregation
aggrs = [
aggr for aggr in vars(aggr).values()
if isinstance(aggr, type) and issubclass(aggr, base_cls)
]
aggr_dict = {
'add': aggr.SumAggregation,
}
return resolver(aggrs, aggr_dict, query, base_cls, *args, **kwargs)
if not reverse:
aggr_dict = {
'add': aggr.SumAggregation,
}
return resolver(aggrs, aggr_dict, query, base_cls, *args, **kwargs)
else:
aggr_dict = {
aggr.Set2Set: 'set2set',
}
return reverse_resolver(aggrs, aggr_dict, query, base_cls)