Skip to content

Commit

Permalink
Add split check to guarantee reprod train/eval
Browse files Browse the repository at this point in the history
  • Loading branch information
plstcharles committed Oct 19, 2018
1 parent 8c9f3f1 commit 08a7690
Showing 1 changed file with 78 additions and 24 deletions.
102 changes: 78 additions & 24 deletions src/thelper/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import platform
import sys
import time
from abc import ABC
from abc import abstractmethod
Expand Down Expand Up @@ -147,28 +148,78 @@ def load(config, data_root, save_dir=None):
if not isinstance(datasets_config, dict):
raise AssertionError("invalid datasets config type")
datasets, task = load_datasets(datasets_config, data_root, data_config.get_base_transforms())
logger.info("task info: %s" % str(task))
if save_dir is not None:
with open(os.path.join(save_dir, "logs", "task.log"), "a+") as fd:
fd.write("session: %s-%s\n" % (session_name, logstamp))
fd.write(str(task) + "\n")
for dataset_name, dataset in datasets.items():
logger.info("dataset '%s' info: %s" % (dataset_name, str(dataset)))
if save_dir is not None:
with open(os.path.join(save_dir, "logs", dataset_name + ".log"), "a+") as fd:
fd.write("session: %s-%s\n" % (session_name, logstamp))
fd.write(str(dataset) + "\n")
if hasattr(dataset, "samples") and isinstance(dataset.samples, list):
for idx, sample in enumerate(dataset.samples):
fd.write("%d: %s\n" % (idx, str(sample)))
logger.debug("splitting datasets and creating loaders")
logger.info("task info: %s" % str(task))
logger.debug("splitting datasets and creating loaders...")
train_idxs, valid_idxs, test_idxs = data_config.get_split(datasets, task)
if save_dir is not None:
with open(os.path.join(save_dir, "logs", "split.log"), "a+") as fd:
with open(os.path.join(save_dir, "logs", "task.log"), "a+") as fd:
fd.write("session: %s-%s\n" % (session_name, logstamp))
fd.write("train:\n" + str(train_idxs) + "\n")
fd.write("valid:\n" + str(valid_idxs) + "\n")
fd.write("test:\n" + str(test_idxs) + "\n")
fd.write(str(task) + "\n")
for dataset_name, dataset in datasets.items():
dataset_log_file = os.path.join(save_dir, "logs", dataset_name + ".log")
if not data_config.skip_verif and os.path.isfile(dataset_log_file):
logger.info("verifying sample list for dataset '%s'..." % dataset_name)
with open(dataset_log_file, "r") as fd:
log_content = fd.read()
if not log_content or log_content[0] != "{":
# could not find new style (json) dataset log, cannot easily parse and compare this log
logger.warning("cannot verify that old split is similar to new split, log is out-of-date")
continue
log_content = json.loads(log_content)
if "samples" not in log_content or not isinstance(log_content["samples"], list):
raise AssertionError("unexpected dataset log content (bad 'samples' field)")
samples_old = log_content["samples"]
samples_new = dataset.samples if hasattr(dataset, "samples") and isinstance(dataset.samples, list) else []
if len(samples_old) != len(samples_new):
answer = thelper.utils.query_yes_no(
"Old sample list for dataset '%s' mismatch with current sample list; proceed anyway?")
if not answer:
logger.error("sample list mismatch with previous run; user aborted")
sys.exit(1)
break
else:
breaking = False
for set_name, idxs in zip(["train_idxs", "valid_idxs", "test_idxs"],
[train_idxs[dataset_name], valid_idxs[dataset_name], test_idxs[dataset_name]]):
# index values were paired in tuples earlier, 0=idx, 1=label
if log_content[set_name] != [idx for idx, _ in idxs]:
answer = thelper.utils.query_yes_no(
"Old indices list for dataset '%s' mismatch with current indices list ('%s'); proceed anyway?"
% (dataset_name, set_name))
if not answer:
logger.error("indices list mismatch with previous run; user aborted")
sys.exit(1)
breaking = True
break
if not breaking:
for idx, (sample_new, sample_old) in enumerate(zip(samples_new, samples_old)):
if str(sample_new) != sample_old:
answer = thelper.utils.query_yes_no(
"Old sample #%d for dataset '%s' mismatch with current #%d; proceed anyway?"
"\n\told: %s\n\tnew: %s" % (idx, dataset_name, idx, str(sample_old), str(sample_new)))
if not answer:
logger.error("sample list mismatch with previous run; user aborted")
sys.exit(1)
break
for dataset_name, dataset in datasets.items():
dataset_log_file = os.path.join(save_dir, "logs", dataset_name + ".log")
samples = dataset.samples if hasattr(dataset, "samples") and isinstance(dataset.samples, list) else []
log_content = {
"metadata": {
"session_name": session_name,
"logstamp": logstamp,
"dataset": str(dataset),
},
"samples": [str(sample) for sample in samples],
# index values were paired in tuples earlier, 0=idx, 1=label
"train_idxs": [idx for idx, _ in train_idxs[dataset_name]],
"valid_idxs": [idx for idx, _ in valid_idxs[dataset_name]],
"test_idxs": [idx for idx, _ in test_idxs[dataset_name]]
}
# now, always overwrite, as it can get too big otherwise
json.dump(log_content, open(dataset_log_file, "w"), indent=4, sort_keys=False)
train_loader, valid_loader, test_loader = data_config.get_loaders(datasets, train_idxs, valid_idxs, test_idxs)
return task, train_loader, valid_loader, test_loader

Expand Down Expand Up @@ -303,6 +354,10 @@ class distribution. See :class:`thelper.samplers.WeightedSubsetRandomSampler` fo
validation data loader. These proportions are given in a dictionary format (``name: ratio``).
- ``test_split`` (optional): provides the proportion of samples of each dataset to hand off to the
test data loader. These proportions are given in a dictionary format (``name: ratio``).
- ``skip_verif`` (optional, default=True): specifies whether the dataset split should be verified
if resuming a session by parsing the log files generated earlier.
- ``skip_split_norm`` (optional, default=False): specifies whether the question about normalizing
the split ratios should be skipped or not.
.. seealso::
:func:`thelper.data.load`
Expand Down Expand Up @@ -399,12 +454,13 @@ def get_ratios_split(prefix, config):
if not self.train_split and not self.valid_split and not self.test_split:
raise AssertionError("data config must define a split for at least one loader type (train/valid/test)")
self.total_usage = Counter(self.train_split) + Counter(self.valid_split) + Counter(self.test_split)
self.skip_split_norm = thelper.utils.str2bool(config["skip_split_norm"]) if "skip_split_norm" in config else False
for name, usage in self.total_usage.items():
if usage != 1:
normalize_ratios = None
if usage < 0:
raise AssertionError("ratio should never be negative...")
elif usage > 0 and usage < 1:
elif 0 < usage < 1 and not self.skip_split_norm:
time.sleep(0.25) # to make sure all debug/info prints are done, and we see the question
normalize_ratios = thelper.utils.query_yes_no(
"Dataset split for '%s' has a ratio sum less than 1; do you want to normalize the split?" % name)
Expand All @@ -417,6 +473,7 @@ def get_ratios_split(prefix, config):
self.valid_split[name] /= usage
if name in self.test_split:
self.test_split[name] /= usage
self.skip_verif = thelper.utils.str2bool(config["skip_verif"]) if "skip_verif" in config else True

def _get_raw_split(self, indices):
for name in self.total_usage:
Expand Down Expand Up @@ -470,7 +527,7 @@ def get_split(self, datasets, task):
global_size = sum(len(dataset) for dataset in datasets.values())
logger.info("splitting datasets with parsed sizes = %s" % str(dataset_sizes))
if isinstance(task, thelper.tasks.Classification):
# note: with current impl, all classes will be shuffle the same way... (shouldnt matter, right?)
# note: with current impl, all class sets will be shuffled the same way... (shouldnt matter, right?)
global_class_names = task.get_class_names()
logger.info("will split evenly over %d classes..." % len(global_class_names))
sample_maps = {dataset_name: task.get_class_sample_map(dataset.samples) for dataset_name, dataset in datasets.items()}
Expand All @@ -496,10 +553,6 @@ def get_split(self, datasets, task):
idxs_dict_list[dataset_name] += class_idxs_dict_list[dataset_name]
else:
idxs_dict_list[dataset_name] = class_idxs_dict_list[dataset_name]
# one last intra-dataset shuffle for good mesure, samples of the same class should not be always fed consecutively
for dataset_name in datasets:
for idxs_dict_list in [train_idxs, valid_idxs, test_idxs]:
np.random.shuffle(idxs_dict_list[dataset_name])
else: # task is not classif-related, no balancing to be done
dataset_indices = {}
for dataset_name in datasets:
Expand Down Expand Up @@ -547,8 +600,9 @@ def get_loaders(self, datasets, train_idxs, valid_idxs, test_idxs):
else:
dataset.transforms = train_augs_copy
for sample_idx_idx in range(len(sample_idxs)):
# values were paired in tuples earlier, 0=idx, 1=label
loader_sample_idxs.append(sample_idxs[sample_idx_idx][0] + loader_sample_idx_offset)
loader_sample_classes.append(sample_idxs[sample_idx_idx][1]) # values were paired earlier, 0=idx, 1=label
loader_sample_classes.append(sample_idxs[sample_idx_idx][1])
loader_sample_idx_offset += len(dataset)
loader_datasets.append(dataset)
if len(loader_datasets) > 0:
Expand Down

0 comments on commit 08a7690

Please sign in to comment.