Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Let Data and HeteroData implement FeatureStore #4807

Merged
merged 11 commits into from Jun 20, 2022
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
Expand Down
25 changes: 25 additions & 0 deletions test/data/test_data.py
Expand Up @@ -239,3 +239,28 @@ def my_attr1(self, value):
data.my_attr1 = 2
assert 'my_attr1' not in data._store
assert data.my_attr1 == 2


# Feature Store ###############################################################


def test_basic_feature_store():
data = Data()
x = torch.randn(20, 20)

# Put tensor:
assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None)
assert torch.equal(data.x, x)

# Put (modify) tensor slice:
x[15:] = 0
data.put_tensor(0, attr_name='x', index=slice(15, None, None))

# Get tensor:
out = data.get_tensor(attr_name='x', index=None)
assert torch.equal(x, out)

# Remove tensor:
assert 'x' in data.__dict__['_store']
data.remove_tensor(attr_name='x', index=None)
assert 'x' not in data.__dict__['_store']
27 changes: 27 additions & 0 deletions test/data/test_hetero_data.py
Expand Up @@ -400,3 +400,30 @@ def test_hetero_data_to_canonical():

with pytest.raises(TypeError, match="missing 1 required"):
data['user', 'product']


# Feature Store ###############################################################


def test_basic_feature_store():
data = HeteroData()
x = torch.randn(20, 20)

# Put tensor:
assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x',
index=None)
assert torch.equal(data['paper'].x, x)

# Put (modify) tensor slice:
x[15:] = 0
data.put_tensor(0, group_name='paper', attr_name='x',
index=slice(15, None, None))

# Get tensor:
out = data.get_tensor(group_name='paper', attr_name='x', index=None)
assert torch.equal(x, out)

# Remove tensor:
assert 'x' in data['paper'].__dict__['_mapping']
data.remove_tensor(group_name='paper', attr_name='x', index=None)
assert 'x' not in data['paper'].__dict__['_mapping']
10 changes: 9 additions & 1 deletion torch_geometric/data/batch.py
Expand Up @@ -23,8 +23,16 @@ def __call__(cls, *args, **kwargs):
new_cls = base_cls
else:
name = f'{base_cls.__name__}{cls.__name__}'

# NOTE `MetaResolver` is necessary to resolve metaclass conflict
# problems between `DynamicInheritance` and the metaclass of
# `base_cls`. In particular, it creates a new common metaclass
# from the defined metaclasses.
class MetaResolver(type(cls), type(base_cls)):
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
pass

if name not in globals():
globals()[name] = type(name, (cls, base_cls), {})
globals()[name] = MetaResolver(name, (cls, base_cls), {})
new_cls = globals()[name]

params = list(inspect.signature(base_cls.__init__).parameters.items())
Expand Down
67 changes: 65 additions & 2 deletions torch_geometric/data/data.py
@@ -1,5 +1,6 @@
import copy
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
Callable,
Expand All @@ -17,6 +18,12 @@
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data.feature_store import (
FeatureStore,
FeatureTensorType,
TensorAttr,
_field_status,
)
from torch_geometric.data.storage import (
BaseStorage,
EdgeStorage,
Expand Down Expand Up @@ -300,7 +307,16 @@ def contains_self_loops(self) -> bool:
###############################################################################


class Data(BaseData):
@dataclass
class DataTensorAttr(TensorAttr):
r"""Attribute class for `Data`, which does not require a `group_name`."""
def __init__(self, attr_name=_field_status.UNSET,
index=_field_status.UNSET):
# Treat group_name as optional, and move it to the end
super().__init__(None, attr_name, index)


class Data(BaseData, FeatureStore):
r"""A data object describing a homogeneous graph.
The data object can hold node-level, link-level and graph-level attributes.
In general, :class:`~torch_geometric.data.Data` tries to mimic the
Expand Down Expand Up @@ -348,7 +364,10 @@ class Data(BaseData):
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
edge_attr: OptTensor = None, y: OptTensor = None,
pos: OptTensor = None, **kwargs):
super().__init__()
# `Data` doesn't support group_name, so we need to adjust `TensorAttr`
# accordingly here to avoid requiring `group_name` to be set:
super().__init__(attr_cls=DataTensorAttr)

self.__dict__['_store'] = GlobalStorage(_parent=self)

if x is not None:
Expand Down Expand Up @@ -384,6 +403,9 @@ def __setattr__(self, key: str, value: Any):
def __delattr__(self, key: str):
delattr(self._store, key)

# TODO consider supporting the feature store interface for
# __getitem__, __setitem__, and __delitem__ so, for example, we
# can accept key: Union[str, TensorAttr] in __getitem__.
def __getitem__(self, key: str) -> Any:
return self._store[key]

Expand Down Expand Up @@ -692,6 +714,47 @@ def num_faces(self) -> Optional[int]:
return self.face.size(self.__cat_dim__('face', self.face))
return None

# FeatureStore interface ###########################################

def items(self):
r"""Returns an `ItemsView` over the stored attributes in the `Data`
object."""
return self._store.items()
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""Stores a feature tensor in node storage."""
out = getattr(self, attr.attr_name, None)
if out is not None and attr.index is not None:
# Attr name exists, handle index:
out[attr.index] = tensor
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
else:
# No attr name (or None index), just store tensor:
setattr(self, attr.attr_name, tensor)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Obtains a feature tensor from node storage."""
# Retrieve tensor and index accordingly:
tensor = getattr(self, attr.attr_name, None)
if tensor is not None:
# TODO this behavior is a bit odd, since TensorAttr requires that
# we set `index`. So, we assume here that indexing by `None` is
# equivalent to not indexing at all, which is not in line with
# Python semantics.
return tensor[attr.index] if attr.index is not None else tensor
rusty1s marked this conversation as resolved.
Show resolved Hide resolved
return None

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Deletes a feature tensor from node storage."""
# Remove tensor entirely:
if hasattr(self, attr.attr_name):
delattr(self, attr.attr_name)
return True
return False

def __len__(self) -> int:
return BaseData.__len__(self)


###############################################################################

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/feature_store.py
Expand Up @@ -245,7 +245,7 @@ def __init__(self, attr_cls: Any = TensorAttr):
attributes by subclassing :class:`TensorAttr` and passing the subclass
as :obj:`attr_cls`."""
super().__init__()
self._attr_cls = attr_cls
self.__dict__['_attr_cls'] = attr_cls
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

# Core (CRUD) #############################################################

Expand Down
55 changes: 54 additions & 1 deletion torch_geometric/data/hetero_data.py
Expand Up @@ -10,6 +10,11 @@
from torch_sparse import SparseTensor

from torch_geometric.data.data import BaseData, Data, size_repr
from torch_geometric.data.feature_store import (
FeatureStore,
FeatureTensorType,
TensorAttr,
)
from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage
from torch_geometric.typing import EdgeType, NodeType, QueryType
from torch_geometric.utils import bipartite_subgraph, is_undirected
Expand All @@ -18,7 +23,7 @@
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]


class HeteroData(BaseData):
class HeteroData(BaseData, FeatureStore):
r"""A data object describing a heterogeneous graph, holding multiple node
and/or edge types in disjunct storage objects.
Storage objects can hold either node-level, link-level or graph-level
Expand Down Expand Up @@ -92,6 +97,8 @@ class HeteroData(BaseData):
DEFAULT_REL = 'to'

def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs):
super().__init__()

self.__dict__['_global_store'] = BaseStorage(_parent=self)
self.__dict__['_node_store_dict'] = {}
self.__dict__['_edge_store_dict'] = {}
Expand Down Expand Up @@ -616,6 +623,52 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:

return data

# :obj:`FeatureStore` interface ###########################################

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
r"""Stores a feature tensor in node storage."""
if not attr.is_set('index'):
attr.index = None
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved

out = self._node_store_dict.get(attr.group_name, None)
if out:
# Group name exists, handle index or create new attribute name:
val = getattr(out, attr.attr_name)
if val is not None:
val[attr.index] = tensor
else:
setattr(self[attr.group_name], attr.attr_name, tensor)
else:
# No node storage found, just store tensor in new one:
setattr(self[attr.group_name], attr.attr_name, tensor)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Obtains a feature tensor from node storage."""
# Retrieve tensor and index accordingly:
mananshah99 marked this conversation as resolved.
Show resolved Hide resolved
tensor = getattr(self[attr.group_name], attr.attr_name, None)
if tensor is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it requires we set index why do we even need to check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe another way to solve this if its confusing is to have a different value for UNSET which indicates its going to index 'all'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the comment! Not sure what you mean by needing to check index; this is because TensorAttr just requires that its attributes are set (they can be set to None). The current way to index all would be None, which I think is acceptable; although I'm happy to define a custom value for the UNSET enum to indicate all indexing in a follow-up PR.

# TODO this behavior is a bit odd, since TensorAttr requires that
# we set `index`. So, we assume here that indexing by `None` is
# equivalent to not indexing at all, which is not in line with
# Python semantics.
return tensor[attr.index] if attr.index is not None else tensor
return None

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Deletes a feature tensor from node storage."""
# Remove tensor entirely:
if hasattr(self[attr.group_name], attr.attr_name):
delattr(self[attr.group_name], attr.attr_name)
return True
return False

def __len__(self) -> int:
return BaseData.__len__(self)

def __iter__(self):
raise NotImplementedError


# Helper functions ############################################################

Expand Down