Skip to content

Commit

Permalink
Remove MutableMapping base class in FeatureStore (#5210)
Browse files Browse the repository at this point in the history
* update

* changelog

* linting

* add a todo

* typo
  • Loading branch information
rusty1s committed Aug 17, 2022
1 parent d82e224 commit 07bf02f
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 27 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added documentation to the `DataLoaderIterator` class ([#4838](https://github.com/pyg-team/pytorch_geometric/pull/4838))
- Added `GraphStore` support to `Data` and `HeteroData` ([#4816](https://github.com/pyg-team/pytorch_geometric/pull/4816))
- Added `FeatureStore` support to `Data` and `HeteroData` ([#4807](https://github.com/pyg-team/pytorch_geometric/pull/4807), [#4853](https://github.com/pyg-team/pytorch_geometric/pull/4853))
- Added `FeatureStore` and `GraphStore` abstractions ([#4534](https://github.com/pyg-team/pytorch_geometric/pull/4534), [#4568](https://github.com/pyg-team/pytorch_geometric/pull/4568))
- Added `FeatureStore` and `GraphStore` abstractions ([#4534](https://github.com/pyg-team/pytorch_geometric/pull/4534), [#4568](https://github.com/pyg-team/pytorch_geometric/pull/4568), [#5120](https://github.com/pyg-team/pytorch_geometric/pull/5120))
- Added support for dense aggregations in `global_*_pool` ([#4827](https://github.com/pyg-team/pytorch_geometric/pull/4827))
- Added Python version requirement ([#4825](https://github.com/pyg-team/pytorch_geometric/pull/4825))
- Added TorchScript support to `JumpingKnowledge` module ([#4805](https://github.com/pyg-team/pytorch_geometric/pull/4805))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'yacs',
'hydra-core',
'protobuf<4.21',
'pytorch-lightning==1.6.*',
'pytorch-lightning',
]

full_requires = graphgym_requires + [
Expand Down
10 changes: 8 additions & 2 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
"""
import copy
from abc import abstractmethod
from collections.abc import MutableMapping
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
Expand Down Expand Up @@ -241,7 +240,14 @@ def __repr__(self) -> str:
f'attr={self._attr})')


class FeatureStore(MutableMapping):
# TODO (manan, matthias) Ideally, we want to let `FeatureStore` inherit from
# `MutableMapping` to clearly indicate its behavior and usage to the user.
# However, having `MutableMapping` as a base class leads to strange behavior
# in combination with PyTorch and PyTorch Lightning, in particular since these
# libraries use customized logic during mini-batch for `Mapping` base classes.


class FeatureStore:
def __init__(self, tensor_attr_cls: Any = TensorAttr):
r"""Initializes the feature store. Implementor classes can customize
the ordering and required nature of their :class:`TensorAttr` tensor
Expand Down
23 changes: 0 additions & 23 deletions torch_geometric/loader/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Mapping, Sequence
from inspect import signature
from typing import List, Optional, Union

import torch.utils.data
Expand Down Expand Up @@ -40,28 +39,6 @@ def collate(self, batch): # Deprecated...
return self(batch)


# PyG 'Data' objects are subclasses of MutableMapping, which is an
# instance of collections.abc.Mapping. Currently, PyTorch pin_memory
# for DataLoaders treats the returned batches as Mapping objects and
# calls `pin_memory` on each element in `Data.__dict__`, which is not
# desired behavior if 'Data' has a `pin_memory` function. We patch
# this behavior here by monkeypatching `pin_memory`, but can hopefully patch
# this in PyTorch in the future:
__torch_pin_memory = torch.utils.data._utils.pin_memory.pin_memory
__torch_pin_memory_params = signature(__torch_pin_memory).parameters


def pin_memory(data, device=None):
if hasattr(data, "pin_memory"):
return data.pin_memory()
if len(__torch_pin_memory_params) > 1:
return __torch_pin_memory(data, device)
return __torch_pin_memory(data)


torch.utils.data._utils.pin_memory.pin_memory = pin_memory


class DataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges data objects from a
:class:`torch_geometric.data.Dataset` to a mini-batch.
Expand Down

0 comments on commit 07bf02f

Please sign in to comment.