Skip to content

Commit

Permalink
add get_fname helper function (#3659)
Browse files Browse the repository at this point in the history
  • Loading branch information
JiaxuanYou committed Dec 9, 2021
1 parent 560df33 commit 1bf018d
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
4 changes: 3 additions & 1 deletion torch_geometric/graphgym/__init__.py
Expand Up @@ -3,7 +3,8 @@
from .utils import * # noqa
from .checkpoint import load_ckpt, save_ckpt, clean_ckpt
from .cmd_args import parse_args
from .config import cfg, set_cfg, load_cfg, dump_cfg, set_run_dir, set_agg_dir
from .config import (cfg, set_cfg, load_cfg, dump_cfg, set_run_dir,
set_agg_dir, get_fname)
from .init import init_weights
from .loader import create_loader
from .logger import set_printing, create_logger
Expand All @@ -28,6 +29,7 @@
'dump_cfg',
'set_run_dir',
'set_agg_dir',
'get_fname',
'init_weights',
'create_loader',
'set_printing',
Expand Down
25 changes: 16 additions & 9 deletions torch_geometric/graphgym/config.py
Expand Up @@ -489,20 +489,31 @@ def makedirs_rm_exist(dir):
os.makedirs(dir, exist_ok=True)


def set_run_dir(out_dir, fname):
def get_fname(fname):
r"""
Create the directory for each random seed experiment run
Extract filename from file name path
Args:
out_dir (string): Directory for output, specified in :obj:`cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
fname = fname.split('/')[-1]
if fname.endswith('.yaml'):
fname = fname[:-5]
elif fname.endswith('.yml'):
fname = fname[:-4]
return fname


def set_run_dir(out_dir, fname):
r"""
Create the directory for each random seed experiment run
Args:
out_dir (string): Directory for output, specified in :obj:`cfg.out_dir`
fname (string): Filename for the yaml format configuration file
"""
fname = get_fname(fname)
cfg.run_dir = os.path.join(out_dir, fname, str(cfg.seed))
# Make output directory
if cfg.train.auto_resume:
Expand All @@ -521,11 +532,7 @@ def set_agg_dir(out_dir, fname):
fname (string): Filename for the yaml format configuration file
"""
fname = fname.split('/')[-1]
if fname.endswith('.yaml'):
fname = fname[:-5]
elif fname.endswith('.yml'):
fname = fname[:-4]
fname = get_fname(fname)
return os.path.join(out_dir, fname)


Expand Down

0 comments on commit 1bf018d

Please sign in to comment.