Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding batch-size to dataset_splitter #120

Merged
merged 3 commits into from
Oct 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 39 additions & 3 deletions caliban_toolbox/dataset_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,54 @@ def _validate_dict(self, train_dict):
if 'X' not in train_dict or 'y' not in train_dict:
raise ValueError('X and y must be keys in the training dictionary')

def split(self, train_dict):
def _duplicate_indices(self, indices, min_size):
"""Duplicates supplied indices to that there are min_size number

Args:
indices: array specifying indices of images to be included
min_size: minimum number of images in split

Returns:
array: duplicate indices
"""

multiplier = int(np.ceil(min_size / len(indices)))
new_indices = np.tile(indices, multiplier)
new_indices = new_indices[:min_size]

return new_indices

def split(self, train_dict, min_size=1):
"""Split training dict

Args:
train_dict: dictionary containing paired X and y data
min_size: minimum number of images for each split. If supplied split size leads to a
split with fewer than min_size, duplicates included images up to specified count

Returns:
dict: dict of dicts containing each split
"""
self._validate_dict(train_dict)
X = train_dict['X']
y = train_dict['y']
N_batches = X.shape[0]
index = np.arange(N_batches)

# randomize index so that we can take sequentially larger splits
permuted_index = np.random.RandomState(seed=self.seed).permutation(index)

split_dict = {}
for split in self.splits:
new_train_dict = {}
train_size = int(split * N_batches)
# minimum of 1 image per split
train_size = max(int(split * N_batches), 1)
split_idx = permuted_index[0:train_size]

# duplicate indices up to minimum batch size if necessary
if len(split_idx) < min_size:
split_idx = self._duplicate_indices(indices=split_idx, min_size=min_size)

new_train_dict = {}
new_train_dict['X'] = X[split_idx]
new_train_dict['y'] = y[split_idx]
split_dict[str(split)] = new_train_dict
Expand Down
23 changes: 23 additions & 0 deletions caliban_toolbox/dataset_splitter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ def test__validate_dict():
ds._validate_dict(invalid_dict)


def test__duplicate_indices():
test_indices = [np.arange(5), np.arange(1), np.arange(7)]
min_size = 8

for test_idx in test_indices:
ds = DatasetSplitter()
duplicated_indices = ds._duplicate_indices(indices=test_idx, min_size=min_size)

assert len(duplicated_indices) == min_size
# all of the same indices are still present
assert set(test_idx) == set(duplicated_indices)


def test_split():
X_vals = np.arange(100)
y_vals = np.arange(100, 200)
Expand Down Expand Up @@ -126,3 +139,13 @@ def test_split():

for data in current_split:
assert not np.array_equal(current_split[data], original_split[data])

# split corresponding to fewer than 1 image returns a single image
splits = [0.001, 0.3, 1]
ds = DatasetSplitter(splits=splits, seed=0)
split_dict = ds.split(train_dict=data_dict)
assert len(split_dict['0.001']['X']) == 1

# setting minimum size
split_dict = ds.split(train_dict=data_dict, min_size=10)
assert len(split_dict['0.001']['X']) == 10