From 7cbb95a700de1b1c6d415eff0a2991a4ad0668f2 Mon Sep 17 00:00:00 2001 From: Noah Greenwald Date: Mon, 5 Oct 2020 14:10:00 -0700 Subject: [PATCH] Adding batch-size to dataset_splitter (#120) * add safeguards for small split sizes * switch from repeat to tile * remove print statement --- caliban_toolbox/dataset_splitter.py | 42 ++++++++++++++++++++++-- caliban_toolbox/dataset_splitter_test.py | 23 +++++++++++++ 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/caliban_toolbox/dataset_splitter.py b/caliban_toolbox/dataset_splitter.py index ee93c62..b273a62 100644 --- a/caliban_toolbox/dataset_splitter.py +++ b/caliban_toolbox/dataset_splitter.py @@ -58,18 +58,54 @@ def _validate_dict(self, train_dict): if 'X' not in train_dict or 'y' not in train_dict: raise ValueError('X and y must be keys in the training dictionary') - def split(self, train_dict): + def _duplicate_indices(self, indices, min_size): + """Duplicates supplied indices to that there are min_size number + + Args: + indices: array specifying indices of images to be included + min_size: minimum number of images in split + + Returns: + array: duplicate indices + """ + + multiplier = int(np.ceil(min_size / len(indices))) + new_indices = np.tile(indices, multiplier) + new_indices = new_indices[:min_size] + + return new_indices + + def split(self, train_dict, min_size=1): + """Split training dict + + Args: + train_dict: dictionary containing paired X and y data + min_size: minimum number of images for each split. If supplied split size leads to a + split with fewer than min_size, duplicates included images up to specified count + + Returns: + dict: dict of dicts containing each split + """ self._validate_dict(train_dict) X = train_dict['X'] y = train_dict['y'] N_batches = X.shape[0] index = np.arange(N_batches) + + # randomize index so that we can take sequentially larger splits permuted_index = np.random.RandomState(seed=self.seed).permutation(index) + split_dict = {} for split in self.splits: - new_train_dict = {} - train_size = int(split * N_batches) + # minimum of 1 image per split + train_size = max(int(split * N_batches), 1) split_idx = permuted_index[0:train_size] + + # duplicate indices up to minimum batch size if necessary + if len(split_idx) < min_size: + split_idx = self._duplicate_indices(indices=split_idx, min_size=min_size) + + new_train_dict = {} new_train_dict['X'] = X[split_idx] new_train_dict['y'] = y[split_idx] split_dict[str(split)] = new_train_dict diff --git a/caliban_toolbox/dataset_splitter_test.py b/caliban_toolbox/dataset_splitter_test.py index aa6027c..e831a0f 100644 --- a/caliban_toolbox/dataset_splitter_test.py +++ b/caliban_toolbox/dataset_splitter_test.py @@ -75,6 +75,19 @@ def test__validate_dict(): ds._validate_dict(invalid_dict) +def test__duplicate_indices(): + test_indices = [np.arange(5), np.arange(1), np.arange(7)] + min_size = 8 + + for test_idx in test_indices: + ds = DatasetSplitter() + duplicated_indices = ds._duplicate_indices(indices=test_idx, min_size=min_size) + + assert len(duplicated_indices) == min_size + # all of the same indices are still present + assert set(test_idx) == set(duplicated_indices) + + def test_split(): X_vals = np.arange(100) y_vals = np.arange(100, 200) @@ -126,3 +139,13 @@ def test_split(): for data in current_split: assert not np.array_equal(current_split[data], original_split[data]) + + # split corresponding to fewer than 1 image returns a single image + splits = [0.001, 0.3, 1] + ds = DatasetSplitter(splits=splits, seed=0) + split_dict = ds.split(train_dict=data_dict) + assert len(split_dict['0.001']['X']) == 1 + + # setting minimum size + split_dict = ds.split(train_dict=data_dict, min_size=10) + assert len(split_dict['0.001']['X']) == 10