Skip to content

Commit

Permalink
Fixup external labels prefetching in splitter
Browse files Browse the repository at this point in the history
  • Loading branch information
plstcharles committed Oct 22, 2018
1 parent 7188465 commit cb89e30
Showing 1 changed file with 26 additions and 3 deletions.
29 changes: 26 additions & 3 deletions src/thelper/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,18 +528,41 @@ def get_split(self, datasets, task):
A three-element tuple containing the maps of the training, validation, and test sets
respectively. These maps associate dataset names to a list of sample indices.
"""
dataset_sizes = {dataset_name: len(dataset) for dataset_name, dataset in datasets.items()}
dataset_sizes = {}
global_size = 0
for dataset_name, dataset in datasets.items():
if not isinstance(dataset, thelper.data.Dataset) and not isinstance(dataset, thelper.data.ExternalDataset):
raise AssertionError("unexpected dataset type for '%s'" % dataset_name)
dataset_sizes[dataset_name] = len(dataset)
global_size += dataset_sizes[dataset_name]
global_size = sum(len(dataset) for dataset in datasets.values())
logger.info("splitting datasets with parsed sizes = %s" % str(dataset_sizes))
if isinstance(task, thelper.tasks.Classification):
# note: with current impl, all class sets will be shuffled the same way... (shouldnt matter, right?)
global_class_names = task.get_class_names()
logger.info("will split evenly over %d classes..." % len(global_class_names))
sample_maps = {dataset_name: task.get_class_sample_map(dataset.samples) for dataset_name, dataset in datasets.items()}
sample_maps = {}
for dataset_name, dataset in datasets.items():
if isinstance(dataset, thelper.data.ExternalDataset):
if hasattr(dataset.samples, "samples"):
sample_maps[dataset_name] = task.get_class_sample_map(dataset.samples.samples)
else:
logger.warning(("must fully parse the external dataset '%s' for intra-class shuffling;" % dataset_name) +
" this might take a while! (consider making a dataset interface that can return labels only)")
label_keys = task.get_gt_key() if isinstance(task.get_gt_key(), list) else [task.get_gt_key()]
samples = []
for sample in dataset:
for key in label_keys:
if key in sample:
samples.append({key: sample[key]})
break # by default, stop after finding first match
sample_maps[dataset_name] = task.get_class_sample_map(samples)
elif isinstance(dataset, thelper.data.Dataset):
sample_maps[dataset_name] = task.get_class_sample_map(dataset.samples)
train_idxs, valid_idxs, test_idxs = {}, {}, {}
for class_name in global_class_names:
curr_class_samples, curr_class_size = {}, {}
for dataset_name, dataset in datasets.items():
for dataset_name in datasets:
class_samples = sample_maps[dataset_name][class_name] if class_name in sample_maps[dataset_name] else []
samples_pairs = list(zip(class_samples, [class_name] * len(class_samples)))
curr_class_samples[dataset_name] = samples_pairs
Expand Down

0 comments on commit cb89e30

Please sign in to comment.