Skip to content

Commit

Permalink
add option for splitting by image number
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Oct 9, 2020
1 parent 0615e5b commit 59a03e4
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 62 deletions.
98 changes: 68 additions & 30 deletions caliban_toolbox/dataset_splitter.py
Expand Up @@ -27,37 +27,53 @@


class DatasetSplitter(object):
def __init__(self, seed=0, splits=None):
def __init__(self, seed=0):
"""Class to split a dataset into sequentially increasing tranches for model training
Args:
seed: random seed for splitting
splits: list of proportions for each split
Raises:
ValueError: If splits are not sequentially increasing between (0, 1]
"""

self.seed = seed

if splits is None:
self.splits = [0.05, 0.10, 0.25, 0.5, 0.75, 1]
else:
splits.sort()
if splits[0] <= 0:
raise ValueError('Smallest split must be non-zero, got {}'.format(splits[0]))
if splits[-1] > 1:
raise ValueError('Largest split cannot be greater than 1, '
'got {}'.format(splits[-1]))
ids, counts = np.unique(splits, return_counts=True)
if np.any(counts != 1):
raise ValueError('Duplicate splits are not allowed, each split must be uniqe')
self.splits = splits

def _validate_dict(self, train_dict):
if 'X' not in train_dict or 'y' not in train_dict:
raise ValueError('X and y must be keys in the training dictionary')

def _validate_split_counts(self, split_counts):
"""Ensures that split_counts are properly formatted"""

split_counts.sort()
if split_counts[0] <= 0:
raise ValueError('Smallest split_count must be greater than 0, '
'got {}'.format(split_counts[0]))

ids, counts = np.unique(split_counts, return_counts=True)
if np.any(counts != 1):
raise ValueError('Duplicate split_counts are not allowed, '
'each split must be unique')
dtypes = [isinstance(x, int) for x in split_counts]
if not np.all(dtypes):
raise ValueError('All split_counts must be integers')

return split_counts

def _validate_split_proportions(self, split_proportions):
"""Ensures that split_proportions are properly formatted"""

split_proportions.sort()
if split_proportions[0] <= 0:
raise ValueError('Smallest split_proportion must be non-zero, '
'got {}'.format(split_proportions[0]))
if split_proportions[-1] > 1:
raise ValueError('Largest split_proportion cannot be greater than 1, '
'got {}'.format(split_proportions[-1]))
ids, counts = np.unique(split_proportions, return_counts=True)
if np.any(counts != 1):
raise ValueError('Duplicate splits are not allowed, each split must be uniqe')

return split_proportions

def _duplicate_indices(self, indices, min_size):
"""Duplicates supplied indices to that there are min_size number
Expand All @@ -75,31 +91,53 @@ def _duplicate_indices(self, indices, min_size):

return new_indices

def split(self, train_dict, min_size=1):
def split(self, input_dict, split_counts=None, split_proportions=None, min_size=1):
"""Split training dict
Args:
train_dict: dictionary containing paired X and y data
input_dict: dictionary containing paired X and y data
split_counts: list with number of images from total dataset in each split
split_proportions: list with fraction of total dataset in each split
min_size: minimum number of images for each split. If supplied split size leads to a
split with fewer than min_size, duplicates included images up to specified count
Returns:
dict: dict of dicts containing each split
Raises:
ValueError: If split_counts and split_proportions are both None
"""
self._validate_dict(train_dict)
X = train_dict['X']
y = train_dict['y']
self._validate_dict(input_dict)

X = input_dict['X']
y = input_dict['y']
N_batches = X.shape[0]
index = np.arange(N_batches)

if split_counts is None and split_proportions is None:
raise ValueError('Either split_counts or split_proportions must be supplied')

if split_counts is not None:
if split_proportions is not None:
raise ValueError('Either split_counts or split_proportions must be supplied,'
'not both')
# get counts per split and key used to store the split
split_counts = self._validate_split_counts(split_counts=split_counts)
split_keys = split_counts

if split_proportions is not None:
split_props = self._validate_split_proportions(split_proportions=split_proportions)

# get counts per split and key used to store the split
split_counts = [max(int(N_batches * split_prop), 1) for split_prop in split_props]
split_keys = split_props

# randomize index so that we can take sequentially larger splits
index = np.arange(N_batches)
permuted_index = np.random.RandomState(seed=self.seed).permutation(index)

split_dict = {}
for split in self.splits:
# minimum of 1 image per split
train_size = max(int(split * N_batches), 1)
split_idx = permuted_index[0:train_size]
for idx, val in enumerate(split_counts):
split_idx = permuted_index[0:val]

# duplicate indices up to minimum batch size if necessary
if len(split_idx) < min_size:
Expand All @@ -108,6 +146,6 @@ def split(self, train_dict, min_size=1):
new_train_dict = {}
new_train_dict['X'] = X[split_idx]
new_train_dict['y'] = y[split_idx]
split_dict[str(split)] = new_train_dict
split_dict[str(split_keys[idx])] = new_train_dict

return split_dict
144 changes: 112 additions & 32 deletions caliban_toolbox/dataset_splitter_test.py
Expand Up @@ -31,34 +31,57 @@


def test__init__():
seed = 123,
splits = [0.5, 0.75, 1]
seed = 123
ds = DatasetSplitter(seed=seed)
assert ds.seed == seed

ds = DatasetSplitter(seed=seed, splits=splits)

assert ds.seed == seed
assert ds.splits == splits
def test__validate_split_counts():
# unsorted split_counts get sorted
ds = DatasetSplitter()
split_counts = [5, 1, 10]
valid_counts = ds._validate_split_counts(split_counts=split_counts)

assert valid_counts == sorted(valid_counts)

with pytest.raises(ValueError):
# first split_count is size 0
split_counts = [0, 1, 4]
_ = ds._validate_split_counts(split_counts=split_counts)

with pytest.raises(ValueError):
# duplicate split_counts
split_counts = [4, 8, 8]
_ = ds._validate_split_counts(split_counts=split_counts)

with pytest.raises(ValueError):
# non-integer split counts
split_counts = [4, 0.25, 7]
_ = ds._validate_split_counts(split_counts=split_counts)

# unsorted splits get sorted
splits = [0.8, 0.3, 0.5]
ds = DatasetSplitter(seed=seed, splits=splits)

assert ds.splits == sorted(ds.splits)
def test__validate_split_proportions():
# unsorted split_proportions get sorted
ds = DatasetSplitter()
split_proportions = [0.8, 0.3, 0.5]
valid_proportions = ds._validate_split_proportions(split_proportions=split_proportions)

assert valid_proportions == sorted(valid_proportions)

with pytest.raises(ValueError):
# first split is size 0
splits = [0, 0.25, 0.5]
_ = DatasetSplitter(splits=splits)
# first split_proportion is size 0
split_proportions = [0, 0.25, 0.5]
_ = ds._validate_split_proportions(split_proportions=split_proportions)

with pytest.raises(ValueError):
# last split is greater than 1
splits = [0.1, 0.25, 1.5]
_ = DatasetSplitter(splits=splits)
# last split_proportion is greater than 1
split_proportions = [0.1, 0.25, 1.5]
_ = ds._validate_split_proportions(split_proportions=split_proportions)

with pytest.raises(ValueError):
# duplicate splits
splits = [0.1, 0.1, 1]
_ = DatasetSplitter(splits=splits)
# duplicate split_proportions
split_proportions = [0.1, 0.1, 1]
_ = ds._validate_split_proportions(split_proportions=split_proportions)


def test__validate_dict():
Expand Down Expand Up @@ -88,18 +111,18 @@ def test__duplicate_indices():
assert set(test_idx) == set(duplicated_indices)


def test_split():
def test_split_by_proportion():
X_vals = np.arange(100)
y_vals = np.arange(100, 200)

data_dict = {'X': X_vals, 'y': y_vals}

splits = [0.1, 0.5, 1]
ds = DatasetSplitter(splits=splits, seed=0)
split_dict = ds.split(train_dict=data_dict)
split_proportions = [0.1, 0.5, 1]
ds = DatasetSplitter(seed=0)
split_dict = ds.split(input_dict=data_dict, split_proportions=split_proportions)

split_x_vals, split_y_vals = [], []
for split in splits:
for split in split_proportions:
current_split = split_dict[str(split)]

assert len(current_split['X']) == int(100 * split)
Expand All @@ -121,18 +144,18 @@ def test_split():
split_y_vals = current_y_vals

# same seed should produce same values
ds = DatasetSplitter(splits=splits, seed=0)
split_dict_same_seed = ds.split(train_dict=data_dict)
ds = DatasetSplitter(seed=0)
split_dict_same_seed = ds.split(input_dict=data_dict, split_proportions=split_proportions)
for split in split_dict_same_seed:
current_split = split_dict_same_seed[split]
original_split = split_dict[split]

for data in current_split:
assert np.array_equal(current_split[data], original_split[data])

# differet seed should produce different values
ds = DatasetSplitter(splits=splits, seed=1)
split_dict_same_seed = ds.split(train_dict=data_dict)
# different seed should produce different values
ds = DatasetSplitter(seed=1)
split_dict_same_seed = ds.split(input_dict=data_dict, split_proportions=split_proportions)
for split in split_dict_same_seed:
current_split = split_dict_same_seed[split]
original_split = split_dict[split]
Expand All @@ -141,11 +164,68 @@ def test_split():
assert not np.array_equal(current_split[data], original_split[data])

# split corresponding to fewer than 1 image returns a single image
splits = [0.001, 0.3, 1]
ds = DatasetSplitter(splits=splits, seed=0)
split_dict = ds.split(train_dict=data_dict)
split_proportions = [0.001, 0.3, 1]
ds = DatasetSplitter(seed=0)
split_dict = ds.split(input_dict=data_dict, split_proportions=split_proportions)
assert len(split_dict['0.001']['X']) == 1

# setting minimum size
split_dict = ds.split(train_dict=data_dict, min_size=10)
split_dict = ds.split(input_dict=data_dict, min_size=10, split_proportions=split_proportions)
assert len(split_dict['0.001']['X']) == 10


def test_split_by_count():
X_vals = np.arange(100)
y_vals = np.arange(100, 200)

data_dict = {'X': X_vals, 'y': y_vals}

split_counts = [4, 50, 100]
ds = DatasetSplitter(seed=0)
split_dict = ds.split(input_dict=data_dict, split_counts=split_counts)

split_x_vals, split_y_vals = [], []
for split in split_counts:
current_split = split_dict[str(split)]

assert len(current_split['X']) == split

if split_x_vals == []:
# first split
split_x_vals = current_split['X']
split_y_vals = current_split['y']
else:
# make sure all all previous values are in current split
current_x_vals = current_split['X']
current_y_vals = current_split['y']

assert np.all(np.isin(split_x_vals, current_x_vals))
assert np.all(np.isin(split_y_vals, current_y_vals))

# update counter with current values
split_x_vals = current_x_vals
split_y_vals = current_y_vals

# same seed should produce same values
ds = DatasetSplitter(seed=0)
split_dict_same_seed = ds.split(input_dict=data_dict, split_counts=split_counts)
for split in split_dict_same_seed:
current_split = split_dict_same_seed[split]
original_split = split_dict[split]

for data in current_split:
assert np.array_equal(current_split[data], original_split[data])

# different seed should produce different values
ds = DatasetSplitter(seed=1)
split_dict_same_seed = ds.split(input_dict=data_dict, split_counts=split_counts)
for split in split_dict_same_seed:
current_split = split_dict_same_seed[split]
original_split = split_dict[split]

for data in current_split:
assert not np.array_equal(current_split[data], original_split[data])

# setting minimum size
split_dict = ds.split(input_dict=data_dict, min_size=10, split_counts=split_counts)
assert len(split_dict['4']['X']) == 10

0 comments on commit 59a03e4

Please sign in to comment.