diff --git a/src/vak/cli/__init__.py b/src/vak/cli/__init__.py index 19722ee96..32608ca2d 100644 --- a/src/vak/cli/__init__.py +++ b/src/vak/cli/__init__.py @@ -1,7 +1,7 @@ """command-line interface functions for training, creating learning curves, etc.""" -from . import cli, eval, learncurve, predict, prep, train +from . import cli, eval, learncurve, predict, prep, train, train_checkpoint __all__ = [ @@ -11,4 +11,5 @@ "predict", "prep", "train", + "train_checkpoint", ] diff --git a/src/vak/cli/cli.py b/src/vak/cli/cli.py index 1ea78d808..883f58488 100644 --- a/src/vak/cli/cli.py +++ b/src/vak/cli/cli.py @@ -3,7 +3,7 @@ from .learncurve import learning_curve from .predict import predict from .prep import prep - +from .train_checkpoint import train_checkpoint COMMAND_FUNCTION_MAP = { "prep": prep, @@ -11,6 +11,7 @@ "eval": eval, "predict": predict, "learncurve": learning_curve, + "train_checkpoint": train_checkpoint, } CLI_COMMANDS = tuple(COMMAND_FUNCTION_MAP.keys()) @@ -22,7 +23,7 @@ def cli(command, config_file): Parameters ---------- command : string - One of {'prep', 'train', 'eval', 'predict', 'learncurve'} + One of {'prep', 'train', 'eval', 'predict', 'learncurve', 'train_checkpoint'} config_file : str, Path path to a config.toml file """ diff --git a/src/vak/cli/train_checkpoint.py b/src/vak/cli/train_checkpoint.py new file mode 100644 index 000000000..16d8f5bbc --- /dev/null +++ b/src/vak/cli/train_checkpoint.py @@ -0,0 +1,75 @@ +from pathlib import Path +import shutil + +from .. import config +from .. import core +from .. import logging +from ..paths import generate_results_dir_name_as_path +from ..timenow import get_timenow_as_str + + +def train_checkpoint(toml_path): + """train models using training set specified in config.toml file. + Starts from a checkpoint given in config.toml file. + Function called by command-line interface. + Updated by K.L.Provost 8 Dec 2021 + + Parameters + ---------- + toml_path : str, Path + path to a configuration file in TOML format. + + Returns + ------- + None + + Trains models from checkpoints, saves results in new directory within root_results_dir specified + in config.toml file, and adds path to that new directory to config.toml file. + """ + toml_path = Path(toml_path) + cfg = config.parse.from_toml_path(toml_path) + + if cfg.train_checkpoint is None: + raise ValueError( + f"train_checkpoint called with a config.toml file that does not have a TRAIN_CHECKPOINT section: {toml_path}" + ) + + # ---- set up directory to save output ----------------------------------------------------------------------------- + results_path = generate_results_dir_name_as_path(cfg.train_checkpoint.root_results_dir) + results_path.mkdir(parents=True) + # copy config file into results dir now that we've made the dir + shutil.copy(toml_path, results_path) + + # ---- set up logging ---------------------------------------------------------------------------------------------- + logger = logging.get_logger( + log_dst=results_path, + caller="train_checkpoint", + timestamp=get_timenow_as_str(), + logger_name=__name__, + ) + logger.info("Logging results to {}".format(results_path)) + + model_config_map = config.models.map_from_path(toml_path, cfg.train_checkpoint.models) + + core.train_checkpoint( + model_config_map=model_config_map, + csv_path=cfg.train_checkpoint.csv_path, + labelset=cfg.prep.labelset, + window_size=cfg.dataloader.window_size, + batch_size=cfg.train_checkpoint.batch_size, + num_epochs=cfg.train_checkpoint.num_epochs, + num_workers=cfg.train_checkpoint.num_workers, + checkpoint_path=cfg.train_checkpoint.checkpoint_path, + labelmap_path=cfg.train_checkpoint.labelmap_path, + spect_scaler_path=cfg.train_checkpoint.spect_scaler_path, + results_path=results_path, + spect_key=cfg.spect_params.spect_key, + timebins_key=cfg.spect_params.timebins_key, + normalize_spectrograms=cfg.train_checkpoint.normalize_spectrograms, + shuffle=cfg.train_checkpoint.shuffle, + val_step=cfg.train_checkpoint.val_step, + ckpt_step=cfg.train_checkpoint.ckpt_step, + patience=cfg.train_checkpoint.patience, + device=cfg.train_checkpoint.device, + logger=logger, + ) diff --git a/src/vak/config/__init__.py b/src/vak/config/__init__.py index 132b5366d..c2d0128a6 100644 --- a/src/vak/config/__init__.py +++ b/src/vak/config/__init__.py @@ -10,5 +10,6 @@ prep, spect_params, train, + train_checkpoint, validators, ) diff --git a/src/vak/config/config.py b/src/vak/config/config.py index 9a2dbcca0..a0e10e561 100644 --- a/src/vak/config/config.py +++ b/src/vak/config/config.py @@ -8,6 +8,7 @@ from .prep import PrepConfig from .spect_params import SpectParamsConfig from .train import TrainConfig +from .train_checkpoint import Train_CheckpointConfig @attr.s @@ -30,6 +31,8 @@ class Config: represents ``[PREDICT]`` section of config.toml file. learncurve : vak.config.learncurve.LearncurveConfig represents ``[LEARNCURVE]`` section of config.toml file + train_checkpoint : vak.config.train_checkpoint.Train_CheckpointConfig + represents ``[TRAIN_CHECKPOINT]`` section of config.toml file """ spect_params = attr.ib( @@ -41,6 +44,7 @@ class Config: prep = attr.ib(validator=optional(instance_of(PrepConfig)), default=None) train = attr.ib(validator=optional(instance_of(TrainConfig)), default=None) + train_checkpoint = attr.ib(validator=optional(instance_of(Train_CheckpointConfig)), default=None) eval = attr.ib(validator=optional(instance_of(EvalConfig)), default=None) predict = attr.ib(validator=optional(instance_of(PredictConfig)), default=None) learncurve = attr.ib( diff --git a/src/vak/config/parse.py b/src/vak/config/parse.py index 46e9086e3..8de7f97a0 100644 --- a/src/vak/config/parse.py +++ b/src/vak/config/parse.py @@ -11,6 +11,7 @@ from .prep import PrepConfig from .spect_params import SpectParamsConfig from .train import TrainConfig +from .train_checkpoint import Train_CheckpointConfig from .validators import are_sections_valid, are_options_valid SECTION_CLASSES = { @@ -21,6 +22,7 @@ "PREP": PrepConfig, "SPECT_PARAMS": SpectParamsConfig, "TRAIN": TrainConfig, + "TRAIN_CHECKPOINT": Train_CheckpointConfig, } REQUIRED_OPTIONS = { @@ -51,6 +53,12 @@ "models", "root_results_dir", ], + "TRAIN_CHECKPOINT": [ + "models", + "root_results_dir", + "checkpoint_path", +> "labelmap_path", + ], } diff --git a/src/vak/config/train_checkpoint.py b/src/vak/config/train_checkpoint.py new file mode 100644 index 000000000..65e9319b9 --- /dev/null +++ b/src/vak/config/train_checkpoint.py @@ -0,0 +1,130 @@ +"""parses [TRAIN_CHECKPOINT] section of config""" +import attr +from attr import converters, validators +from attr.validators import instance_of + +from .validators import is_a_directory, is_a_file, is_valid_model_name +from .. import device +from ..converters import bool_from_str, comma_separated_list, expanded_user_path + + +@attr.s +class Train_CheckpointConfig: + """class that represents [TRAIN_CHECKPOINT] section of config.toml file + Updated by K.L.Provost 8 Dec 2021 + + Attributes + ---------- + models : list + comma-separated list of model names. + e.g., 'models = TweetyNet, GRUNet, ConvNet' + csv_path : str + path to where dataset was saved as a csv. + num_epochs : int + number of training epochs. One epoch = one iteration through the entire + training set. + batch_size : int + number of samples per batch presented to models during training. + root_results_dir : str + directory in which results will be created. + The vak.cli.train function will create + a subdirectory in this directory each time it runs. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. + device : str + Device on which to work with model + data. + Defaults to 'cuda' if torch.cuda.is_available is True. + shuffle: bool + if True, shuffle training data before each epoch. Default is True. + normalize_spectrograms : bool + if True, use spect.utils.data.SpectScaler to normalize the spectrograms. + Normalization is done by subtracting off the mean for each frequency bin + of the training set and then dividing by the std for that frequency bin. + This same normalization is then applied to validation + test data. + val_step : int + Step on which to estimate accuracy using validation set. + If val_step is n, then validation is carried out every time + the global step / n is a whole number, i.e., when val_step modulo the global step is 0. + Default is None, in which case no validation is done. + ckpt_step : int + Step on which to save to checkpoint file. + If ckpt_step is n, then a checkpoint is saved every time + the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. + Default is None, in which case checkpoint is only saved at the last epoch. + patience : int + number of validation steps to wait without performance on the + validation set improving before stopping the training. + Default is None, in which case training only stops after the specified number of epochs. + checkpoint_path : str + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str + path to 'labelmap.json' file. + spect_scaler_path : str + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + """ + + # required + models = attr.ib( + converter=comma_separated_list, + validator=[instance_of(list), is_valid_model_name], + ) + num_epochs = attr.ib(converter=int, validator=instance_of(int)) + batch_size = attr.ib(converter=int, validator=instance_of(int)) + root_results_dir = attr.ib(converter=expanded_user_path, validator=is_a_directory) + + checkpoint_path = attr.ib(converter=expanded_user_path, validator=is_a_file) + labelmap_path = attr.ib(converter=expanded_user_path, validator=is_a_file) + + # optional + # csv_path is actually 'required' but we can't enforce that here because cli.prep looks at + # what sections are defined to figure out where to add csv_path after it creates the csv + csv_path = attr.ib( + converter=converters.optional(expanded_user_path), + validator=validators.optional(is_a_file), + default=None, + ) + + results_dirname = attr.ib( + converter=converters.optional(expanded_user_path), + validator=validators.optional(is_a_directory), + default=None, + ) + + normalize_spectrograms = attr.ib( + converter=bool_from_str, + validator=validators.optional(instance_of(bool)), + default=False, + ) + + num_workers = attr.ib(validator=instance_of(int), default=2) + device = attr.ib(validator=instance_of(str), default=device.get_default()) + shuffle = attr.ib( + converter=bool_from_str, validator=instance_of(bool), default=True + ) + + val_step = attr.ib( + converter=converters.optional(int), + validator=validators.optional(instance_of(int)), + default=None, + ) + ckpt_step = attr.ib( + converter=converters.optional(int), + validator=validators.optional(instance_of(int)), + default=None, + ) + patience = attr.ib( + converter=converters.optional(int), + validator=validators.optional(instance_of(int)), + default=None, + ) + + spect_scaler_path = attr.ib( + converter=converters.optional(expanded_user_path), + validator=validators.optional(is_a_file), + default=None, + ) + + diff --git a/src/vak/core/__init__.py b/src/vak/core/__init__.py index 936154f8e..a23804791 100644 --- a/src/vak/core/__init__.py +++ b/src/vak/core/__init__.py @@ -3,3 +3,4 @@ from .predict import predict from .prep import prep from .train import train +from .train_checkpoint import train_checkpoint diff --git a/src/vak/core/prep.py b/src/vak/core/prep.py index e169e07dc..5adbd2a9a 100644 --- a/src/vak/core/prep.py +++ b/src/vak/core/prep.py @@ -14,6 +14,7 @@ "learncurve", "predict", "train", + "train_checkpoint", ] ) @@ -43,7 +44,7 @@ def prep( data_dir : str, Path path to directory with files from which to make dataset purpose : str - one of {'train', 'predict', 'learncurve'} + one of {'train', 'predict', 'learncurve', 'train_checkpoint'} output_dir : str Path to location where data sets should be saved. Default is None, in which case data sets to `data_dir`. @@ -173,7 +174,7 @@ def prep( # ---- figure out if we're going to split into train / val / test sets --------------------------------------------- # catch case where user specified duration for just training set, raise a helpful error instead of failing silently - if (purpose == "train" or purpose == "learncurve") and ( + if (purpose == "train" or purpose == "learncurve" or purpose == "train_checkpoint") and ( (train_dur is not None and train_dur > 0) and (val_dur is None or val_dur == 0) and (test_dur is None or val_dur == 0) diff --git a/src/vak/core/train_checkpoint.py b/src/vak/core/train_checkpoint.py new file mode 100644 index 000000000..a2adf9143 --- /dev/null +++ b/src/vak/core/train_checkpoint.py @@ -0,0 +1,333 @@ +import json +from pathlib import Path + +import joblib +import pandas as pd +import torch.utils.data + +from .. import csv +from .. import labels +from .. import models +from .. import tensorboard +from .. import transforms +from ..datasets.window_dataset import WindowDataset +from ..datasets.vocal_dataset import VocalDataset +from ..device import get_default as get_default_device +from ..io import dataframe +from ..logging import log_or_print +from ..paths import generate_results_dir_name_as_path + + +def train_checkpoint( + model_config_map, + csv_path, + labelset, + window_size, + batch_size, + num_epochs, + num_workers, + checkpoint_path, + labelmap_path, + spect_scaler_path=None, + root_results_dir=None, + results_path=None, + spect_key="s", + timebins_key="t", + normalize_spectrograms=True, + spect_id_vector=None, + spect_inds_vector=None, + x_inds=None, + shuffle=True, + val_step=None, + ckpt_step=None, + patience=None, + device=None, + logger=None, +): + """train models using training set specified in config.toml file from a checkpoint. + Updated by K.L.Provost 8 Dec 2021 + + Parameters + ---------- + model_config_map : dict + where each key-value pair is model name : dict of config parameters + csv_path : str + path to where dataset was saved as a csv. + labelset : set + of str or int, the set of labels that correspond to annotated segments + that a network should learn to segment and classify. Note that if there + are segments that are not annotated, e.g. silent gaps between songbird + syllables, then `vak` will assign a dummy label to those segments + -- you don't have to give them a label here. + window_size : int + size of windows taken from spectrograms, in number of time bins, + shonw to neural networks + batch_size : int + number of samples per batch presented to models during training. + num_epochs : int + number of training epochs. One epoch = one iteration through the entire + training set. + num_workers : int + Number of processes to use for parallel loading of data. + Argument to torch.DataLoader. + root_results_dir : str, pathlib.Path + Root directory in which a new directory will be created where results will be saved. + results_path : str, pathlib.Path + Directory where results will be saved. If specified, this parameter overrides root_results_dir. + spect_key : str + key for accessing spectrogram in files. Default is 's'. + timebins_key : str + key for accessing vector of time bins in files. Default is 't'. + device : str + Device on which to work with model + data. + Default is None. If None, then a device will be selected with vak.split.get_default. + That function defaults to 'cuda' if torch.cuda.is_available is True. + shuffle: bool + if True, shuffle training data before each epoch. Default is True. + normalize_spectrograms : bool + if True, use spect.utils.data.SpectScaler to normalize the spectrograms. + Normalization is done by subtracting off the mean for each frequency bin + of the training set and then dividing by the std for that frequency bin. + This same normalization is then applied to validation + test data. + spect_id_vector : numpy.ndarray + Parameter for WindowDataset. Represents the 'id' of any spectrogram, + i.e., the index into spect_paths that will let us load it. + Default is None. + spect_inds_vector : numpy.ndarray + Parameter for WindowDataset. Same length as spect_id_vector + but values represent indices within each spectrogram. + Default is None. + x_inds : numpy.ndarray + Parameter for WindowDataset. + Indices of each window in the dataset. The value at x[0] + represents the start index of the first window; using that + value, we can index into spect_id_vector to get the path + of the spectrogram file to load, and we can index into + spect_inds_vector to index into the spectrogram itself + and get the window. + Default is None. + val_step : int + Step on which to estimate accuracy using validation set. + If val_step is n, then validation is carried out every time + the global step / n is a whole number, i.e., when val_step modulo the global step is 0. + Default is None, in which case no validation is done. + ckpt_step : int + Step on which to save to checkpoint file. + If ckpt_step is n, then a checkpoint is saved every time + the global step / n is a whole number, i.e., when ckpt_step modulo the global step is 0. + Default is None, in which case checkpoint is only saved at the last epoch. + patience : int + number of validation steps to wait without performance on the + validation set improving before stopping the training. + Default is None, in which case training only stops after the specified number of epochs. + checkpoint_path : str, pathlib.Path + path to directory with checkpoint files saved by Torch, to reload model + labelmap_path : str, pathlib.Path + path to 'labelmap.json' file. + spect_scaler_path : str, pathlib.Path + path to a saved SpectScaler object used to normalize spectrograms. + If spectrograms were normalized and this is not provided, will give + incorrect results. + Default is None. + + Other Parameters + ---------------- + logger : logging.Logger + instance created by vak.logging.get_logger. Default is None. + + Returns + ------- + None + + Trains models, saves results in new directory within root_results_dir + """ + log_or_print( + f"Loading dataset from .csv path: {csv_path}", logger=logger, level="info" + ) + dataset_df = pd.read_csv(csv_path) + # ---------------- pre-conditions ---------------------------------------------------------------------------------- + if val_step and not dataset_df["split"].str.contains("val").any(): + raise ValueError( + f"val_step set to {val_step} but dataset does not contain a validation set; " + f"please run `vak prep` with a config.toml file that specifies a duration for the validation set." + ) + + # ---- set up directory to save output ----------------------------------------------------------------------------- + if results_path: + results_path = Path(results_path).expanduser().resolve() + if not results_path.is_dir(): + raise NotADirectoryError( + f"results_path not recognized as a directory: {results_path}" + ) + else: + results_path = generate_results_dir_name_as_path(root_results_dir) + results_path.mkdir() + + timebin_dur = dataframe.validate_and_get_timebin_dur(dataset_df) + log_or_print( + f"Size of timebin in spectrograms from dataset, in seconds: {timebin_dur}", + logger=logger, + level="info", + ) + + # ---------------- load training data ----------------------------------------------------------------------------- + log_or_print(f"using training dataset from {csv_path}", logger=logger, level="info") + # below, if we're going to train network to predict unlabeled segments, then + # we need to include a class for those unlabeled segments in labelmap, + # the mapping from labelset provided by user to a set of consecutive + # integers that the network learns to predict + train_dur = dataframe.split_dur(dataset_df, "train") + log_or_print( + f"Total duration of training split from dataset (in s): {train_dur}", + logger=logger, + level="info", + ) + + has_unlabeled = csv.has_unlabeled(csv_path, labelset, timebins_key) + if has_unlabeled: + map_unlabeled = True + else: + map_unlabeled = False + #labelmap = labels.to_map(labelset, map_unlabeled=map_unlabeled) + #log_or_print( + # f"number of classes in labelmap: {len(labelmap)}", logger=logger, level="info" + #) + ## save labelmap in case we need it later + #with open(results_path.joinpath("labelmap.json"), "w") as f: + # json.dump(labelmap, f) + + log_or_print( + f"loading labelmap from path: {labelmap_path}", logger=logger, level="info" + ) + with labelmap_path.open("r") as f: + labelmap = json.load(f) + + if spect_scaler_path: + log_or_print( + f"loading spect scaler from path: {spect_scaler_path}", + logger=logger, + level="info", + ) + spect_standardizer = joblib.load(spect_scaler_path) + else: + log_or_print( + f"not using a spect scaler", + logger=logger, + level="info", + ) + spect_standardizer = None + + # get transforms just before creating datasets with them + #if normalize_spectrograms: + # # we instantiate this transform here because we want to save it + # # and don't want to add more parameters to `transforms.split.get_defaults` function + # # and make too tight a coupling between this function and that one. + # # Trade off is that this is pretty verbose (even ignoring my comments) + # log_or_print("will normalize spectrograms", logger=logger, level="info") + # spect_standardizer = transforms.StandardizeSpect.fit_df( + # dataset_df, spect_key=spect_key + # ) + # joblib.dump(spect_standardizer, results_path.joinpath("StandardizeSpect")) + #else: + # spect_standardizer = None + transform, target_transform = transforms.get_defaults("train", spect_standardizer) + + train_dataset = WindowDataset.from_csv( + csv_path=csv_path, + x_inds=x_inds, + spect_id_vector=spect_id_vector, + spect_inds_vector=spect_inds_vector, + split="train", + labelmap=labelmap, + window_size=window_size, + spect_key=spect_key, + timebins_key=timebins_key, + transform=transform, + target_transform=target_transform, + ) + log_or_print( + f"Duration of WindowDataset used for training, in seconds: {train_dataset.duration()}", + logger=logger, + level="info", + ) + train_data = torch.utils.data.DataLoader( + dataset=train_dataset, + shuffle=shuffle, + batch_size=batch_size, + num_workers=num_workers, + ) + + # ---------------- load validation set (if there is one) ----------------------------------------------------------- + if val_step: + item_transform = transforms.get_defaults( + "eval", + spect_standardizer, + window_size=window_size, + return_padding_mask=True, + ) + val_dataset = VocalDataset.from_csv( + csv_path=csv_path, + split="val", + labelmap=labelmap, + spect_key=spect_key, + timebins_key=timebins_key, + item_transform=item_transform, + ) + val_data = torch.utils.data.DataLoader( + dataset=val_dataset, + shuffle=False, + # batch size 1 because each spectrogram reshaped into a batch of windows + batch_size=1, + num_workers=num_workers, + ) + val_dur = dataframe.split_dur(dataset_df, "val") + log_or_print( + f"Total duration of validation split from dataset (in s): {val_dur}", + logger=logger, + level="info", + ) + + log_or_print( + f"will measure error on validation set every {val_step} steps of training", + logger=logger, + level="info", + ) + else: + val_data = None + + if device is None: + device = get_default_device() + + models_map = models.from_model_config_map( + model_config_map, + num_classes=len(labelmap), + input_shape=train_dataset.shape, + logger=logger, + ) + for model_name, model in models_map.items(): + log_or_print( + f"loading checkpoint for {model_name} from path: {checkpoint_path}", + logger=logger, + level="info", + ) + model.load(checkpoint_path, device=device) + ## makes a new folder so as to not overwrite the old checkpoint + results_model_root = results_path.joinpath(model_name) + results_model_root.mkdir() + ckpt_root = results_model_root.joinpath("checkpoints") + ckpt_root.mkdir() + log_or_print(f"training from checkpoint {model_name}", logger=logger, level="info") + writer = tensorboard.get_summary_writer( + log_dir=results_model_root, filename_suffix=model_name + ) + model.summary_writer = writer + model.fit( + train_data=train_data, + num_epochs=num_epochs, + ckpt_root=ckpt_root, + val_data=val_data, + val_step=val_step, + ckpt_step=ckpt_step, + patience=patience, + device=device, + )