Skip to content

Commit

Permalink
Add doc & fixups for cli/data/optim/samplers/tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
plstcharles committed Oct 12, 2018
1 parent 1d71fca commit 196c1a9
Show file tree
Hide file tree
Showing 7 changed files with 820 additions and 38 deletions.
84 changes: 81 additions & 3 deletions src/thelper/cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
"""
Command-line module, for use with __main__ entrypoint or external apps.
Command-line module, for use with a ``__main__`` entrypoint.
This module contains the primary functions used to create or resume a training session. The three
basic arguments that need to be provided by the user to create a session are the session configuration
(dict), the dataset root directory (string), and the save root directory (string). See the docstrings
of :func:`thelper.cli.create_session` and :func:`thelper.cli.resume_session` for more information.
"""

import argparse
Expand All @@ -16,14 +21,32 @@


def create_session(config, data_root, save_dir):
"""Creates a session to train a model.
All generated outputs (model checkpoints and logs) will be saved in a directory named after the
session (the name itself is specified in ``config``), and located in ``save_dir``.
Args:
config: a dictionary that provides all required data configuration and trainer parameters; see
:class:`thelper.train.Trainer` and :func:`thelper.data.load` for more information. Here, it
is only expected to contain a ``name`` field that specifies the name of the session.
data_root: the path to the dataset root directory that will be passed to the dataset interfaces
for them to figure out where the training/validation/testing data is located. This path may
be unused if the dataset interfaces already know where to look via config parameters.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
"""
logger = thelper.utils.get_func_logger()
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
logger.info("Creating new training session '%s'..." % session_name)
if "cudnn_benchmark" in config and thelper.utils.str2bool(config["cudnn_benchmark"]):
logger.debug("activating benchmark mode for cudnn")
torch.backends.cudnn.benchmark = True
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config)
logger.info("Creating new training session '%s'..." % session_name)
logger.debug("session will be saved at '%s'" % save_dir)
task, train_loader, valid_loader, test_loader = thelper.data.load(config, data_root, save_dir)
model = thelper.modules.load_model(config, task, save_dir)
loaders = (train_loader, valid_loader, test_loader)
Expand All @@ -35,6 +58,20 @@ def create_session(config, data_root, save_dir):


def visualize_data(config, data_root):
"""Displays the images used in a training session.
This mode does not generate any output, and is only used to visualize the transformed images used
in a training session. This is useful to debug the data augmentation and base transformation pipelines
and make sure the modified images are valid. It does not attempt to load a model or instantiate a
trainer, meaning the related fields are not required inside ``config``.
Args:
config: a dictionary that provides all required data configuration parameters; see
:func:`thelper.data.load` for more information.
data_root: the path to the dataset root directory that will be passed to the dataset interfaces
for them to figure out where the training/validation/testing data is located. This path may
be unused if the dataset interfaces already know where to look via config parameters.
"""
logger = thelper.utils.get_func_logger()
logger.info("creating visualization session...")
task, train_loader, valid_loader, test_loader = thelper.data.load(config, data_root)
Expand Down Expand Up @@ -68,6 +105,34 @@ def visualize_data(config, data_root):


def resume_session(ckptdata, data_root, save_dir, config=None, eval_only=False):
"""Resumes a previously created training session.
Since the saved checkpoints contain the original session's configuration, the ``config`` argument
can be set to ``None`` if the session should simply pick up where it was interrupted. Otherwise,
the ``config`` argument can be set to a new configuration that will override the older one. This is
useful when fine-tuning a model, or when testing on a new dataset.
.. warning::
If a session is resumed with an overriding configuration, the user must make sure that the
inputs/outputs of the older model are compatible with the new parameters. For example, with
classifiers, this means that the number of classes parsed by the dataset (and thus to be
predicted by the model) should remain the same. This is a limitation of the framework that
should be addressed in a future update.
Args:
ckptdata: raw checkpoint data loaded via ``torch.load()``; it will be parsed by the various
parts of the framework that need to reload their previous state.
data_root: the path to the dataset root directory that will be passed to the dataset interfaces
for them to figure out where the training/validation/testing data is located. This path may
be unused if the dataset interfaces already know where to look via config parameters.
save_dir: the path to the root directory where the session directory should be saved. Note that
this is not the path to the session directory itself, but its parent, which may also contain
other session directories.
config: a dictionary that provides all required data configuration and trainer parameters; see
:class:`thelper.train.Trainer` and :func:`thelper.data.load` for more information. Here, it
is only expected to contain a ``name`` field that specifies the name of the session.
eval_only: specifies whether training should be resumed or the model should only be evaluated.
"""
logger = thelper.utils.get_func_logger()
if not config:
if "config" not in ckptdata or not ckptdata["config"]:
Expand All @@ -76,10 +141,12 @@ def resume_session(ckptdata, data_root, save_dir, config=None, eval_only=False):
if "name" not in config or not config["name"]:
raise AssertionError("config missing 'name' field")
session_name = config["name"]
logger.info("loading training session '%s' objects..." % session_name)
if "cudnn_benchmark" in config and thelper.utils.str2bool(config["cudnn_benchmark"]):
logger.debug("activating benchmark mode for cudnn")
torch.backends.cudnn.benchmark = True
save_dir = thelper.utils.get_save_dir(save_dir, session_name, config, resume=True)
logger.info("loading training session '%s' objects..." % session_name)
logger.debug("session will be saved at '%s'" % save_dir)
task, train_loader, valid_loader, test_loader = thelper.data.load(config, data_root, save_dir)
if "task" not in ckptdata:
logger.warning("cannot verify that checkpoint task is same as current task, might cause key or class mapping issues")
Expand All @@ -99,6 +166,17 @@ def resume_session(ckptdata, data_root, save_dir, config=None, eval_only=False):


def main(args=None):
"""Main entrypoint to use with console applications.
This function parses command line arguments and dispatches the execution based on the selected
operating mode (new session, resume session, or visualize). Run with ``--help`` for information
on the available arguments.
.. seealso::
:func:`thelper.cli.create_session`
:func:`thelper.cli.resume_session`
:func:`thelper.cli.visualize_data`
"""
ap = argparse.ArgumentParser(description='thelper model trainer application')
ap.add_argument("--version", default=False, action="store_true", help="prints the version of the library and exits")
ap.add_argument("-l", "--log", default="thelper.log", type=str, help="path to the top-level log file (default: 'thelper.log')")
Expand Down

0 comments on commit 196c1a9

Please sign in to comment.