Skip to content

Commit

Permalink
Merge d143b92 into 4c3ccc4
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Oct 5, 2020
2 parents 4c3ccc4 + d143b92 commit bb77342
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
42 changes: 39 additions & 3 deletions caliban_toolbox/dataset_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,54 @@ def _validate_dict(self, train_dict):
if 'X' not in train_dict or 'y' not in train_dict:
raise ValueError('X and y must be keys in the training dictionary')

def split(self, train_dict):
def _duplicate_indices(self, indices, min_size):
"""Duplicates supplied indices to that there are min_size number
Args:
indices: array specifying indices of images to be included
min_size: minimum number of images in split
Returns:
array: duplicate indices
"""

multiplier = int(np.ceil(min_size / len(indices)))
new_indices = np.tile(indices, multiplier)
new_indices = new_indices[:min_size]

return new_indices

def split(self, train_dict, min_size=1):
"""Split training dict
Args:
train_dict: dictionary containing paired X and y data
min_size: minimum number of images for each split. If supplied split size leads to a
split with fewer than min_size, duplicates included images up to specified count
Returns:
dict: dict of dicts containing each split
"""
self._validate_dict(train_dict)
X = train_dict['X']
y = train_dict['y']
N_batches = X.shape[0]
index = np.arange(N_batches)

# randomize index so that we can take sequentially larger splits
permuted_index = np.random.RandomState(seed=self.seed).permutation(index)

split_dict = {}
for split in self.splits:
new_train_dict = {}
train_size = int(split * N_batches)
# minimum of 1 image per split
train_size = max(int(split * N_batches), 1)
split_idx = permuted_index[0:train_size]

# duplicate indices up to minimum batch size if necessary
if len(split_idx) < min_size:
split_idx = self._duplicate_indices(indices=split_idx, min_size=min_size)

new_train_dict = {}
new_train_dict['X'] = X[split_idx]
new_train_dict['y'] = y[split_idx]
split_dict[str(split)] = new_train_dict
Expand Down
24 changes: 24 additions & 0 deletions caliban_toolbox/dataset_splitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def test__validate_dict():
ds._validate_dict(invalid_dict)


def test__duplicate_indices():
test_indices = [np.arange(5), np.arange(1), np.arange(7)]
min_size = 8

for test_idx in test_indices:
ds = DatasetSplitter()
duplicated_indices = ds._duplicate_indices(indices=test_idx, min_size=min_size)

assert len(duplicated_indices) == min_size
# all of the same indices are still present
assert set(test_idx) == set(duplicated_indices)


def test_split():
X_vals = np.arange(100)
y_vals = np.arange(100, 200)
Expand Down Expand Up @@ -126,3 +139,14 @@ def test_split():

for data in current_split:
assert not np.array_equal(current_split[data], original_split[data])

# split corresponding to fewer than 1 image returns a single image
splits = [0.001, 0.3, 1]
ds = DatasetSplitter(splits=splits, seed=0)
split_dict = ds.split(train_dict=data_dict)
print(split_dict['0.001']['X'])
assert len(split_dict['0.001']['X']) == 1

# setting minimum size
split_dict = ds.split(train_dict=data_dict, min_size=10)
assert len(split_dict['0.001']['X']) == 10

0 comments on commit bb77342

Please sign in to comment.