Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Adding training checkpoint feature #504

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/vak/cli/__init__.py
@@ -1,7 +1,7 @@
"""command-line interface functions for training,
creating learning curves, etc."""

from . import cli, eval, learncurve, predict, prep, train
from . import cli, eval, learncurve, predict, prep, train, train_checkpoint


__all__ = [
Expand All @@ -11,4 +11,5 @@
"predict",
"prep",
"train",
"train_checkpoint",
]
5 changes: 3 additions & 2 deletions src/vak/cli/cli.py
Expand Up @@ -3,14 +3,15 @@
from .learncurve import learning_curve
from .predict import predict
from .prep import prep

from .train_checkpoint import train_checkpoint

COMMAND_FUNCTION_MAP = {
"prep": prep,
"train": train,
"eval": eval,
"predict": predict,
"learncurve": learning_curve,
"train_checkpoint": train_checkpoint,
}

CLI_COMMANDS = tuple(COMMAND_FUNCTION_MAP.keys())
Expand All @@ -22,7 +23,7 @@ def cli(command, config_file):
Parameters
----------
command : string
One of {'prep', 'train', 'eval', 'predict', 'learncurve'}
One of {'prep', 'train', 'eval', 'predict', 'learncurve', 'train_checkpoint'}
config_file : str, Path
path to a config.toml file
"""
Expand Down
75 changes: 75 additions & 0 deletions src/vak/cli/train_checkpoint.py
@@ -0,0 +1,75 @@
from pathlib import Path
import shutil

from .. import config
from .. import core
from .. import logging
from ..paths import generate_results_dir_name_as_path
from ..timenow import get_timenow_as_str


def train_checkpoint(toml_path):
"""train models using training set specified in config.toml file.
Starts from a checkpoint given in config.toml file.
Function called by command-line interface.
Updated by K.L.Provost 8 Dec 2021

Parameters
----------
toml_path : str, Path
path to a configuration file in TOML format.

Returns
-------
None

Trains models from checkpoints, saves results in new directory within root_results_dir specified
in config.toml file, and adds path to that new directory to config.toml file.
"""
toml_path = Path(toml_path)
cfg = config.parse.from_toml_path(toml_path)

if cfg.train_checkpoint is None:
raise ValueError(
f"train_checkpoint called with a config.toml file that does not have a TRAIN_CHECKPOINT section: {toml_path}"
)

# ---- set up directory to save output -----------------------------------------------------------------------------
results_path = generate_results_dir_name_as_path(cfg.train_checkpoint.root_results_dir)
results_path.mkdir(parents=True)
# copy config file into results dir now that we've made the dir
shutil.copy(toml_path, results_path)

# ---- set up logging ----------------------------------------------------------------------------------------------
logger = logging.get_logger(
log_dst=results_path,
caller="train_checkpoint",
timestamp=get_timenow_as_str(),
logger_name=__name__,
)
logger.info("Logging results to {}".format(results_path))

model_config_map = config.models.map_from_path(toml_path, cfg.train_checkpoint.models)

core.train_checkpoint(
model_config_map=model_config_map,
csv_path=cfg.train_checkpoint.csv_path,
labelset=cfg.prep.labelset,
window_size=cfg.dataloader.window_size,
batch_size=cfg.train_checkpoint.batch_size,
num_epochs=cfg.train_checkpoint.num_epochs,
num_workers=cfg.train_checkpoint.num_workers,
checkpoint_path=cfg.train_checkpoint.checkpoint_path,
labelmap_path=cfg.train_checkpoint.labelmap_path,
spect_scaler_path=cfg.train_checkpoint.spect_scaler_path,
results_path=results_path,
spect_key=cfg.spect_params.spect_key,
timebins_key=cfg.spect_params.timebins_key,
normalize_spectrograms=cfg.train_checkpoint.normalize_spectrograms,
shuffle=cfg.train_checkpoint.shuffle,
val_step=cfg.train_checkpoint.val_step,
ckpt_step=cfg.train_checkpoint.ckpt_step,
patience=cfg.train_checkpoint.patience,
device=cfg.train_checkpoint.device,
logger=logger,
)
1 change: 1 addition & 0 deletions src/vak/config/__init__.py
Expand Up @@ -10,5 +10,6 @@
prep,
spect_params,
train,
train_checkpoint,
validators,
)
4 changes: 4 additions & 0 deletions src/vak/config/config.py
Expand Up @@ -8,6 +8,7 @@
from .prep import PrepConfig
from .spect_params import SpectParamsConfig
from .train import TrainConfig
from .train_checkpoint import Train_CheckpointConfig


@attr.s
Expand All @@ -30,6 +31,8 @@ class Config:
represents ``[PREDICT]`` section of config.toml file.
learncurve : vak.config.learncurve.LearncurveConfig
represents ``[LEARNCURVE]`` section of config.toml file
train_checkpoint : vak.config.train_checkpoint.Train_CheckpointConfig
represents ``[TRAIN_CHECKPOINT]`` section of config.toml file
"""

spect_params = attr.ib(
Expand All @@ -41,6 +44,7 @@ class Config:

prep = attr.ib(validator=optional(instance_of(PrepConfig)), default=None)
train = attr.ib(validator=optional(instance_of(TrainConfig)), default=None)
train_checkpoint = attr.ib(validator=optional(instance_of(Train_CheckpointConfig)), default=None)
eval = attr.ib(validator=optional(instance_of(EvalConfig)), default=None)
predict = attr.ib(validator=optional(instance_of(PredictConfig)), default=None)
learncurve = attr.ib(
Expand Down
8 changes: 8 additions & 0 deletions src/vak/config/parse.py
Expand Up @@ -11,6 +11,7 @@
from .prep import PrepConfig
from .spect_params import SpectParamsConfig
from .train import TrainConfig
from .train_checkpoint import Train_CheckpointConfig
from .validators import are_sections_valid, are_options_valid

SECTION_CLASSES = {
Expand All @@ -21,6 +22,7 @@
"PREP": PrepConfig,
"SPECT_PARAMS": SpectParamsConfig,
"TRAIN": TrainConfig,
"TRAIN_CHECKPOINT": Train_CheckpointConfig,
}

REQUIRED_OPTIONS = {
Expand Down Expand Up @@ -51,6 +53,12 @@
"models",
"root_results_dir",
],
"TRAIN_CHECKPOINT": [
"models",
"root_results_dir",
"checkpoint_path",
> "labelmap_path",
],
}


Expand Down
130 changes: 130 additions & 0 deletions src/vak/config/train_checkpoint.py
@@ -0,0 +1,130 @@
"""parses [TRAIN_CHECKPOINT] section of config"""
import attr
from attr import converters, validators
from attr.validators import instance_of

from .validators import is_a_directory, is_a_file, is_valid_model_name
from .. import device
from ..converters import bool_from_str, comma_separated_list, expanded_user_path


@attr.s
class Train_CheckpointConfig:
"""class that represents [TRAIN_CHECKPOINT] section of config.toml file
Updated by K.L.Provost 8 Dec 2021

Attributes
----------
models : list
comma-separated list of model names.
e.g., 'models = TweetyNet, GRUNet, ConvNet'
csv_path : str
path to where dataset was saved as a csv.
num_epochs : int
number of training epochs. One epoch = one iteration through the entire
training set.
batch_size : int
number of samples per batch presented to models during training.
root_results_dir : str
directory in which results will be created.
The vak.cli.train function will create
a subdirectory in this directory each time it runs.
num_workers : int
Number of processes to use for parallel loading of data.
Argument to torch.DataLoader.
device : str
Device on which to work with model + data.
Defaults to 'cuda' if torch.cuda.is_available is True.
shuffle: bool
if True, shuffle training data before each epoch. Default is True.
normalize_spectrograms : bool
if True, use spect.utils.data.SpectScaler to normalize the spectrograms.
Normalization is done by subtracting off the mean for each frequency bin
of the training set and then dividing by the std for that frequency bin.
This same normalization is then applied to validation + test data.
val_step : int
Step on which to estimate accuracy using validation set.
If val_step is n, then validation is carried out every time
the global step / n is a whole number, i.e., when val_step modulo the global step is 0.
Default is None, in which case no validation is done.
ckpt_step : int
Step on which to save to checkpoint file.
If ckpt_step is n, then a checkpoint is saved every time
the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0.
Default is None, in which case checkpoint is only saved at the last epoch.
patience : int
number of validation steps to wait without performance on the
validation set improving before stopping the training.
Default is None, in which case training only stops after the specified number of epochs.
checkpoint_path : str
path to directory with checkpoint files saved by Torch, to reload model
labelmap_path : str
path to 'labelmap.json' file.
spect_scaler_path : str
path to a saved SpectScaler object used to normalize spectrograms.
If spectrograms were normalized and this is not provided, will give
incorrect results.
"""

# required
models = attr.ib(
converter=comma_separated_list,
validator=[instance_of(list), is_valid_model_name],
)
num_epochs = attr.ib(converter=int, validator=instance_of(int))
batch_size = attr.ib(converter=int, validator=instance_of(int))
root_results_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory)

checkpoint_path = attr.ib(converter=expanded_user_path, validator=is_a_file)
labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file)

# optional
# csv_path is actually 'required' but we can't enforce that here because cli.prep looks at
# what sections are defined to figure out where to add csv_path after it creates the csv
csv_path = attr.ib(
converter=converters.optional(expanded_user_path),
validator=validators.optional(is_a_file),
default=None,
)

results_dirname = attr.ib(
converter=converters.optional(expanded_user_path),
validator=validators.optional(is_a_directory),
default=None,
)

normalize_spectrograms = attr.ib(
converter=bool_from_str,
validator=validators.optional(instance_of(bool)),
default=False,
)

num_workers = attr.ib(validator=instance_of(int), default=2)
device = attr.ib(validator=instance_of(str), default=device.get_default())
shuffle = attr.ib(
converter=bool_from_str, validator=instance_of(bool), default=True
)

val_step = attr.ib(
converter=converters.optional(int),
validator=validators.optional(instance_of(int)),
default=None,
)
ckpt_step = attr.ib(
converter=converters.optional(int),
validator=validators.optional(instance_of(int)),
default=None,
)
patience = attr.ib(
converter=converters.optional(int),
validator=validators.optional(instance_of(int)),
default=None,
)

spect_scaler_path = attr.ib(
converter=converters.optional(expanded_user_path),
validator=validators.optional(is_a_file),
default=None,
)


1 change: 1 addition & 0 deletions src/vak/core/__init__.py
Expand Up @@ -3,3 +3,4 @@
from .predict import predict
from .prep import prep
from .train import train
from .train_checkpoint import train_checkpoint
5 changes: 3 additions & 2 deletions src/vak/core/prep.py
Expand Up @@ -14,6 +14,7 @@
"learncurve",
"predict",
"train",
"train_checkpoint",
]
)

Expand Down Expand Up @@ -43,7 +44,7 @@ def prep(
data_dir : str, Path
path to directory with files from which to make dataset
purpose : str
one of {'train', 'predict', 'learncurve'}
one of {'train', 'predict', 'learncurve', 'train_checkpoint'}
output_dir : str
Path to location where data sets should be saved.
Default is None, in which case data sets to `data_dir`.
Expand Down Expand Up @@ -173,7 +174,7 @@ def prep(

# ---- figure out if we're going to split into train / val / test sets ---------------------------------------------
# catch case where user specified duration for just training set, raise a helpful error instead of failing silently
if (purpose == "train" or purpose == "learncurve") and (
if (purpose == "train" or purpose == "learncurve" or purpose == "train_checkpoint") and (
(train_dur is not None and train_dur > 0)
and (val_dur is None or val_dur == 0)
and (test_dur is None or val_dur == 0)
Expand Down