Skip to content

Commit

Permalink
Improved handling of NumNeighbors serialization (#6646)
Browse files Browse the repository at this point in the history
Introduces an `EdgeTypeStr` object that is used for serializing
`NumNeighbors` instances.
  • Loading branch information
rusty1s committed Feb 8, 2023
1 parent 9e2c604 commit 1cd0f46
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 70 deletions.
18 changes: 2 additions & 16 deletions test/sampler/test_sampler_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_heterogeneous_num_neighbors_list():
assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]}

values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])
assert values == {'A__B': [25, 10], 'B__A': [25, 10]}
assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]}

assert num_neighbors.num_hops == 2

Expand All @@ -46,20 +46,6 @@ def test_heterogeneous_num_neighbors_dict_and_default():
assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]}

values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')])
assert values == {'A__B': [25, 10], 'B__A': [-1, -1]}
assert values == {'A__to__B': [25, 10], 'B__to__A': [-1, -1]}

assert num_neighbors.num_hops == 2


def test_num_neighbors_config():
num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1])

config = num_neighbors.config()
assert len(config) == 3
assert config['_target_'] == 'torch_geometric.sampler.base.NumNeighbors'
assert config['values'] == {'A__B': [25, 10]}
assert config['default'] == [-1, -1]

num_neighbors = NumNeighbors.from_config(config)
assert num_neighbors.values == {('A', 'B'): [25, 10]}
assert num_neighbors.default == [-1, -1]
36 changes: 36 additions & 0 deletions test/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import pytest

from torch_geometric.typing import EdgeTypeStr


def test_edge_type_str():
edge_type_str = EdgeTypeStr('a__links__b')
assert isinstance(edge_type_str, str)
assert edge_type_str == 'a__links__b'
assert edge_type_str.to_tuple() == ('a', 'links', 'b')

edge_type_str = EdgeTypeStr('a', 'b')
assert isinstance(edge_type_str, str)
assert edge_type_str == 'a__to__b'
assert edge_type_str.to_tuple() == ('a', 'to', 'b')

edge_type_str = EdgeTypeStr(('a', 'b'))
assert isinstance(edge_type_str, str)
assert edge_type_str == 'a__to__b'
assert edge_type_str.to_tuple() == ('a', 'to', 'b')

edge_type_str = EdgeTypeStr('a', 'links', 'b')
assert isinstance(edge_type_str, str)
assert edge_type_str == 'a__links__b'
assert edge_type_str.to_tuple() == ('a', 'links', 'b')

edge_type_str = EdgeTypeStr(('a', 'links', 'b'))
assert isinstance(edge_type_str, str)
assert edge_type_str == 'a__links__b'
assert edge_type_str.to_tuple() == ('a', 'links', 'b')

with pytest.raises(ValueError, match="invalid edge type"):
EdgeTypeStr('a', 'b', 'c', 'd')

with pytest.raises(ValueError, match="Cannot convert the edge type"):
EdgeTypeStr('a__b__c__d').to_tuple()
6 changes: 2 additions & 4 deletions torch_geometric/data/hetero_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from torch_geometric.data.graph_store import EdgeLayout
from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage
from torch_geometric.typing import (
DEFAULT_REL,
EdgeTensorType,
EdgeType,
FeatureTensorType,
Expand Down Expand Up @@ -102,9 +103,6 @@ class HeteroData(BaseData, FeatureStore, GraphStore):
{ 'edge_index': edge_index_author_paper }
})
"""

DEFAULT_REL = 'to'

def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs):
super().__init__()

Expand Down Expand Up @@ -435,7 +433,7 @@ def _to_canonical(self, *args: Tuple[QueryType]) -> NodeOrEdgeType:
args = edge_types[0]
return args
elif len(edge_types) == 0:
args = (args[0], self.DEFAULT_REL, args[1])
args = (args[0], DEFAULT_REL, args[1])
return args

return args
Expand Down
107 changes: 57 additions & 50 deletions torch_geometric/sampler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import Tensor

from torch_geometric.data import Data, FeatureStore, GraphStore, HeteroData
from torch_geometric.typing import EdgeType, NodeType, OptTensor
from torch_geometric.typing import EdgeType, EdgeTypeStr, NodeType, OptTensor
from torch_geometric.utils.mixin import CastMixin


Expand Down Expand Up @@ -193,52 +193,82 @@ class NumNeighbors:
default (List[int], optional): The default number of neighbors for edge
types not specified in :obj:`values`. (default: :obj:`None`)
"""
values: Union[List[int], Dict[EdgeType, List[int]]]
values: Union[List[int], Dict[EdgeTypeStr, List[int]]]
default: Optional[List[int]] = None

def __post_init__(self):
if isinstance(self.values, (tuple, list)) and self.default is not None:
def __init__(
self,
values: Union[List[int], Dict[EdgeType, List[int]]],
default: Optional[List[int]] = None,
):
if isinstance(values, (tuple, list)) and default is not None:
raise ValueError(f"'default' must be set to 'None' in case a "
f"single list is given as the number of "
f"neighbors (got '{type(self.default)})'")
f"neighbors (got '{type(default)})'")

def get_values(
self,
edge_types: Optional[List[EdgeType]] = None,
) -> Union[List[int], Dict[EdgeType, List[int]]]:
r"""Returns the number of neighbors.
if isinstance(values, dict):
values = {EdgeTypeStr(key): value for key, value in values.items()}

Args:
edge_types (List[Tuple[str, str, str]], optional): The edge types
to generate the number of neighbors for. (default: :obj:`None`)
"""
if '_values' in self.__dict__:
return self.__dict__['_values']
# Write to `__dict__` since dataclass is annotated with `frozen=True`:
self.__dict__['values'] = values
self.__dict__['default'] = default

values = self.values
def _get_values(
self,
edge_types: Optional[List[EdgeType]] = None,
mapped: bool = False,
) -> Union[List[int], Dict[Union[EdgeType, EdgeTypeStr], List[int]]]:

if edge_types is not None:
if isinstance(values, (tuple, list)):
default = values
values = {}
else: # isinstance(self.values, dict):
if isinstance(self.values, (tuple, list)):
default = self.values
elif isinstance(self.values, dict):
default = self.default
values = copy.copy(values)
else:
assert False

out = {}
for edge_type in edge_types:
if edge_type not in values:
edge_type_str = EdgeTypeStr(edge_type)
if edge_type_str in self.values:
out[edge_type_str if mapped else edge_type] = (
self.values[edge_type_str])
else:
if default is None:
raise ValueError(f"Missing number of neighbors for "
f"edge type '{edge_type}'")
values[edge_type] = default
out[edge_type_str if mapped else edge_type] = default

if isinstance(values, dict):
num_hops = set(len(v) for v in values.values())
elif isinstance(self.values, dict) and not mapped:
out = {key.to_tuple(): value for key, value in self.values.items()}

else:
out = copy.copy(self.values)

if isinstance(out, dict):
num_hops = set(len(v) for v in out.values())
if len(num_hops) > 1:
raise ValueError(f"Number of hops must be the same across all "
f"edge types (got {len(num_hops)} different "
f"number of hops)")

return out

def get_values(
self,
edge_types: Optional[List[EdgeType]] = None,
) -> Union[List[int], Dict[EdgeType, List[int]]]:
r"""Returns the number of neighbors.
Args:
edge_types (List[Tuple[str, str, str]], optional): The edge types
to generate the number of neighbors for. (default: :obj:`None`)
"""
if '_values' in self.__dict__:
return self.__dict__['_values']

values = self._get_values(edge_types, mapped=False)

self.__dict__['_values'] = values
return values

Expand All @@ -257,9 +287,7 @@ def get_mapped_values(
if '_mapped_values' in self.__dict__:
return self.__dict__['_mapped_values']

values = self.get_values(edge_types)
if isinstance(values, dict):
values = {'__'.join(key): value for key, value in values.items()}
values = self._get_values(edge_types, mapped=True)

self.__dict__['_mapped_values'] = values
return values
Expand All @@ -282,27 +310,6 @@ def __len__(self) -> int:
r"""Returns the number of hops."""
return self.num_hops

def config(self) -> Dict[str, Any]:
values = self.values
if isinstance(values, dict):
values = {'__'.join(k): v for k, v in values.items()}

cls_name = f'{self.__class__.__module__}.{self.__class__.__name__}'

return {
'_target_': cls_name,
'values': values,
'default': self.default,
}

@classmethod
def from_config(cls, cfg: Dict[str, Any]) -> 'NumNeighbors':
values = cfg['values']
if isinstance(values, dict):
values = {tuple(k.split('__')): v for k, v in values.items()}

return cls(values, cfg.get('default'))


class NegativeSamplingMode(Enum):
# 'binary': Randomly sample negative edges in the graph.
Expand Down
38 changes: 38 additions & 0 deletions torch_geometric/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,44 @@ def from_edge_index(cls, *args, **kwargs) -> 'SparseTensor':
# `data[('author', 'writes', 'paper')]
EdgeType = Tuple[str, str, str]

DEFAULT_REL = 'to'
EDGE_TYPE_STR_SPLIT = '__'


class EdgeTypeStr(str):
r"""A helper class to construct serializable edge types by merging an edge
type tuple into a single string."""
def __new__(cls, *args):
if isinstance(args[0], (list, tuple)):
# Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`:
args = tuple(args[0])

if len(args) == 1 and isinstance(args[0], str):
args = args[0] # An edge type string was passed.

elif len(args) == 2 and all(isinstance(arg, str) for arg in args):
# A `(src, dst)` edge type was passed - add `DEFAULT_REL`:
args = (args[0], DEFAULT_REL, args[1])
args = EDGE_TYPE_STR_SPLIT.join(args)

elif len(args) == 3 and all(isinstance(arg, str) for arg in args):
# A `(src, rel, dst)` edge type was passed:
args = EDGE_TYPE_STR_SPLIT.join(args)

else:
raise ValueError(f"Encountered invalid edge type '{args}'")

return str.__new__(cls, args)

def to_tuple(self) -> EdgeType:
r"""Returns the original edge type."""
out = tuple(self.split(EDGE_TYPE_STR_SPLIT))
if len(out) != 3:
raise ValueError(f"Cannot convert the edge type '{self}' to a "
f"tuple since it holds invalid characters")
return out


# There exist some short-cuts to query edge-types (given that the full triplet
# can be uniquely reconstructed, e.g.:
# * via str: `data['writes']`
Expand Down

0 comments on commit 1cd0f46

Please sign in to comment.