Skip to content

Commit

Permalink
corner cases for train/test/val
Browse files Browse the repository at this point in the history
  • Loading branch information
ngreenwald committed Aug 24, 2020
1 parent 307ec00 commit 8ea6a87
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 5 deletions.
35 changes: 30 additions & 5 deletions caliban_toolbox/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def train_val_test_split(X_data, y_data, data_split=(0.8, 0.1, 0.1), seed=None):
seed: random seed for reproducible splits
Returns:
list of X and y data split appropriately
list of X and y data split appropriately. If dataset is too small for all splits,
returns None for remaining splits
Raises:
ValueError: if ratios do not sum to 1
Expand All @@ -195,20 +196,44 @@ def train_val_test_split(X_data, y_data, data_split=(0.8, 0.1, 0.1), seed=None):

train_ratio, val_ratio, test_ratio = data_split

# 1 image, train split only
if X_data.shape[0] == 1:
return X_data, y_data, None, None, None, None

# 2 images, train and val split only
if X_data.shape[0] == 2:
return X_data[:1, ...], y_data[:1, ...], X_data[1:, ...], y_data[1:, ...], None, None

# compute fraction not in train
remainder_size = np.round(1 - train_ratio, decimals=2)
val_remainder_ratio = np.round(1 - train_ratio, decimals=2)
val_remainder_count = X_data.shape[0] * val_remainder_ratio
# not enough data for val split, put minimum (1) in each split
if val_remainder_count < 1:
X_train, y_train = X_data[:-2], y_data[:-2]
X_val, y_val = X_data[-1:], y_data[-1:]
X_test, y_test = X_data[-2:-1], y_data[-2:-1]
return X_train, y_train, X_val, y_val, X_test, y_test

# split dataset into train and remainder
X_train, X_remainder, y_train, y_remainder = train_test_split(X_data, y_data,
test_size=remainder_size,
test_size=val_remainder_ratio,
random_state=seed)

# compute fraction of remainder that is test
test_size = np.round(test_ratio / (val_ratio + test_ratio), decimals=2)
test_remainder_ratio = np.round(test_ratio / (val_ratio + test_ratio), decimals=2)
test_remainder_count = X_remainder.shape[0] * test_remainder_ratio

# not enough data for test split, put minimum (1) in test split
if test_remainder_count < 1:
X_test, y_test = X_train[-1:], y_train[-1:]
X_val, y_val = X_remainder, y_remainder
X_train, y_train = X_train[:-1], y_train[:-1]

return X_train, y_train, X_val, y_val, X_test, y_test

# split remainder into val and test
X_val, X_test, y_val, y_test = train_test_split(X_remainder, y_remainder,
test_size=test_size,
test_size=test_remainder_ratio,
random_state=seed)

return X_train, y_train, X_val, y_val, X_test, y_test
57 changes: 57 additions & 0 deletions caliban_toolbox/build_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ def _make_npzs(sizes, num_images):
return npz_list


def _all_unique_vals(arrays):
unique_vals = []
for array in arrays:
unique_vals.append(np.unique(array))

unique_vals = np.concatenate(unique_vals, axis=0)

unique_in_all = np.unique(unique_vals)

return len(unique_vals) == len(unique_in_all)


def test_compute_cell_size():
labels = np.zeros((3, 40, 40, 1), dtype='int')

Expand Down Expand Up @@ -216,6 +228,11 @@ def test_train_val_test_split():
X_data = np.zeros((100, 5, 5, 3))
y_data = np.zeros((100, 5, 5, 1))

unique_vals = np.arange(100)
for val in unique_vals:
X_data[val, ...] = val + 1
y_data[val, ...] = -(val + 1)

train_ratio, val_ratio, test_ratio = 0.7, 0.2, 0.1

X_train, y_train, X_val, y_val, X_test, y_test, = \
Expand All @@ -232,12 +249,52 @@ def test_train_val_test_split():
assert X_test.shape[0] == 100 * test_ratio
assert y_test.shape[0] == 100 * test_ratio

assert _all_unique_vals((X_train, y_train, X_val, y_val, X_test, y_test))

# ensure that None is returned for val and test when data is not large enough to be split
X_train, y_train, X_val, y_val, X_test, y_test, = \
build.train_val_test_split(X_data=X_data[:1],
y_data=y_data[:1],
data_split=[train_ratio, val_ratio, test_ratio])
assert X_train.shape[0] == y_train.shape[0] == 1
assert np.all([val is None for val in [X_val, y_val, X_test, y_test]])

# ensure that None is returned for test when data is not large enough to be split
X_train, y_train, X_val, y_val, X_test, y_test, = \
build.train_val_test_split(X_data=X_data[:2],
y_data=y_data[:2],
data_split=[train_ratio, val_ratio, test_ratio])
assert np.all([val.shape[0] == 1 for val in [X_train, y_train, X_val, y_val]])
assert np.all([val is None for val in [X_test, y_test]])
assert _all_unique_vals((X_train, y_train, X_val, y_val))

# Adjust data appropriately when split sizes will result in zero values for val and test
X_train, y_train, X_val, y_val, X_test, y_test, = \
build.train_val_test_split(X_data=X_data[:5],
y_data=y_data[:5],
data_split=[0.8, 0.1, 0.1])
assert X_train.shape[0] == y_train.shape[0] == 3
assert np.all([val.shape[0] == 1 for val in [X_val, y_val, X_test, y_test]])
assert _all_unique_vals((X_train, y_train, X_val, y_val, X_test, y_test))

# Adjust data appropriately when split sizes will result in zero values for test
X_train, y_train, X_val, y_val, X_test, y_test, = \
build.train_val_test_split(X_data=X_data[:9],
y_data=y_data[:9],
data_split=[0.8, 0.1, 0.1])
assert X_train.shape[0] == y_train.shape[0] == 7
assert np.all([val.shape[0] == 1 for val in [X_val, y_val, X_test, y_test]])
assert _all_unique_vals((X_train, y_train, X_val, y_val, X_test, y_test))

# data split includes 0 for one of splits
with pytest.raises(ValueError):
_ = build.train_val_test_split(X_data=X_data, y_data=y_data, data_split=[0, 0.5, 0.5])

# data split sums to more than 1
with pytest.raises(ValueError):
_ = build.train_val_test_split(X_data=X_data, y_data=y_data, data_split=[0.5, 0.5, 0.5])

# different sizes for X and y
with pytest.raises(ValueError):
_ = build.train_val_test_split(X_data=X_data[:5], y_data=y_data,
data_split=[0.5, 0.5, 0.5])

0 comments on commit 8ea6a87

Please sign in to comment.