Skip to content

Commit

Permalink
fix(data): make FeatureStore and GraphStore proper abstract base …
Browse files Browse the repository at this point in the history
…classes (#9002)

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
mananshah99 and rusty1s committed Mar 2, 2024
1 parent 849ca0a commit 87cde51
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions torch_geometric/data/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Async `put` and `get` functionality
"""
import copy
from abc import abstractmethod
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Tuple, Union
Expand Down Expand Up @@ -262,7 +262,7 @@ def __repr__(self) -> str:
# libraries use customized logic during mini-batch for `Mapping` base classes.


class FeatureStore:
class FeatureStore(ABC):
r"""An abstract base class to access features from a remote feature store.
Args:
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/data/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
the graph in interesting manners based on the provided metadata.
"""
import copy
from abc import abstractmethod
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
self.size = size


class GraphStore:
class GraphStore(ABC):
r"""An abstract base class to access edges from a remote graph store.
Args:
Expand Down
8 changes: 4 additions & 4 deletions torch_geometric/datasets/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ def __init__(
assert self.name in ['aifb', 'am', 'mutag', 'bgs']
super().__init__(root, transform, pre_transform,
force_reload=force_reload)
self.load(
self.processed_paths[0],
data_cls=HeteroData if hetero else Data,
)
if hetero:
self.load(self.processed_paths[0], data_cls=HeteroData)
else:
self.load(self.processed_paths[0], data_cls=Data)

@property
def raw_dir(self) -> str:
Expand Down

0 comments on commit 87cde51

Please sign in to comment.