Skip to content

Commit

Permalink
Fix map_annotation when annotation is {list,dict} instead of `typ…
Browse files Browse the repository at this point in the history
…ing.{List,Dict}` (#9269)

Fixes this error:

```python
from dataclasses import dataclass
from typing import Dict

from torch_geometric.config_store import register


@register()
@DataClass
class MyClass1:
    something: Dict[str, int]


@register()
@DataClass
class MyClass2:
    something: dict[str, int]
```

```
Traceback (most recent call last):
  File "/home/aki/work/github.com/pyg-team/pytorch_geometric/test_config.py", line 13, in <module>
    @register()
     ^^^^^^^^^^
  File "/home/aki/work/github.com/pyg-team/pytorch_geometric/torch_geometric/config_store.py", line 360, in bounded_register
    register(cls=cls, data_cls=data_cls, group=group, **kwargs)
  File "/home/aki/work/github.com/pyg-team/pytorch_geometric/torch_geometric/config_store.py", line 345, in register
    data_cls = to_dataclass(cls, **kwargs)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aki/work/github.com/pyg-team/pytorch_geometric/torch_geometric/config_store.py", line 255, in to_dataclass
    annotation = map_annotation(annotation, mapping=MAPPING)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aki/work/github.com/pyg-team/pytorch_geometric/torch_geometric/config_store.py", line 170, in map_annotation
    annotation.__args__ = tuple(map_annotation(a, mapping) for a in args)
    ^^^^^^^^^^^^^^^^^^^
AttributeError: readonly attribute
```

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
3 people committed May 2, 2024
1 parent ae185ba commit 2310348
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 13 deletions.
20 changes: 18 additions & 2 deletions test/test_config_store.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Any
from typing import Any, Dict, List, Tuple

from torch_geometric.config_store import (
class_from_dataclass,
clear_config_store,
dataclass_from_class,
fill_config_store,
get_config_store,
map_annotation,
register,
to_dataclass,
)
from torch_geometric.testing import withPackage
from torch_geometric.testing import minPython, withPackage
from torch_geometric.transforms import AddSelfLoops


Expand Down Expand Up @@ -44,6 +45,21 @@ def test_to_dataclass():
"AddSelfLoops')")


@minPython('3.10')
def test_map_annotation():
mapping = {int: Any}
assert map_annotation(dict[str, int], mapping) == dict[str, Any]
assert map_annotation(Dict[str, float], mapping) == Dict[str, float]
assert map_annotation(List[str], mapping) == List[str]
assert map_annotation(List[int], mapping) == List[Any]
assert map_annotation(Tuple[int], mapping) == Tuple[Any]
assert map_annotation(dict[str, int], mapping) == dict[str, Any]
assert map_annotation(dict[str, float], mapping) == dict[str, float]
assert map_annotation(list[str], mapping) == list[str]
assert map_annotation(list[int], mapping) == list[Any]
assert map_annotation(tuple[int], mapping) == tuple[Any]


def test_register():
register(AddSelfLoops, group='transform')
assert 'transform' in get_config_store().repo
Expand Down
16 changes: 11 additions & 5 deletions torch_geometric/config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,18 @@ def map_annotation(
annotation: Any,
mapping: Optional[Dict[Any, Any]] = None,
) -> Any:

origin = getattr(annotation, '__origin__', None)
args = getattr(annotation, '__args__', [])
if origin == Union or origin == list or origin == dict:
annotation = copy.copy(annotation)
annotation.__args__ = tuple(map_annotation(a, mapping) for a in args)
args = getattr(annotation, '__args__', tuple())
if origin in {Union, list, dict, tuple}:
new_args = tuple(map_annotation(a, mapping) for a in args)
if type(annotation).__name__ == 'GenericAlias':
# If annotated with `list[...]` or `dict[...]` (>= Python 3.10):
annotation = origin[new_args]
else:
# If annotated with `typing.List[...]` or `typing.Dict[...]`:
annotation = copy.copy(annotation)
annotation.__args__ = new_args

return annotation

if mapping is not None and annotation in mapping:
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
onlyDistributedTest,
onlyLinux,
noWindows,
onlyPython,
minPython,
onlyCUDA,
onlyXPU,
onlyOnline,
Expand Down Expand Up @@ -40,7 +40,7 @@
'onlyDistributedTest',
'onlyLinux',
'noWindows',
'onlyPython',
'minPython',
'onlyCUDA',
'onlyXPU',
'onlyOnline',
Expand Down
16 changes: 12 additions & 4 deletions torch_geometric/testing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,23 @@ def noWindows(func: Callable) -> Callable:
)(func)


def onlyPython(*args: str) -> Callable:
def minPython(version: str) -> Callable:
r"""A decorator to run tests on specific :python:`Python` versions only."""
def decorator(func: Callable) -> Callable:
import pytest

python_version = f'{sys.version_info.major}.{sys.version_info.minor}'
major, minor = version.split('.')

skip = False
if sys.version_info.major < int(major):
skip = True
if (sys.version_info.major == int(major)
and sys.version_info.minor < int(minor)):
skip = True

return pytest.mark.skipif(
python_version not in args,
reason=f"Python {python_version} not supported",
skip,
reason=f"Python {version} required",
)(func)

return decorator
Expand Down

0 comments on commit 2310348

Please sign in to comment.