diff --git a/caliban_toolbox/dataset_splitter.py b/caliban_toolbox/dataset_splitter.py index b273a62..7ef266e 100644 --- a/caliban_toolbox/dataset_splitter.py +++ b/caliban_toolbox/dataset_splitter.py @@ -27,37 +27,53 @@ class DatasetSplitter(object): - def __init__(self, seed=0, splits=None): + def __init__(self, seed=0): """Class to split a dataset into sequentially increasing tranches for model training Args: seed: random seed for splitting - splits: list of proportions for each split - - Raises: - ValueError: If splits are not sequentially increasing between (0, 1] """ self.seed = seed - if splits is None: - self.splits = [0.05, 0.10, 0.25, 0.5, 0.75, 1] - else: - splits.sort() - if splits[0] <= 0: - raise ValueError('Smallest split must be non-zero, got {}'.format(splits[0])) - if splits[-1] > 1: - raise ValueError('Largest split cannot be greater than 1, ' - 'got {}'.format(splits[-1])) - ids, counts = np.unique(splits, return_counts=True) - if np.any(counts != 1): - raise ValueError('Duplicate splits are not allowed, each split must be uniqe') - self.splits = splits - def _validate_dict(self, train_dict): if 'X' not in train_dict or 'y' not in train_dict: raise ValueError('X and y must be keys in the training dictionary') + def _validate_split_counts(self, split_counts): + """Ensures that split_counts are properly formatted""" + + split_counts.sort() + if split_counts[0] <= 0: + raise ValueError('Smallest split_count must be greater than 0, ' + 'got {}'.format(split_counts[0])) + + ids, counts = np.unique(split_counts, return_counts=True) + if np.any(counts != 1): + raise ValueError('Duplicate split_counts are not allowed, ' + 'each split must be unique') + dtypes = [isinstance(x, int) for x in split_counts] + if not np.all(dtypes): + raise ValueError('All split_counts must be integers') + + return split_counts + + def _validate_split_proportions(self, split_proportions): + """Ensures that split_proportions are properly formatted""" + + split_proportions.sort() + if split_proportions[0] <= 0: + raise ValueError('Smallest split_proportion must be non-zero, ' + 'got {}'.format(split_proportions[0])) + if split_proportions[-1] > 1: + raise ValueError('Largest split_proportion cannot be greater than 1, ' + 'got {}'.format(split_proportions[-1])) + ids, counts = np.unique(split_proportions, return_counts=True) + if np.any(counts != 1): + raise ValueError('Duplicate splits are not allowed, each split must be uniqe') + + return split_proportions + def _duplicate_indices(self, indices, min_size): """Duplicates supplied indices to that there are min_size number @@ -75,31 +91,53 @@ def _duplicate_indices(self, indices, min_size): return new_indices - def split(self, train_dict, min_size=1): + def split(self, input_dict, split_counts=None, split_proportions=None, min_size=1): """Split training dict Args: - train_dict: dictionary containing paired X and y data + input_dict: dictionary containing paired X and y data + split_counts: list with number of images from total dataset in each split + split_proportions: list with fraction of total dataset in each split min_size: minimum number of images for each split. If supplied split size leads to a split with fewer than min_size, duplicates included images up to specified count Returns: dict: dict of dicts containing each split + + Raises: + ValueError: If split_counts and split_proportions are both None """ - self._validate_dict(train_dict) - X = train_dict['X'] - y = train_dict['y'] + self._validate_dict(input_dict) + + X = input_dict['X'] + y = input_dict['y'] N_batches = X.shape[0] - index = np.arange(N_batches) + + if split_counts is None and split_proportions is None: + raise ValueError('Either split_counts or split_proportions must be supplied') + + if split_counts is not None: + if split_proportions is not None: + raise ValueError('Either split_counts or split_proportions must be supplied,' + 'not both') + # get counts per split and key used to store the split + split_counts = self._validate_split_counts(split_counts=split_counts) + split_keys = split_counts + + if split_proportions is not None: + split_props = self._validate_split_proportions(split_proportions=split_proportions) + + # get counts per split and key used to store the split + split_counts = [max(int(N_batches * split_prop), 1) for split_prop in split_props] + split_keys = split_props # randomize index so that we can take sequentially larger splits + index = np.arange(N_batches) permuted_index = np.random.RandomState(seed=self.seed).permutation(index) split_dict = {} - for split in self.splits: - # minimum of 1 image per split - train_size = max(int(split * N_batches), 1) - split_idx = permuted_index[0:train_size] + for idx, val in enumerate(split_counts): + split_idx = permuted_index[0:val] # duplicate indices up to minimum batch size if necessary if len(split_idx) < min_size: @@ -108,6 +146,6 @@ def split(self, train_dict, min_size=1): new_train_dict = {} new_train_dict['X'] = X[split_idx] new_train_dict['y'] = y[split_idx] - split_dict[str(split)] = new_train_dict + split_dict[str(split_keys[idx])] = new_train_dict return split_dict diff --git a/caliban_toolbox/dataset_splitter_test.py b/caliban_toolbox/dataset_splitter_test.py index e831a0f..bed6192 100644 --- a/caliban_toolbox/dataset_splitter_test.py +++ b/caliban_toolbox/dataset_splitter_test.py @@ -31,34 +31,57 @@ def test__init__(): - seed = 123, - splits = [0.5, 0.75, 1] + seed = 123 + ds = DatasetSplitter(seed=seed) + assert ds.seed == seed - ds = DatasetSplitter(seed=seed, splits=splits) - assert ds.seed == seed - assert ds.splits == splits +def test__validate_split_counts(): + # unsorted split_counts get sorted + ds = DatasetSplitter() + split_counts = [5, 1, 10] + valid_counts = ds._validate_split_counts(split_counts=split_counts) + + assert valid_counts == sorted(valid_counts) + + with pytest.raises(ValueError): + # first split_count is size 0 + split_counts = [0, 1, 4] + _ = ds._validate_split_counts(split_counts=split_counts) + + with pytest.raises(ValueError): + # duplicate split_counts + split_counts = [4, 8, 8] + _ = ds._validate_split_counts(split_counts=split_counts) + + with pytest.raises(ValueError): + # non-integer split counts + split_counts = [4, 0.25, 7] + _ = ds._validate_split_counts(split_counts=split_counts) - # unsorted splits get sorted - splits = [0.8, 0.3, 0.5] - ds = DatasetSplitter(seed=seed, splits=splits) - assert ds.splits == sorted(ds.splits) +def test__validate_split_proportions(): + # unsorted split_proportions get sorted + ds = DatasetSplitter() + split_proportions = [0.8, 0.3, 0.5] + valid_proportions = ds._validate_split_proportions(split_proportions=split_proportions) + + assert valid_proportions == sorted(valid_proportions) with pytest.raises(ValueError): - # first split is size 0 - splits = [0, 0.25, 0.5] - _ = DatasetSplitter(splits=splits) + # first split_proportion is size 0 + split_proportions = [0, 0.25, 0.5] + _ = ds._validate_split_proportions(split_proportions=split_proportions) with pytest.raises(ValueError): - # last split is greater than 1 - splits = [0.1, 0.25, 1.5] - _ = DatasetSplitter(splits=splits) + # last split_proportion is greater than 1 + split_proportions = [0.1, 0.25, 1.5] + _ = ds._validate_split_proportions(split_proportions=split_proportions) with pytest.raises(ValueError): - # duplicate splits - splits = [0.1, 0.1, 1] - _ = DatasetSplitter(splits=splits) + # duplicate split_proportions + split_proportions = [0.1, 0.1, 1] + _ = ds._validate_split_proportions(split_proportions=split_proportions) def test__validate_dict(): @@ -88,18 +111,18 @@ def test__duplicate_indices(): assert set(test_idx) == set(duplicated_indices) -def test_split(): +def test_split_by_proportion(): X_vals = np.arange(100) y_vals = np.arange(100, 200) data_dict = {'X': X_vals, 'y': y_vals} - splits = [0.1, 0.5, 1] - ds = DatasetSplitter(splits=splits, seed=0) - split_dict = ds.split(train_dict=data_dict) + split_proportions = [0.1, 0.5, 1] + ds = DatasetSplitter(seed=0) + split_dict = ds.split(input_dict=data_dict, split_proportions=split_proportions) split_x_vals, split_y_vals = [], [] - for split in splits: + for split in split_proportions: current_split = split_dict[str(split)] assert len(current_split['X']) == int(100 * split) @@ -121,8 +144,8 @@ def test_split(): split_y_vals = current_y_vals # same seed should produce same values - ds = DatasetSplitter(splits=splits, seed=0) - split_dict_same_seed = ds.split(train_dict=data_dict) + ds = DatasetSplitter(seed=0) + split_dict_same_seed = ds.split(input_dict=data_dict, split_proportions=split_proportions) for split in split_dict_same_seed: current_split = split_dict_same_seed[split] original_split = split_dict[split] @@ -130,9 +153,9 @@ def test_split(): for data in current_split: assert np.array_equal(current_split[data], original_split[data]) - # differet seed should produce different values - ds = DatasetSplitter(splits=splits, seed=1) - split_dict_same_seed = ds.split(train_dict=data_dict) + # different seed should produce different values + ds = DatasetSplitter(seed=1) + split_dict_same_seed = ds.split(input_dict=data_dict, split_proportions=split_proportions) for split in split_dict_same_seed: current_split = split_dict_same_seed[split] original_split = split_dict[split] @@ -141,11 +164,68 @@ def test_split(): assert not np.array_equal(current_split[data], original_split[data]) # split corresponding to fewer than 1 image returns a single image - splits = [0.001, 0.3, 1] - ds = DatasetSplitter(splits=splits, seed=0) - split_dict = ds.split(train_dict=data_dict) + split_proportions = [0.001, 0.3, 1] + ds = DatasetSplitter(seed=0) + split_dict = ds.split(input_dict=data_dict, split_proportions=split_proportions) assert len(split_dict['0.001']['X']) == 1 # setting minimum size - split_dict = ds.split(train_dict=data_dict, min_size=10) + split_dict = ds.split(input_dict=data_dict, min_size=10, split_proportions=split_proportions) assert len(split_dict['0.001']['X']) == 10 + + +def test_split_by_count(): + X_vals = np.arange(100) + y_vals = np.arange(100, 200) + + data_dict = {'X': X_vals, 'y': y_vals} + + split_counts = [4, 50, 100] + ds = DatasetSplitter(seed=0) + split_dict = ds.split(input_dict=data_dict, split_counts=split_counts) + + split_x_vals, split_y_vals = [], [] + for split in split_counts: + current_split = split_dict[str(split)] + + assert len(current_split['X']) == split + + if split_x_vals == []: + # first split + split_x_vals = current_split['X'] + split_y_vals = current_split['y'] + else: + # make sure all all previous values are in current split + current_x_vals = current_split['X'] + current_y_vals = current_split['y'] + + assert np.all(np.isin(split_x_vals, current_x_vals)) + assert np.all(np.isin(split_y_vals, current_y_vals)) + + # update counter with current values + split_x_vals = current_x_vals + split_y_vals = current_y_vals + + # same seed should produce same values + ds = DatasetSplitter(seed=0) + split_dict_same_seed = ds.split(input_dict=data_dict, split_counts=split_counts) + for split in split_dict_same_seed: + current_split = split_dict_same_seed[split] + original_split = split_dict[split] + + for data in current_split: + assert np.array_equal(current_split[data], original_split[data]) + + # different seed should produce different values + ds = DatasetSplitter(seed=1) + split_dict_same_seed = ds.split(input_dict=data_dict, split_counts=split_counts) + for split in split_dict_same_seed: + current_split = split_dict_same_seed[split] + original_split = split_dict[split] + + for data in current_split: + assert not np.array_equal(current_split[data], original_split[data]) + + # setting minimum size + split_dict = ds.split(input_dict=data_dict, min_size=10, split_counts=split_counts) + assert len(split_dict['4']['X']) == 10