Skip to content

Commit

Permalink
Let Data and HeteroData implement FeatureStore (#4807)
Browse files Browse the repository at this point in the history
  • Loading branch information
mananshah99 committed Jun 20, 2022
1 parent c13d62c commit 4b30b6d
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [2.0.5] - 2022-MM-DD
### Added
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
Expand Down
25 changes: 25 additions & 0 deletions test/data/test_data.py
Expand Up @@ -239,3 +239,28 @@ def my_attr1(self, value):
data.my_attr1 = 2
assert 'my_attr1' not in data._store
assert data.my_attr1 == 2


# Feature Store ###############################################################


def test_basic_feature_store():
data = Data()
x = torch.randn(20, 20)

# Put tensor:
assert data.put_tensor(copy.deepcopy(x), attr_name='x', index=None)
assert torch.equal(data.x, x)

# Put (modify) tensor slice:
x[15:] = 0
data.put_tensor(0, attr_name='x', index=slice(15, None, None))

# Get tensor:
out = data.get_tensor(attr_name='x', index=None)
assert torch.equal(x, out)

# Remove tensor:
assert 'x' in data.__dict__['_store']
data.remove_tensor(attr_name='x', index=None)
assert 'x' not in data.__dict__['_store']
27 changes: 27 additions & 0 deletions test/data/test_hetero_data.py
Expand Up @@ -400,3 +400,30 @@ def test_hetero_data_to_canonical():

with pytest.raises(TypeError, match="missing 1 required"):
data['user', 'product']


# Feature Store ###############################################################


def test_basic_feature_store():
data = HeteroData()
x = torch.randn(20, 20)

# Put tensor:
assert data.put_tensor(copy.deepcopy(x), group_name='paper', attr_name='x',
index=None)
assert torch.equal(data['paper'].x, x)

# Put (modify) tensor slice:
x[15:] = 0
data.put_tensor(0, group_name='paper', attr_name='x',
index=slice(15, None, None))

# Get tensor:
out = data.get_tensor(group_name='paper', attr_name='x', index=None)
assert torch.equal(x, out)

# Remove tensor:
assert 'x' in data['paper'].__dict__['_mapping']
data.remove_tensor(group_name='paper', attr_name='x', index=None)
assert 'x' not in data['paper'].__dict__['_mapping']
10 changes: 9 additions & 1 deletion torch_geometric/data/batch.py
Expand Up @@ -23,8 +23,16 @@ def __call__(cls, *args, **kwargs):
new_cls = base_cls
else:
name = f'{base_cls.__name__}{cls.__name__}'

# NOTE `MetaResolver` is necessary to resolve metaclass conflict
# problems between `DynamicInheritance` and the metaclass of
# `base_cls`. In particular, it creates a new common metaclass
# from the defined metaclasses.
class MetaResolver(type(cls), type(base_cls)):
pass

if name not in globals():
globals()[name] = type(name, (cls, base_cls), {})
globals()[name] = MetaResolver(name, (cls, base_cls), {})
new_cls = globals()[name]

params = list(inspect.signature(base_cls.__init__).parameters.items())
Expand Down
67 changes: 65 additions & 2 deletions torch_geometric/data/data.py
@@ -1,5 +1,6 @@
import copy
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from typing import (
Any,
Callable,
Expand All @@ -17,6 +18,12 @@
from torch import Tensor
from torch_sparse import SparseTensor

from torch_geometric.data.feature_store import (
FeatureStore,
FeatureTensorType,
TensorAttr,
_field_status,
)
from torch_geometric.data.storage import (
BaseStorage,
EdgeStorage,
Expand Down Expand Up @@ -300,7 +307,16 @@ def contains_self_loops(self) -> bool:
###############################################################################


class Data(BaseData):
@dataclass
class DataTensorAttr(TensorAttr):
r"""Attribute class for `Data`, which does not require a `group_name`."""
def __init__(self, attr_name=_field_status.UNSET,
index=_field_status.UNSET):
# Treat group_name as optional, and move it to the end
super().__init__(None, attr_name, index)


class Data(BaseData, FeatureStore):
r"""A data object describing a homogeneous graph.
The data object can hold node-level, link-level and graph-level attributes.
In general, :class:`~torch_geometric.data.Data` tries to mimic the
Expand Down Expand Up @@ -348,7 +364,10 @@ class Data(BaseData):
def __init__(self, x: OptTensor = None, edge_index: OptTensor = None,
edge_attr: OptTensor = None, y: OptTensor = None,
pos: OptTensor = None, **kwargs):
super().__init__()
# `Data` doesn't support group_name, so we need to adjust `TensorAttr`
# accordingly here to avoid requiring `group_name` to be set:
super().__init__(attr_cls=DataTensorAttr)

self.__dict__['_store'] = GlobalStorage(_parent=self)

if x is not None:
Expand Down Expand Up @@ -384,6 +403,9 @@ def __setattr__(self, key: str, value: Any):
def __delattr__(self, key: str):
delattr(self._store, key)

# TODO consider supporting the feature store interface for
# __getitem__, __setitem__, and __delitem__ so, for example, we
# can accept key: Union[str, TensorAttr] in __getitem__.
def __getitem__(self, key: str) -> Any:
return self._store[key]

Expand Down Expand Up @@ -692,6 +714,47 @@ def num_faces(self) -> Optional[int]:
return self.face.size(self.__cat_dim__('face', self.face))
return None

# FeatureStore interface ###########################################

def items(self):
r"""Returns an `ItemsView` over the stored attributes in the `Data`
object."""
return self._store.items()

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""Stores a feature tensor in node storage."""
out = getattr(self, attr.attr_name, None)
if out is not None and attr.index is not None:
# Attr name exists, handle index:
out[attr.index] = tensor
else:
# No attr name (or None index), just store tensor:
setattr(self, attr.attr_name, tensor)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Obtains a feature tensor from node storage."""
# Retrieve tensor and index accordingly:
tensor = getattr(self, attr.attr_name, None)
if tensor is not None:
# TODO this behavior is a bit odd, since TensorAttr requires that
# we set `index`. So, we assume here that indexing by `None` is
# equivalent to not indexing at all, which is not in line with
# Python semantics.
return tensor[attr.index] if attr.index is not None else tensor
return None

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Deletes a feature tensor from node storage."""
# Remove tensor entirely:
if hasattr(self, attr.attr_name):
delattr(self, attr.attr_name)
return True
return False

def __len__(self) -> int:
return BaseData.__len__(self)


###############################################################################

Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/data/feature_store.py
Expand Up @@ -245,7 +245,7 @@ def __init__(self, attr_cls: Any = TensorAttr):
attributes by subclassing :class:`TensorAttr` and passing the subclass
as :obj:`attr_cls`."""
super().__init__()
self._attr_cls = attr_cls
self.__dict__['_attr_cls'] = attr_cls

# Core (CRUD) #############################################################

Expand Down
55 changes: 54 additions & 1 deletion torch_geometric/data/hetero_data.py
Expand Up @@ -10,6 +10,11 @@
from torch_sparse import SparseTensor

from torch_geometric.data.data import BaseData, Data, size_repr
from torch_geometric.data.feature_store import (
FeatureStore,
FeatureTensorType,
TensorAttr,
)
from torch_geometric.data.storage import BaseStorage, EdgeStorage, NodeStorage
from torch_geometric.typing import EdgeType, NodeType, QueryType
from torch_geometric.utils import bipartite_subgraph, is_undirected
Expand All @@ -18,7 +23,7 @@
NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage]


class HeteroData(BaseData):
class HeteroData(BaseData, FeatureStore):
r"""A data object describing a heterogeneous graph, holding multiple node
and/or edge types in disjunct storage objects.
Storage objects can hold either node-level, link-level or graph-level
Expand Down Expand Up @@ -92,6 +97,8 @@ class HeteroData(BaseData):
DEFAULT_REL = 'to'

def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs):
super().__init__()

self.__dict__['_global_store'] = BaseStorage(_parent=self)
self.__dict__['_node_store_dict'] = {}
self.__dict__['_edge_store_dict'] = {}
Expand Down Expand Up @@ -616,6 +623,52 @@ def _consistent_size(stores: List[BaseStorage]) -> List[str]:

return data

# :obj:`FeatureStore` interface ###########################################

def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool:
r"""Stores a feature tensor in node storage."""
if not attr.is_set('index'):
attr.index = None

out = self._node_store_dict.get(attr.group_name, None)
if out:
# Group name exists, handle index or create new attribute name:
val = getattr(out, attr.attr_name)
if val is not None:
val[attr.index] = tensor
else:
setattr(self[attr.group_name], attr.attr_name, tensor)
else:
# No node storage found, just store tensor in new one:
setattr(self[attr.group_name], attr.attr_name, tensor)
return True

def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]:
r"""Obtains a feature tensor from node storage."""
# Retrieve tensor and index accordingly:
tensor = getattr(self[attr.group_name], attr.attr_name, None)
if tensor is not None:
# TODO this behavior is a bit odd, since TensorAttr requires that
# we set `index`. So, we assume here that indexing by `None` is
# equivalent to not indexing at all, which is not in line with
# Python semantics.
return tensor[attr.index] if attr.index is not None else tensor
return None

def _remove_tensor(self, attr: TensorAttr) -> bool:
r"""Deletes a feature tensor from node storage."""
# Remove tensor entirely:
if hasattr(self[attr.group_name], attr.attr_name):
delattr(self[attr.group_name], attr.attr_name)
return True
return False

def __len__(self) -> int:
return BaseData.__len__(self)

def __iter__(self):
raise NotImplementedError


# Helper functions ############################################################

Expand Down

0 comments on commit 4b30b6d

Please sign in to comment.