Skip to content

Commit

Permalink
Change the way job_dir and run_name are used in eval
Browse files Browse the repository at this point in the history
Both are now required; checkpoints and logs will be loaded/stored to
`<job_dir>/<run_name>`.
  • Loading branch information
Agustín Azzinnari authored and vierja committed Feb 8, 2018
1 parent e813db5 commit c5338ea
Showing 1 changed file with 21 additions and 31 deletions.
52 changes: 21 additions & 31 deletions luminoth/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,27 +14,30 @@
@click.command(help='Evaluate trained (or training) models')
@click.option('dataset_split', '--split', default='val', help='Dataset split to use.') # noqa
@click.option('config_files', '--config', '-c', required=True, multiple=True, help='Config to use.') # noqa
@click.option('--job-dir', help='Directory from where to read saved models and write evaluation logs.') # noqa
@click.option('--watch/--no-watch', default=True, help='Keep watching checkpoint directory for new files.') # noqa
@click.option('--from-global-step', type=int, default=None, help='Consider only checkpoints after this global step') # noqa
@click.option('override_params', '--override', '-o', multiple=True, help='Override model config params.') # noqa
@click.option('--files-per-class', type=int, default=10, help='How many files per class display in every epoch.') # noqa
@click.option('--iou-threshold', type=float, default=0.5, help='IoU threshold to use.') # noqa
@click.option('--min-probability', type=float, default=0.5, help='Minimum probability to use.') # noqa
def evaluate(dataset_split, config_files, job_dir, watch,
from_global_step, override_params, files_per_class,
iou_threshold, min_probability):
"""
Evaluate models using dataset.
"""
def evaluate(dataset_split, config_files, watch, from_global_step,
override_params, files_per_class, iou_threshold, min_probability):
"""Evaluate models using dataset."""

# If the config file is empty, our config will be the base_config for the
# default model.
try:
config = get_config(config_files, override_params=override_params)
except KeyError:
raise KeyError('model.type should be set on the custom config.')

config.train.job_dir = job_dir or config.train.job_dir
if not config.train.job_dir:
raise KeyError('`job_dir` should be set.')
if not config.train.run_name:
raise KeyError('`run_name` should be set.')

# `run_dir` is where the actual checkpoint and logs are located.
run_dir = os.path.join(config.train.job_dir, config.train.run_name)

# Only activate debug for if needed for debug visualization mode.
if not config.train.debug:
Expand Down Expand Up @@ -137,15 +140,15 @@ def evaluate(dataset_split, config_files, job_dir, watch,

# Use global writer for all checkpoints. We don't want to write different
# files for each checkpoint.
writer = tf.summary.FileWriter(config.train.job_dir)
writer = tf.summary.FileWriter(run_dir)

files_to_visualize = {}

last_global_step = from_global_step
while True:
# Get the checkpoint files to evaluate.
try:
checkpoints = get_checkpoints(config, last_global_step)
checkpoints = get_checkpoints(run_dir, last_global_step)
except ValueError as e:
if not watch:
tf.logging.error('Missing checkpoint.')
Expand Down Expand Up @@ -199,11 +202,11 @@ def evaluate(dataset_split, config_files, job_dir, watch,
time.sleep(60)


def get_checkpoints(config, from_global_step=None):
def get_checkpoints(run_dir, from_global_step=None):
"""Return all available checkpoints.
Args:
config: Run configuration file, where the checkpoint dir is present.
run_dir: Directory where the checkpoints are located.
from_global_step (int): Only return checkpoints after this global step.
The comparison is *strict*. If ``None``, returns all available
checkpoints.
Expand All @@ -213,49 +216,36 @@ def get_checkpoints(config, from_global_step=None):
checkpoints found.
Raises:
ValueError: If there are no checkpoints on the ``train.job_dir`` key
of `config`.
ValueError: If there are no checkpoints in ``run_dir``.
"""
# The latest checkpoint file should be the last item of
# `all_model_checkpoint_paths`, according to the CheckpointState protobuf
# definition.
job_dir = config.train.job_dir
if config.train.run_name:
job_dir = os.path.join(job_dir, config.train.run_name)

ckpt = tf.train.get_checkpoint_state(job_dir)
# TODO: Must check if the checkpoints are complete somehow.
ckpt = tf.train.get_checkpoint_state(run_dir)
if not ckpt or not ckpt.all_model_checkpoint_paths:
raise ValueError('Could not find checkpoint in {}.'.format(
job_dir
))
raise ValueError('Could not find checkpoint in {}.'.format(run_dir))

# TODO: Any other way to get the global_step?
checkpoints = [
{'global_step': int(path.split('-')[-1]), 'file': path}
for path in ckpt.all_model_checkpoint_paths
]

# Get the run name from the checkpoint path. Do it before filtering the
# list, as it may end up empty.
# TODO: Can't it be set somewhere else?
config.train.run_name = os.path.split(
os.path.dirname(checkpoints[0]['file'])
)[-1]

if from_global_step is not None:
checkpoints = [
c for c in checkpoints
if c['global_step'] > from_global_step
]

tf.logging.info(
'Found %s checkpoints in job_dir with global_step > %s',
'Found %s checkpoints in run_dir with global_step > %s',
len(checkpoints), from_global_step,
)

else:
tf.logging.info(
'Found {} checkpoints in job_dir'.format(len(checkpoints))
'Found {} checkpoints in run_dir'.format(len(checkpoints))
)

return checkpoints
Expand Down

0 comments on commit c5338ea

Please sign in to comment.