Skip to content

Commit

Permalink
GraphGym: register_dataset (#3782)
Browse files Browse the repository at this point in the history
* dataset

* typo

* reset
  • Loading branch information
rusty1s committed Dec 31, 2021
1 parent 2ac815f commit 1a5e65c
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
6 changes: 4 additions & 2 deletions torch_geometric/graphgym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
from .register import (register_base, register_act, register_node_encoder,
register_edge_encoder, register_stage, register_head,
register_layer, register_pooling, register_network,
register_config, register_loader, register_optimizer,
register_scheduler, register_loss, register_train)
register_config, register_dataset, register_loader,
register_optimizer, register_scheduler, register_loss,
register_train)

__all__ = classes = [
'load_ckpt',
Expand Down Expand Up @@ -50,6 +51,7 @@
'register_pooling',
'register_network',
'register_config',
'register_dataset',
'register_loader',
'register_optimizer',
'register_scheduler',
Expand Down
12 changes: 8 additions & 4 deletions torch_geometric/graphgym/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,16 +547,20 @@ def set_agg_dir(out_dir, fname):


def from_config(func):
params = inspect.signature(func).parameters
arg_names = list(params.keys())
defaults = [p.default != inspect.Parameter.empty for p in params.values()]
if inspect.isclass(func):
params = list(inspect.signature(func.__init__).parameters.values())[1:]
else:
params = list(inspect.signature(func).parameters.values())

arg_names = [p.name for p in params]
has_defaults = [p.default != inspect.Parameter.empty for p in params]

@functools.wraps(func)
def wrapper(*args, cfg: Any = None, **kwargs):
if cfg is not None:
cfg = dict(cfg) if isinstance(cfg, Iterable) else asdict(cfg)

iterator = zip(arg_names[len(args):], defaults[len(args):])
iterator = zip(arg_names[len(args):], has_defaults[len(args):])
for arg_name, has_default in iterator:
if arg_name in kwargs:
continue
Expand Down
12 changes: 12 additions & 0 deletions torch_geometric/graphgym/loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Callable

import torch
from torch_geometric.loader import DataLoader

Expand All @@ -20,6 +22,16 @@
from torch_geometric.utils import negative_sampling


def planetoid_dataset(name: str) -> Callable:
return lambda root: Planetoid(root, name)


register.register_dataset('Cora', planetoid_dataset('Cora'))
register.register_dataset('CiteSeer', planetoid_dataset('CiteSeer'))
register.register_dataset('PubMed', planetoid_dataset('PubMed'))
register.register_dataset('PPI', PPI)


def load_pyg(name, dataset_dir):
"""
Load PyG dataset objects. (More PyG datasets will be supported)
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/graphgym/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
pooling_dict: Dict[str, Any] = {}
network_dict: Dict[str, Any] = {}
config_dict: Dict[str, Any] = {}
dataset_dict: Dict[str, Any] = {}
loader_dict: Dict[str, Any] = {}
optimizer_dict: Dict[str, Any] = {}
scheduler_dict: Dict[str, Any] = {}
Expand Down Expand Up @@ -86,6 +87,11 @@ def register_config(key: str, module: Any = None):
return register_base(config_dict, key, module)


def register_dataset(key: str, module: Any = None):
r"""Registers a dataset in GraphGym."""
return register_base(dataset_dict, key, module)


def register_loader(key: str, module: Any = None):
r"""Registers a data loader in GraphGym."""
return register_base(loader_dict, key, module)
Expand Down

0 comments on commit 1a5e65c

Please sign in to comment.