From 1dc8b576acfed14719e7b498bbd93f509f2cfb6c Mon Sep 17 00:00:00 2001 From: Noah Greenwald Date: Thu, 3 Sep 2020 10:04:42 -0700 Subject: [PATCH] DatasetBuilder and DatasetBenchmarker (#115) * modified train/val/test function * refactored init * switched to loading helper function * loads all experiments and splits data * build dataset function * first pass at DataBuilder * working version of bencharmker * removed separate id and name fields * integrated resizing * pep8 * switch naming in Benchmarker * pep8 * kwargs to pass in additional arguments * fixed mocker * handle images with no labeled cells * fix bug with resizing * corner cases for train/test/val * incorporate train_val_split changes * added summarize * simplified output format for Benchmarker * simplified _reshape_dict * renamed variables * pep8 * renamed files * fixed bug in summarize * different output shapes for each split * simplified benchmarking dicts * renamed file * added explicit 4D checks where necessary, abstracted to ... where not * extract error checking * bugfixes * Apply suggestions from code review Co-authored-by: willgraf <7930703+willgraf@users.noreply.github.com> * add option to resize by constant value * add testing for reproducible splits * pep8 * better input validation * Apply suggestions from code review Co-authored-by: willgraf <7930703+willgraf@users.noreply.github.com> * typo Co-authored-by: willgraf <7930703+willgraf@users.noreply.github.com> --- caliban_toolbox/build.py | 182 +++-- caliban_toolbox/build_test.py | 289 ++++---- caliban_toolbox/dataset_benchmarker.py | 121 ++++ caliban_toolbox/dataset_benchmarker_test.py | 113 ++++ caliban_toolbox/dataset_builder.py | 629 ++++++++++++++++++ caliban_toolbox/dataset_builder_test.py | 695 ++++++++++++++++++++ caliban_toolbox/utils/misc_utils.py | 16 + caliban_toolbox/utils/misc_utils_test.py | 15 + 8 files changed, 1872 insertions(+), 188 deletions(-) create mode 100644 caliban_toolbox/dataset_benchmarker.py create mode 100644 caliban_toolbox/dataset_benchmarker_test.py create mode 100644 caliban_toolbox/dataset_builder.py create mode 100644 caliban_toolbox/dataset_builder_test.py diff --git a/caliban_toolbox/build.py b/caliban_toolbox/build.py index c1a6e8b..b30d31b 100644 --- a/caliban_toolbox/build.py +++ b/caliban_toolbox/build.py @@ -23,12 +23,14 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import warnings import math import numpy as np from skimage.measure import regionprops_table +from sklearn.model_selection import train_test_split + from deepcell_toolbox.utils import resize, tile_image @@ -43,14 +45,15 @@ def compute_cell_size(npz_file, method='median', by_image=True): the cell size across the entire npz is returned Returns: - average_sizes: list of typical cell size in NPZ + list: list of typical cell size in NPZ. If no cells, returns None. Raises: ValueError if invalid method supplied Raises: ValueError if data does have len(shape) of 4 """ - valid_methods = set(['median', 'mean']) - if method.lower() not in valid_methods: + valid_methods = {'median', 'mean'} + method = str(method).lower() + if method not in valid_methods: raise ValueError('Invalid method supplied: got {}, ' 'method must be one of {}'.format(method, valid_methods)) @@ -63,9 +66,15 @@ def compute_cell_size(npz_file, method='median', by_image=True): for i in range(labels.shape[0]): current_label = labels[i, :, :, 0] - area = regionprops_table(current_label.astype('int'), properties=['area'])['area'] - cell_sizes.append(area) + # check to make sure array contains cells + if len(np.unique(current_label)) > 1: + area = regionprops_table(current_label.astype('int'), properties=['area'])['area'] + cell_sizes.append(area) + + # if all images were empty, return NA + if cell_sizes == []: + return None # compute for each list corresponding to each image if by_image: @@ -80,16 +89,16 @@ def compute_cell_size(npz_file, method='median', by_image=True): else: all_cell_sizes = np.concatenate(cell_sizes) if method == 'mean': - average_cell_sizes = [np.mean(all_cell_sizes)] + average_cell_sizes = np.mean(all_cell_sizes) elif method == 'median': - average_cell_sizes = [np.median(all_cell_sizes)] + average_cell_sizes = np.median(all_cell_sizes) else: raise ValueError('Invalid method supplied') return average_cell_sizes -def reshape_training_image(X_data, y_data, resize_ratio, final_size, stride_ratio): +def reshape_training_data(X_data, y_data, resize_ratio, final_size, stride_ratio=1, tolerance=1.5): """Takes a stack of X and y data and reshapes and crops them to match output dimensions Args: @@ -98,14 +107,20 @@ def reshape_training_image(X_data, y_data, resize_ratio, final_size, stride_rati resize_ratio: resize ratio for the images final_size: the desired shape of the output image stride_ratio: amount of overlap between crops (1 is no overlap, 0.5 is half crop size) + tolerance: ratio that determines when resizing occurs Returns: reshaped_X, reshaped_y: resized and cropped version of input images + + Raises: + ValueError: If image data is not 4D """ + if len(X_data.shape) != 4: + raise ValueError('Image data must be 4D') + # resize if needed - # TODO: Add tolerance to control when resizing happens - if resize_ratio != 1: + if resize_ratio > tolerance or resize_ratio < (1 / tolerance): new_shape = (int(X_data.shape[1] * resize_ratio), int(X_data.shape[2] * resize_ratio)) @@ -135,8 +150,14 @@ def pad_image_stack(images, crop_size): Returns: np.array: padded image stack + + Raises: + ValueError: If images are not 4D """ + if len(images.shape) != 4: + raise ValueError('Image data must be 4D') + row_len, col_len = images.shape[1:3] row_crop, col_crop = crop_size row_num = math.ceil(row_len / crop_size[0]) @@ -149,73 +170,94 @@ def pad_image_stack(images, crop_size): # don't need to pad return images else: - new_images = np.zeros((images.shape[0], new_row_len, new_col_len, images.shape[3])) + new_images = np.zeros((images.shape[0], new_row_len, new_col_len, images.shape[3]), + dtype=images.dtype) new_images[:, :row_len, :col_len, :] = images return new_images -def combine_npz_files(npz_list, resize_ratios, stride_ratio=1, final_size=(256, 256)): - """Take a series of NPZ files and combine together into single training NPZ +def train_val_test_split(X_data, y_data, data_split=(0.8, 0.1, 0.1), seed=None): + """Randomly splits supplied data into specified sizes for model assessment Args: - npz_list: list of NPZ files to combine. Currently only works on 2D static data - resize_ratios: ratio used to resize each NPZ if data is of different resolutions. Must - be either 1 for each NPZ file, or 1 for each image within the NPZ file - stride_ratio: amount of overlap between crops (1 is no overlap, 0.5 is half crop size) - final_size: size of the final crops to be produced + X_data: array of X data + y_data: array of y_data + data_split: tuple specifying the fraction of the dataset for train/val/test + seed: random seed for reproducible splits Returns: - np.array: array containing resized and cropped data from all input NPZs + list of X and y data split appropriately. If dataset is too small for all splits, + returns None for remaining splits Raises: - ValueError: If mismatch between number of resize ratios and number of images + ValueError: if ratios do not sum to 1 + ValueError: if any of the splits are 0 + ValueError: If length of X and y data is not equal """ - combined_x = [] - combined_y = [] - - for idx, npz in enumerate(npz_list): - current_x = npz['X'] - current_y = npz['y'] - current_resize = resize_ratios[idx] - - # same resize value for entire NPZ file - if len(current_resize) == 1: - current_x, current_y = reshape_training_image(X_data=current_x, - y_data=current_y, - resize_ratio=current_resize[0], - final_size=final_size, - stride_ratio=stride_ratio) - combined_x.append(current_x) - combined_y.append(current_y) - - # different resize value for each image within the NPZ file - else: - unique_x, unique_y = [], [] - if len(current_resize) != current_x.shape[0]: - raise ValueError('Resize ratios must have same length as image data.' - 'Provided resize ratios has length {} ' - 'and image data has shape {}'.format(len(resize_ratios), - current_x.shape)) - # loop over each image and resize + crop appropriately - for img in range(current_x.shape[0]): - x_batch, y_batch = reshape_training_image(X_data=current_x[img:(img + 1)], - y_data=current_y[img:(img + 1)], - resize_ratio=current_resize[img], - final_size=final_size, - stride_ratio=stride_ratio) - unique_x.append(x_batch) - unique_y.append(y_batch) - - # combine all images from this NPZ together - current_x = np.concatenate(unique_x, axis=0) - current_y = np.concatenate(unique_y, axis=0) - - # add combined images from this NPZ onto main accumulator list - combined_x.append(current_x) - combined_y.append(current_y) - - # combine all images from all NPZs together - combined_x = np.concatenate(combined_x, axis=0) - combined_y = np.concatenate(combined_y, axis=0) - - return combined_x, combined_y + + total = np.round(np.sum(data_split), decimals=2) + if total != 1: + raise ValueError('Data splits must sum to 1, supplied splits sum to {}'.format(total)) + + if 0 in data_split: + raise ValueError('All splits must be non-zero') + + if X_data.shape[0] != y_data.shape[0]: + raise ValueError('Supplied X and y data do not have the same ' + 'length over batches dimension. ' + 'X.shape: {}, y.shape: {}'.format(X_data.shape, y_data.shape)) + + train_ratio, val_ratio, test_ratio = data_split + + # 1 image, train split only + if X_data.shape[0] == 1: + warnings.warn('Only one image in current NPZ, returning training split only') + return X_data, y_data, None, None, None, None + + # 2 images, train and val split only + if X_data.shape[0] == 2: + warnings.warn('Only two images in current NPZ, returning training and val split only') + return X_data[:1, ...], y_data[:1, ...], X_data[1:, ...], y_data[1:, ...], None, None + + # compute fraction not in train + val_remainder_ratio = np.round(1 - train_ratio, decimals=2) + val_remainder_count = X_data.shape[0] * val_remainder_ratio + + # not enough data for val split, put minimum (1) in each split + if val_remainder_count < 1: + warnings.warn('Not enough data in current NPZ for specified data split.' + 'Returning modified data split') + X_train, X_remainder, y_train, y_remainder = train_test_split(X_data, y_data, + test_size=2, + random_state=seed) + X_val, X_test, y_val, y_test = train_test_split(X_remainder, y_remainder, + test_size=1, + random_state=seed) + return X_train, y_train, X_val, y_val, X_test, y_test + + # split dataset into train and remainder + X_train, X_remainder, y_train, y_remainder = train_test_split(X_data, y_data, + test_size=val_remainder_ratio, + random_state=seed) + + # compute fraction of remainder that is test + test_remainder_ratio = np.round(test_ratio / (val_ratio + test_ratio), decimals=2) + test_remainder_count = X_remainder.shape[0] * test_remainder_ratio + + # not enough data in remainder for test split, put minimum (1) in test split from train split + if test_remainder_count < 1: + warnings.warn('Not enough data in current NPZ for specified data split.' + 'Returning modified data split') + X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, + test_size=1, + random_state=seed) + X_val, y_val = X_remainder, y_remainder + + return X_train, y_train, X_val, y_val, X_test, y_test + + # split remainder into val and test + X_val, X_test, y_val, y_test = train_test_split(X_remainder, y_remainder, + test_size=test_remainder_ratio, + random_state=seed) + + return X_train, y_train, X_val, y_val, X_test, y_test diff --git a/caliban_toolbox/build_test.py b/caliban_toolbox/build_test.py index 46f8ca8..8a78f6c 100644 --- a/caliban_toolbox/build_test.py +++ b/caliban_toolbox/build_test.py @@ -43,6 +43,18 @@ def _make_npzs(sizes, num_images): return npz_list +def _all_unique_vals(arrays): + unique_vals = [] + for array in arrays: + unique_vals.append(np.unique(array)) + + unique_vals = np.concatenate(unique_vals, axis=0) + + unique_in_all = np.unique(unique_vals) + + return len(unique_vals) == len(unique_in_all) + + def test_compute_cell_size(): labels = np.zeros((3, 40, 40, 1), dtype='int') @@ -73,6 +85,19 @@ def test_compute_cell_size(): cell_sizes = build.compute_cell_size(npz_file=example_npz, method='mean', by_image=False) assert np.round(cell_sizes, 2) == [37.14] # mean across all images + # adding blank images shouldn't change the value returned + labels_blank = np.zeros((5, 40, 40, 1)) + labels_blank[1:4, ...] = labels + + cell_sizes = build.compute_cell_size(npz_file={'y': labels_blank}, method='mean', + by_image=False) + assert np.round(cell_sizes, 2) == [37.14] # mean across all images + + # completely blank image should return None + cell_sizes = build.compute_cell_size(npz_file={'y': np.zeros((3, 40, 40, 1))}, method='mean', + by_image=False) + assert cell_sizes is None + # incorrect method with pytest.raises(ValueError): _ = build.compute_cell_size(npz_file=example_npz, method='bad_method', by_image=True) @@ -82,18 +107,18 @@ def test_compute_cell_size(): _ = build.compute_cell_size(npz_file={'y': labels[0]}, method='bad_method', by_image=True) -def test_reshape_training_image(): +def test_reshape_training_data(): # test without resizing or cropping X_data, y_data = np.zeros((5, 40, 40, 3)), np.zeros((5, 40, 40, 2)) resize_ratio = 1 final_size = (40, 40) stride_ratio = 1 - reshaped_X, reshaped_y = build.reshape_training_image(X_data=X_data, - y_data=y_data, - resize_ratio=resize_ratio, - final_size=final_size, - stride_ratio=stride_ratio) + reshaped_X, reshaped_y = build.reshape_training_data(X_data=X_data, + y_data=y_data, + resize_ratio=resize_ratio, + final_size=final_size, + stride_ratio=stride_ratio) assert reshaped_X.shape == X_data.shape assert reshaped_y.shape == y_data.shape @@ -103,25 +128,25 @@ def test_reshape_training_image(): final_size = (40, 40) stride_ratio = 1 - reshaped_X, reshaped_y = build.reshape_training_image(X_data=X_data, - y_data=y_data, - resize_ratio=resize_ratio, - final_size=final_size, - stride_ratio=stride_ratio) + reshaped_X, reshaped_y = build.reshape_training_data(X_data=X_data, + y_data=y_data, + resize_ratio=resize_ratio, + final_size=final_size, + stride_ratio=stride_ratio) assert list(reshaped_X.shape) == [X_data.shape[0] * 2] + list(final_size) + [X_data.shape[-1]] assert list(reshaped_y.shape) == [y_data.shape[0] * 2] + list(final_size) + [y_data.shape[-1]] # test with just resizing X_data, y_data = np.zeros((5, 40, 40, 3)), np.zeros((5, 40, 40, 2)) - resize_ratio = 1 + resize_ratio = 2 final_size = (80, 80) - stride_ratio = 2 + stride_ratio = 1 - reshaped_X, reshaped_y = build.reshape_training_image(X_data=X_data, - y_data=y_data, - resize_ratio=resize_ratio, - final_size=final_size, - stride_ratio=stride_ratio) + reshaped_X, reshaped_y = build.reshape_training_data(X_data=X_data, + y_data=y_data, + resize_ratio=resize_ratio, + final_size=final_size, + stride_ratio=stride_ratio) assert list(reshaped_X.shape) == [X_data.shape[0]] + list(final_size) + [X_data.shape[-1]] assert list(reshaped_y.shape) == [y_data.shape[0]] + list(final_size) + [y_data.shape[-1]] @@ -129,16 +154,48 @@ def test_reshape_training_image(): X_data, y_data = np.zeros((5, 40, 40, 3)), np.zeros((5, 40, 40, 2)) resize_ratio = 2 final_size = (40, 40) - stride_ratio = 2 + stride_ratio = 1 - reshaped_X, reshaped_y = build.reshape_training_image(X_data=X_data, - y_data=y_data, - resize_ratio=resize_ratio, - final_size=final_size, - stride_ratio=stride_ratio) + reshaped_X, reshaped_y = build.reshape_training_data(X_data=X_data, + y_data=y_data, + resize_ratio=resize_ratio, + final_size=final_size, + stride_ratio=stride_ratio) assert list(reshaped_X.shape) == [X_data.shape[0] * 4] + list(X_data.shape[1:]) assert list(reshaped_y.shape) == [y_data.shape[0] * 4] + list(y_data.shape[1:]) + # test with resizing below threshold for increase + X_data, y_data = np.zeros((5, 40, 40, 3)), np.zeros((5, 40, 40, 2)) + resize_ratio = 1.5 + resize_tolerance = 2 + final_size = (40, 40) + stride_ratio = 1 + + reshaped_X, reshaped_y = build.reshape_training_data(X_data=X_data, + y_data=y_data, + resize_ratio=resize_ratio, + final_size=final_size, + stride_ratio=stride_ratio, + tolerance=resize_tolerance) + assert reshaped_X.shape == X_data.shape + assert reshaped_y.shape == y_data.shape + + # test with resizing below threshold for decrease + X_data, y_data = np.zeros((5, 40, 40, 3)), np.zeros((5, 40, 40, 2)) + resize_ratio = 0.65 + resize_tolerance = 2 + final_size = (40, 40) + stride_ratio = 1 + + reshaped_X, reshaped_y = build.reshape_training_data(X_data=X_data, + y_data=y_data, + resize_ratio=resize_ratio, + final_size=final_size, + stride_ratio=stride_ratio, + tolerance=resize_tolerance) + assert reshaped_X.shape == X_data.shape + assert reshaped_y.shape == y_data.shape + def test_pad_image_stack(): # rows and cols both need to be modified @@ -167,100 +224,96 @@ def test_pad_image_stack(): assert np.all(padded_stack[:, 0, 0, 0] == tags) -def test_combine_npz_files(): - # NPZ files are appropriate size and resolution - num_images = [2, 2] - sizes = [(256, 256), (256, 256)] - npz_list = _make_npzs(sizes=sizes, num_images=num_images) - resize_ratios = [[1], [1]] - final_size = (256, 256) - - combined_x, combined_y = build.combine_npz_files(npz_list=npz_list, - resize_ratios=resize_ratios, - final_size=final_size) - - # check that correct number of NPZs present - assert combined_x.shape[0] == np.sum(num_images) - - # check correct size of NPZs - assert combined_x.shape[1:3] == final_size - - # NPZ files need to be cropped - num_images = [2, 2] - sizes = [(512, 512), (512, 512)] - npz_crop_list = _make_npzs(sizes=sizes, num_images=num_images) - resize_ratios = [[1], [1]] - final_size = (256, 256) - - combined_x, combined_y = build.combine_npz_files(npz_list=npz_crop_list, - resize_ratios=resize_ratios, - final_size=final_size) - - # check that correct number of NPZs present - assert combined_x.shape[0] == np.sum(num_images) * 4 - - # check correct size of NPZs - assert combined_x.shape[1:3] == final_size - - # NPZ files need to be resized - num_images = [2, 2] - sizes = [(128, 128), (128, 128)] - npz_resize_list = _make_npzs(sizes=sizes, num_images=num_images) - resize_ratios = [[2], [2]] - final_size = (256, 256) - - combined_x, combined_y = build.combine_npz_files(npz_list=npz_resize_list, - resize_ratios=resize_ratios, - final_size=final_size) - - # check that correct number of NPZs present - assert combined_x.shape[0] == np.sum(num_images) - - # check correct size of NPZs - assert combined_x.shape[1:3] == final_size - - # some need to be cropped, some need to be resized - npz_list = npz_crop_list + npz_resize_list - resize_ratios = [[1], [1], [2], [2]] - final_size = (256, 256) - - combined_npz = build.combine_npz_files(npz_list=npz_list, resize_ratios=resize_ratios, - final_size=final_size) - - combined_x, combined_y = combined_npz - - # check that correct number of NPZs present - assert combined_x.shape[0] == (np.sum(num_images) + np.sum(num_images) * 4) - - # check correct size of NPZs - assert combined_x.shape[1:3] == final_size - - # different resizing for each image in the NPZ - num_images = [2, 2] - sizes = [(256, 256), (256, 256)] - npz_resize_list = _make_npzs(sizes=sizes, num_images=num_images) - resize_ratios = [[1], [1, 2]] - final_size = (256, 256) - - combined_x, combined_y = build.combine_npz_files(npz_list=npz_resize_list, - resize_ratios=resize_ratios, - final_size=final_size) - - # check that correct number of NPZs present - assert combined_x.shape[0] == np.sum(2 + 1 + 4) +def test_train_val_test_split(): + X_data = np.zeros((100, 5, 5, 3)) + y_data = np.zeros((100, 5, 5, 1)) + + unique_vals = np.arange(100) + for val in unique_vals: + X_data[val, ...] = val + 1 + y_data[val, ...] = -(val + 1) + + train_ratio, val_ratio, test_ratio = 0.7, 0.2, 0.1 + + X_train, y_train, X_val, y_val, X_test, y_test = \ + build.train_val_test_split(X_data=X_data, + y_data=y_data, + data_split=[train_ratio, val_ratio, test_ratio], + seed=1337) + + assert X_train.shape[0] == 100 * train_ratio + assert y_train.shape[0] == 100 * train_ratio + + assert X_val.shape[0] == 100 * val_ratio + assert y_val.shape[0] == 100 * val_ratio + + assert X_test.shape[0] == 100 * test_ratio + assert y_test.shape[0] == 100 * test_ratio + + assert _all_unique_vals((X_train, y_train, X_val, y_val, X_test, y_test)) + + # rerun split with same seed + rerun = build.train_val_test_split(X_data=X_data, y_data=y_data, + data_split=[train_ratio, val_ratio, test_ratio], + seed=1337) + + # make sure identical data with same seed + for version1, version2 in zip((X_train, y_train, X_val, y_val, X_test, y_test), rerun): + assert np.array_equal(version1, version2) + + # rerun split with different seed + rerun = build.train_val_test_split(X_data=X_data, y_data=y_data, + data_split=[train_ratio, val_ratio, test_ratio], + seed=666) + + # make sure different data with different seed + for version1, version2 in zip((X_train, y_train, X_val, y_val, X_test, y_test), rerun): + assert not np.array_equal(version1, version2) + + # ensure that None is returned for val and test when data is not large enough to be split + X_train, y_train, X_val, y_val, X_test, y_test, = \ + build.train_val_test_split(X_data=X_data[:1], + y_data=y_data[:1], + data_split=[train_ratio, val_ratio, test_ratio]) + assert X_train.shape[0] == y_train.shape[0] == 1 + assert np.all([val is None for val in [X_val, y_val, X_test, y_test]]) + + # ensure that None is returned for test when data is not large enough to be split + X_train, y_train, X_val, y_val, X_test, y_test, = \ + build.train_val_test_split(X_data=X_data[:2], + y_data=y_data[:2], + data_split=[train_ratio, val_ratio, test_ratio]) + assert np.all([val.shape[0] == 1 for val in [X_train, y_train, X_val, y_val]]) + assert np.all([val is None for val in [X_test, y_test]]) + assert _all_unique_vals((X_train, y_train, X_val, y_val)) + + # Adjust data appropriately when split sizes will result in zero values for val and test + X_train, y_train, X_val, y_val, X_test, y_test, = \ + build.train_val_test_split(X_data=X_data[:5], + y_data=y_data[:5], + data_split=[0.8, 0.1, 0.1]) + assert X_train.shape[0] == y_train.shape[0] == 3 + assert np.all([val.shape[0] == 1 for val in [X_val, y_val, X_test, y_test]]) + assert _all_unique_vals((X_train, y_train, X_val, y_val, X_test, y_test)) + + # Adjust data appropriately when split sizes will result in zero values for test + X_train, y_train, X_val, y_val, X_test, y_test, = \ + build.train_val_test_split(X_data=X_data[:9], + y_data=y_data[:9], + data_split=[0.8, 0.1, 0.1]) + assert X_train.shape[0] == y_train.shape[0] == 7 + assert np.all([val.shape[0] == 1 for val in [X_val, y_val, X_test, y_test]]) + assert _all_unique_vals((X_train, y_train, X_val, y_val, X_test, y_test)) + + # data split includes 0 for one of splits + with pytest.raises(ValueError): + _ = build.train_val_test_split(X_data=X_data, y_data=y_data, data_split=[0, 0.5, 0.5]) - # check correct size of NPZs - assert combined_x.shape[1:3] == final_size + # data split sums to more than 1 + with pytest.raises(ValueError): + _ = build.train_val_test_split(X_data=X_data, y_data=y_data, data_split=[0.5, 0.5, 0.5]) - # mismatch between resize_ratios and npz size + # different sizes for X and y with pytest.raises(ValueError): - # different resizing for each image in the NPZ - num_images = [4, 2] - sizes = [(256, 256), (256, 256)] - npz_resize_list = _make_npzs(sizes=sizes, num_images=num_images) - resize_ratios = [[1, 1, 1], [1, 2]] - final_size = (256, 256) - - _, _ = build.combine_npz_files(npz_list=npz_resize_list, - resize_ratios=resize_ratios, - final_size=final_size) + _ = build.train_val_test_split(X_data=X_data[:5], y_data=y_data, + data_split=[0.5, 0.5, 0.5]) diff --git a/caliban_toolbox/dataset_benchmarker.py b/caliban_toolbox/dataset_benchmarker.py new file mode 100644 index 0000000..57810fd --- /dev/null +++ b/caliban_toolbox/dataset_benchmarker.py @@ -0,0 +1,121 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +from deepcell_toolbox.metrics import Metrics, stats_pixelbased +from scipy.stats import hmean + + +class DatasetBenchmarker(object): + """Class to perform benchmarking across different tissue and platform types + + Args: + y_true: true labels + y_pred: predicted labels + tissue_list: list of tissue names for each image + platform_list: list of platform names for each image + model_name: name of the model used to generate the predictions + metrics_kwargs: arguments to be passed to metrics package + + Raises: + ValueError: if y_true and y_pred have different shapes + ValueError: if y_true and y_pred are not 4D + ValueError: if tissue_ids or platform_ids is not same length as labels + """ + def __init__(self, + y_true, + y_pred, + tissue_list, + platform_list, + model_name, + metrics_kwargs={}): + if y_true.shape != y_pred.shape: + raise ValueError('Shape mismatch: y_true has shape {}, ' + 'y_pred has shape {}. Labels must have the same' + 'shape.'.format(y_true.shape, y_pred.shape)) + if len(y_true.shape) != 4: + raise ValueError('Data must be 4D, supplied data is {}'.format(y_true.shape)) + + self.y_true = y_true + self.y_pred = y_pred + + if len({y_true.shape[0], len(tissue_list), len(platform_list)}) != 1: + raise ValueError('Tissue_list and platform_list must have same length as labels') + + self.tissue_list = tissue_list + self.platform_list = platform_list + self.model_name = model_name + self.metrics = Metrics(model_name, **metrics_kwargs) + + def _benchmark_category(self, category_ids): + """Compute benchmark stats over the different categories in supplied list + + Args: + category_ids: list specifying which category each image belongs to + + Returns: + dict: benchmarking results across each category + """ + + unique_categories = np.unique(category_ids) + + # create dict to hold stats across each category + stats_dict = {} + for cat in unique_categories: + # cat the index of metrics corresponding to current category + cat_idx = np.isin(category_ids, cat) + cat_dict = {} + + # sum metrics across individual images within current category + for key in self.metrics.stats: + cat_dict[key] = self.metrics.stats[key][cat_idx].sum() + + # compute additional metrics not produced by Metrics class + cat_dict['recall'] = cat_dict['correct_detections'] / cat_dict['n_true'] + + cat_dict['precision'] = cat_dict['correct_detections'] / cat_dict['n_pred'] + + cat_dict['f1'] = hmean([cat_dict['recall'], cat_dict['precision']]) + + pixel_stats = stats_pixelbased(self.y_true[cat_idx] != 0, + self.y_pred[cat_idx] != 0) + cat_dict['jaccard'] = pixel_stats['jaccard'] + + # save current category dict to overall dict + stats_dict[cat] = cat_dict + + return stats_dict + + def benchmark(self): + self.metrics.calc_object_stats(self.y_true, self.y_pred) + + tissue_stats = self._benchmark_category(category_ids=self.tissue_list) + platform_stats = self._benchmark_category(category_ids=self.platform_list) + all_stats = self._benchmark_category(category_ids=['all'] * len(self.tissue_list)) + tissue_stats['all'] = all_stats['all'] + platform_stats['all'] = all_stats['all'] + + return tissue_stats, platform_stats diff --git a/caliban_toolbox/dataset_benchmarker_test.py b/caliban_toolbox/dataset_benchmarker_test.py new file mode 100644 index 0000000..90bb45e --- /dev/null +++ b/caliban_toolbox/dataset_benchmarker_test.py @@ -0,0 +1,113 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import pytest + +import numpy as np + +from caliban_toolbox.dataset_benchmarker import DatasetBenchmarker + + +def _create_labels(offset=0): + labels = np.zeros((5, 100, 100, 1)) + + base_label = np.zeros((1, 100, 100, 1)) + base_label[0, :20, :20] = 1 + base_label[0, 20:34, 30:50] = 2 + base_label[0, 48:52, 10:20] = 3 + base_label[0, 82:100, 70:90] = 4 + + labels[:, offset:, offset:] = base_label[:, :(100 - offset), :(100 - offset)] + + return labels + + +def test__init__(): + y_true, y_pred = _create_labels(), _create_labels() + tissue_list = ['tissue{}'.format(i) for i in range(5)] + platform_list = ['platform{}'.format(i) for i in range(5)] + db = DatasetBenchmarker(y_true=y_true, y_pred=y_pred, tissue_list=tissue_list, + platform_list=platform_list, model_name='test') + + with pytest.raises(ValueError, match='Shape mismatch'): + _ = DatasetBenchmarker(y_true=y_true, y_pred=y_pred[0], tissue_list=tissue_list, + platform_list=platform_list, model_name='test') + + with pytest.raises(ValueError, match='Data must be 4D'): + _ = DatasetBenchmarker(y_true=y_true[0], y_pred=y_pred[0], tissue_list=tissue_list, + platform_list=platform_list, model_name='test') + + with pytest.raises(ValueError, match='Tissue_list and platform_list'): + _ = DatasetBenchmarker(y_true=y_true, y_pred=y_pred, tissue_list=tissue_list[1:], + platform_list=platform_list, model_name='test') + + with pytest.raises(ValueError, match='Tissue_list and platform_list'): + _ = DatasetBenchmarker(y_true=y_true, y_pred=y_pred, tissue_list=tissue_list, + platform_list=platform_list[1:], model_name='test') + + +def test__benchmark_category(): + # perfect agreement + y_true_category_1, y_pred_category_1 = _create_labels(), _create_labels() + + # small offset between labels + y_true_category_2, y_pred_category_2 = _create_labels(), _create_labels(offset=3) + + # large offset between labels + y_true_category_3, y_pred_category_3 = _create_labels(), _create_labels(offset=5) + + y_true = np.concatenate((y_true_category_1, y_true_category_2, y_true_category_3)) + y_pred = np.concatenate((y_pred_category_1, y_pred_category_2, y_pred_category_3)) + tissue_list = ['tissue1'] * 5 + ['tissue2'] * 5 + ['tissue3'] * 5 + platform_list = ['platform1'] * 15 + + # initialize + db = DatasetBenchmarker(y_true=y_true, y_pred=y_pred, tissue_list=tissue_list, + platform_list=platform_list, model_name='test') + db.metrics.calc_object_stats(y_true, y_pred) + + # compute across tissues + stats_dict = db._benchmark_category(category_ids=tissue_list) + + assert stats_dict['tissue1']['recall'] == 1 + assert stats_dict['tissue1']['jaccard'] == 1 + + assert stats_dict['tissue2']['recall'] > stats_dict['tissue3']['recall'] + assert stats_dict['tissue2']['jaccard'] > stats_dict['tissue3']['jaccard'] + + +def test_benchmark(): + y_true, y_pred = _create_labels(), _create_labels(offset=1) + tissue_list = ['tissue1'] * 2 + ['tissue2'] * 3 + platform_list = ['platform1'] * 3 + ['platform2'] * 2 + + db = DatasetBenchmarker(y_true=y_true, y_pred=y_pred, tissue_list=tissue_list, + platform_list=platform_list, model_name='test') + + tissue_stats, platform_stats = db.benchmark() + + assert set(tissue_stats.keys()) == set(tissue_list + ['all']) + assert set(platform_stats.keys()) == set(platform_list + ['all']) diff --git a/caliban_toolbox/dataset_builder.py b/caliban_toolbox/dataset_builder.py new file mode 100644 index 0000000..5dd7b5f --- /dev/null +++ b/caliban_toolbox/dataset_builder.py @@ -0,0 +1,629 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import json +import warnings + +import numpy as np + +from skimage.measure import label +from skimage.morphology import remove_small_objects + +from caliban_toolbox.utils.misc_utils import list_npzs_folder, list_folders +from caliban_toolbox.build import train_val_test_split, reshape_training_data, compute_cell_size + + +class DatasetBuilder(object): + """Class to build a dataset from annotated data + + Args: + dataset_path: path to dataset. Within the dataset, each unique experiment + has its own folder with a dedicated metadata file + """ + def __init__(self, dataset_path): + + self._validate_dataset(dataset_path) + + experiment_folders = list_folders(dataset_path) + self.dataset_path = dataset_path + self.experiment_folders = experiment_folders + + self.all_tissues = [] + self.all_platforms = [] + + # dicts to hold aggregated data + self.train_dict = {} + self.val_dict = {} + self.test_dict = {} + + # parameters for splitting the data + self.data_split = None + self.seed = None + + def _validate_dataset(self, dataset_path): + """Check to make sure that supplied dataset is formatted appropriately + + Args: + dataset_path: path to dataset + + Raises: + ValueError: If dataset_path doesn't exist + ValueError: If dataset_path doesn't contain any folders + ValueError: If dataset_path has any folders without an NPZ file + ValueError: If dataset_path has any folders without a metadata file + """ + + if not os.path.isdir(dataset_path): + raise ValueError('Invalid dataset_path, must be a directory') + experiment_folders = list_folders(dataset_path) + + if experiment_folders == []: + raise ValueError('dataset_path must include at least one folder') + + for folder in experiment_folders: + if not os.path.exists(os.path.join(dataset_path, folder, 'metadata.json')): + raise ValueError('No metadata file found in {}'.format(folder)) + npz_files = list_npzs_folder(os.path.join(dataset_path, folder)) + + if len(npz_files) == 0: + raise ValueError('No NPZ files found in {}'.format(folder)) + + def _get_metadata(self, experiment_folder): + """Get the metadata associated with a specific experiment + + Args: + experiment_folder: folder to get metadata from + + Returns: + dictionary containing relevant metadata""" + + metadata_file = os.path.join(self.dataset_path, experiment_folder, 'metadata.json') + with open(metadata_file) as f: + metadata = json.load(f) + + return metadata + + def _identify_tissue_and_platform_types(self): + """Identify all of the unique tissues and platforms in the dataset""" + + tissues = [] + platforms = [] + for folder in self.experiment_folders: + metadata = self._get_metadata(experiment_folder=folder) + + tissues.append(metadata['tissue']) + platforms.append(metadata['platform']) + + self.all_tissues.extend(tissues) + self.all_platforms.extend(platforms) + + def _load_experiment(self, experiment_path): + """Load the NPZ files present in a single experiment folder + + Args: + experiment_path: the full path to a folder of NPZ files and metadata file + + Returns: + tuple of X and y data from all NPZ files in the experiment + tissue: the tissue type of this experiment + platform: the platform type of this experiment + """ + + X_list = [] + y_list = [] + + # get all NPZ files present in current experiment directory + npz_files = list_npzs_folder(experiment_path) + for file in npz_files: + npz_path = os.path.join(experiment_path, file) + training_data = np.load(npz_path) + + X = training_data['X'] + y = training_data['y'] + + X_list.append(X) + y_list.append(y) + + # get associated metadata + metadata = self._get_metadata(experiment_folder=experiment_path) + + # combine all NPZ files together + X = np.concatenate(X_list, axis=0) + y = np.concatenate(y_list, axis=0) + if np.issubdtype(y.dtype, np.floating): + warnings.warn('Converting float labels to integers') + y = y.astype('int64') + + tissue = metadata['tissue'] + platform = metadata['platform'] + + return X, y, tissue, platform + + def _load_all_experiments(self, data_split, seed): + """Loads all experiment data from experiment folder to enable dataset building + + Args: + data_split: tuple specifying the fraction of the dataset for train/val/test + seed: seed for reproducible splitting of dataset + + Raises: + ValueError: If any of the NPZ files have different non-batch dimensions + """ + X_train, X_val, X_test = [], [], [] + y_train, y_val, y_test = [], [], [] + tissue_list_train, tissue_list_val, tissue_list_test = [], [], [] + platform_list_train, platform_list_val, platform_list_test = [], [], [] + + # loop through all experiments + for folder in self.experiment_folders: + # Get all NPZ files from each experiment + folder_path = os.path.join(self.dataset_path, folder) + X, y, tissue, platform = self._load_experiment(folder_path) + + # split data according to specified ratios + X_train_batch, y_train_batch, X_val_batch, y_val_batch, X_test_batch, y_test_batch = \ + train_val_test_split(X_data=X, y_data=y, data_split=data_split, seed=seed) + + # construct list for each split + tissue_list_train_batch = [tissue] * X_train_batch.shape[0] + platform_list_train_batch = [platform] * X_train_batch.shape[0] + X_train.append(X_train_batch) + y_train.append(y_train_batch) + tissue_list_train.append(tissue_list_train_batch) + platform_list_train.append(platform_list_train_batch) + + if X_val_batch is not None: + tissue_list_val_batch = [tissue] * X_val_batch.shape[0] + platform_list_val_batch = [platform] * X_val_batch.shape[0] + X_val.append(X_val_batch) + y_val.append(y_val_batch) + tissue_list_val.append(tissue_list_val_batch) + platform_list_val.append(platform_list_val_batch) + + if X_test_batch is not None: + tissue_list_test_batch = [tissue] * X_test_batch.shape[0] + platform_list_test_batch = [platform] * X_test_batch.shape[0] + X_test.append(X_test_batch) + y_test.append(y_test_batch) + tissue_list_test.append(tissue_list_test_batch) + platform_list_test.append(platform_list_test_batch) + + # make sure that all data has same shape + first_shape = X_train[0].shape + for i in range(1, len(X_train)): + current_shape = X_train[i].shape + if first_shape[1:] != current_shape[1:]: + raise ValueError('Found mismatching dimensions between ' + 'first NPZ and npz at position {}. ' + 'Shapes of {}, {}'.format(i, first_shape, current_shape)) + + # concatenate lists together + X_train = np.concatenate(X_train, axis=0) + X_val = np.concatenate(X_val, axis=0) + X_test = np.concatenate(X_test, axis=0) + + y_train = np.concatenate(y_train, axis=0) + y_val = np.concatenate(y_val, axis=0) + y_test = np.concatenate(y_test, axis=0) + + tissue_list_train = np.concatenate(tissue_list_train, axis=0) + tissue_list_val = np.concatenate(tissue_list_val, axis=0) + tissue_list_test = np.concatenate(tissue_list_test, axis=0) + + platform_list_train = np.concatenate(platform_list_train, axis=0) + platform_list_val = np.concatenate(platform_list_val, axis=0) + platform_list_test = np.concatenate(platform_list_test, axis=0) + + # create combined dicts + train_dict = {'X': X_train, 'y': y_train, 'tissue_list': tissue_list_train, + 'platform_list': platform_list_train} + + val_dict = {'X': X_val, 'y': y_val, 'tissue_list': tissue_list_val, + 'platform_list': platform_list_val} + + test_dict = {'X': X_test, 'y': y_test, 'tissue_list': tissue_list_test, + 'platform_list': platform_list_test} + + self.train_dict = train_dict + self.val_dict = val_dict + self.test_dict = test_dict + self.data_split = data_split + self.seed = seed + + def _subset_data_dict(self, data_dict, tissues, platforms): + """Subsets a dictionary to only include from the specified tissues and platforms + + Args: + data_dict: dictionary to subset from + tissues: list of tissues to include + platforms: list of platforms to include + + Returns: + subset_dict: dictionary containing examples desired data + + Raises: + ValueError: If no matching data for tissue/platform combination + """ + X, y = data_dict['X'], data_dict['y'] + tissue_list, platform_list = data_dict['tissue_list'], data_dict['platform_list'] + + # Identify locations with the correct categories types + tissue_idx = np.isin(tissue_list, tissues) + platform_idx = np.isin(platform_list, platforms) + + # get indices which meet both criteria + combined_idx = tissue_idx * platform_idx + + # check that there is data which meets requirements + if np.sum(combined_idx) == 0: + raise ValueError('No matching data for specified parameters') + + X, y = X[combined_idx], y[combined_idx] + tissue_list = np.array(tissue_list)[combined_idx] + platform_list = np.array(platform_list)[combined_idx] + + subset_dict = {'X': X, 'y': y, 'tissue_list': list(tissue_list), + 'platform_list': list(platform_list)} + return subset_dict + + def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize_target=400, + resize_tolerance=1.5): + """Takes a dictionary of training data and reshapes it to appropriate size + + data_dict: dictionary of training data + resize: flag to control resizing of the data. + Valid arguments: + - False. No resizing + - int/float: resizes by given ratio for all images + - by_tissue. Resizes by median cell size within each tissue type + - by_image. Resizes by median cell size within each image + output_shape: output shape for image data + resize_target: desired median cell size after resizing + resize_tolerance: sets maximum allowable ratio between resize_target and + median cell size before resizing occurs + """ + X, y = data_dict['X'], data_dict['y'] + tissue_list = np.array(data_dict['tissue_list']) + platform_list = np.array(data_dict['platform_list']) + + if not resize: + # no resizing + X_new, y_new = reshape_training_data(X_data=X, y_data=y, resize_ratio=1, + final_size=output_shape, stride_ratio=1) + + # to preserve category labels, we need to figure out how much the array grew by + multiplier = int(X_new.shape[0] / X.shape[0]) + + # then we duplicate the labels in place to match expanded array size + tissue_list_new = [item for item in tissue_list for _ in range(multiplier)] + platform_list_new = [item for item in platform_list for _ in range(multiplier)] + + elif isinstance(resize, (float, int)): + # resized based on supplied value + X_new, y_new = reshape_training_data(X_data=X, y_data=y, resize_ratio=resize, + final_size=output_shape, stride_ratio=1, + tolerance=resize_tolerance) + + # to preserve category labels, we need to figure out how much the array grew by + multiplier = int(X_new.shape[0] / X.shape[0]) + + # then we duplicate the labels in place to match expanded array size + tissue_list_new = [item for item in tissue_list for _ in range(multiplier)] + platform_list_new = [item for item in platform_list for _ in range(multiplier)] + else: + X_new, y_new, tissue_list_new, platform_list_new = [], [], [], [] + + if resize == 'by_tissue': + batch_ids = np.unique(tissue_list) + elif resize == 'by_image': + batch_ids = np.arange(0, X.shape[0]) + else: + raise ValueError('Invalid `resize` value: {}'.format(resize)) + + # loop over each batch + for batch_id in batch_ids: + + # get tissue types that match current tissue type + if isinstance(batch_id, str): + batch_idx = np.isin(tissue_list, batch_id) + + # get boolean index for current image + else: + batch_idx = np.arange(X.shape[0]) == batch_id + + X_batch, y_batch = X[batch_idx], y[batch_idx] + tissue_list_batch = tissue_list[batch_idx] + platform_list_batch = platform_list[batch_idx] + + # compute appropriate resize ratio + median_cell_size = compute_cell_size({'X': X_batch, 'y': y_batch}, by_image=False) + + # check for empty images + if median_cell_size is not None: + resize_ratio = np.sqrt(resize_target / median_cell_size) + else: + resize_ratio = 1 + + # resize the data + X_batch_resized, y_batch_resized = reshape_training_data( + X_data=X_batch, y_data=y_batch, + resize_ratio=resize_ratio, final_size=output_shape, + tolerance=resize_tolerance) + + # to preserve category labels, we need to figure out how much the array grew by + multiplier = int(X_batch_resized.shape[0] / X_batch.shape[0]) + + # then we duplicate the labels in place to match expanded array size + tissue_list_batch = [item for item in tissue_list_batch for _ in range(multiplier)] + platform_list_batch = \ + [item for item in platform_list_batch for _ in range(multiplier)] + + # add each batch onto main list + X_new.append(X_batch_resized) + y_new.append(y_batch_resized) + tissue_list_new.append(tissue_list_batch) + platform_list_new.append(platform_list_batch) + + X_new = np.concatenate(X_new, axis=0) + y_new = np.concatenate(y_new, axis=0) + tissue_list_new = np.concatenate(tissue_list_new, axis=0) + platform_list_new = np.concatenate(platform_list_new, axis=0) + + return {'X': X_new, 'y': y_new, 'tissue_list': tissue_list_new, + 'platform_list': platform_list_new} + + def _clean_labels(self, data_dict, relabel=False, small_object_threshold=0, + min_objects=0): + """Cleans labels prior to creating final dict + + Args: + data_dict: dictionary of training data + relabel: if True, relabels the image with new labels + small_object_threshold: threshold for removing small objects + min_objects: minimum number of objects per image + + Returns: + cleaned_dict: dictionary with cleaned labels + """ + X, y = data_dict['X'], data_dict['y'] + tissue_list = np.array(data_dict['tissue_list']) + platform_list = np.array(data_dict['platform_list']) + keep_idx = np.repeat(True, y.shape[0]) + cleaned_y = np.zeros_like(y) + + # TODO: remove once data QC happens in main toolbox pipeline + for i in range(y.shape[0]): + y_current = y[i, ..., 0] + if relabel: + y_current = label(y_current) + + y_current = remove_small_objects(y_current, min_size=small_object_threshold) + + unique_objects = len(np.unique(y_current)) - 1 + if unique_objects < min_objects: + keep_idx[i] = False + + cleaned_y[i, ..., 0] = y_current + + # subset all dict members to include only relevant images + cleaned_y = cleaned_y[keep_idx] + cleaned_X = X[keep_idx] + cleaned_tissue = tissue_list[keep_idx] + cleaned_platform = platform_list[keep_idx] + + cleaned_dict = {'X': cleaned_X, 'y': cleaned_y, 'tissue_list': list(cleaned_tissue), + 'platform_list': list(cleaned_platform)} + + return cleaned_dict + + def _validate_categories(self, category_list, supplied_categories): + """Check that an appropriate subset of a list of categories was supplied + + Args: + category_list: list of all categories + supplied_categories: specified categories provided by user. Must be either + - a list containing the desired category members + - a string of a single category name + - a string of 'all', in which case all will be used + + Returns: + list: a properly formatted sub_category list + + Raises: + ValueError: if invalid supplied_categories argument + """ + if isinstance(supplied_categories, list): + for cat in supplied_categories: + if cat not in category_list: + raise ValueError('{} is not one of {}'.format(cat, category_list)) + return supplied_categories + elif supplied_categories == 'all': + return category_list + elif supplied_categories in category_list: + return [supplied_categories] + else: + raise ValueError( + 'Specified categories should be "all", one of {}, or a list ' + 'of acceptable tissue types'.format(category_list)) + + def _validate_output_shape(self, output_shape): + """Check that appropriate values were provided for output_shape + + Args: + output_shape: output_shape supplied by the user + + Returns: + list: a properly formatted output_shape + + Raises: + ValueError: If invalid output_shape provided + """ + if not isinstance(output_shape, (list, tuple)): + raise ValueError('output_shape must be either a list of tuples or a tuple') + + if len(output_shape) == 2: + for val in output_shape: + if not isinstance(val, int): + raise ValueError('A list of length two was supplied, but not all ' + 'list items were ints, got {}'.format(val)) + # convert to list with same shape for each split + output_shape = [output_shape, output_shape, output_shape] + return output_shape + elif len(output_shape) == 3: + for sub_shape in output_shape: + if not len(sub_shape) == 2: + raise ValueError('A list of length three was supplied, bu not all ' + 'of the sublists had len 2, got {}'.format(sub_shape)) + for val in sub_shape: + if not isinstance(val, int): + raise ValueError('A list of lists was supplied, but not all ' + 'sub_list items were ints, got {}'.format(val)) + + return output_shape + else: + raise ValueError('output_shape must be a list of length 2 ' + 'or length 3, got {}'.format(output_shape)) + + def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512), resize=False, + data_split=(0.8, 0.1, 0.1), seed=0, **kwargs): + """Construct a dataset for model training and evaluation + + Args: + tissues: which tissues to include. Must be either a list of tissue types, + a single tissue type, or 'all' + platforms: which platforms to include. Must be either a list of platform types, + a single platform type, or 'all' + output_shape: output shape for dataset. Either a single tuple, in which case + train/va/test will all have same size, or a list of three tuples + resize: flag to control resizing the input data. + Valid arguments: + - False. No resizing + - float/int: Resizes all images by supplied value + - by_tissue. Resizes by median cell size within each tissue type + - by_image. Resizes by median cell size within each image + data_split: tuple specifying the fraction of the dataset for train/val/test + seed: seed for reproducible splitting of dataset + **kwargs: other arguments to be passed to helper functions + + Returns: + list of dicts containing the split dataset + + Raises: + ValueError: If invalid resize parameter supplied + """ + if self.all_tissues == []: + self._identify_tissue_and_platform_types() + + # validate inputs + tissues = self._validate_categories(category_list=self.all_tissues, + supplied_categories=tissues) + platforms = self._validate_categories(category_list=self.all_platforms, + supplied_categories=platforms) + + valid_resize = ['by_tissue', 'by_image'] + if resize in valid_resize or not resize: + pass + elif isinstance(resize, (float, int)): + if resize <= 0: + raise ValueError('Resize values must be greater than 0') + else: + raise ValueError('resize must be one of {}, or an integer value'.format(valid_resize)) + + output_shape = self._validate_output_shape(output_shape=output_shape) + + # if any of the split parameters are different we need to reload the dataset + if self.seed != seed or self.data_split != data_split: + self._load_all_experiments(data_split=data_split, seed=seed) + + dicts = [self.train_dict, self.val_dict, self.test_dict] + # process each dict + for idx, current_dict in enumerate(dicts): + # subset dict to include only relevant tissues and platforms + current_dict = self._subset_data_dict(data_dict=current_dict, tissues=tissues, + platforms=platforms) + current_shape = output_shape[idx] + + # if necessary, reshape and resize data to be of correct output size + if current_dict['X'].shape[1:3] != current_shape or resize is not False: + resize_target = kwargs.get('resize_target', 400) + resize_tolerance = kwargs.get('resize_tolerance', 1.5) + current_dict = self._reshape_dict(data_dict=current_dict, resize=resize, + output_shape=current_shape, + resize_target=resize_target, + resize_tolerance=resize_tolerance) + + # clean labels + relabel = kwargs.get('relabel', False) + small_object_threshold = kwargs.get('small_object_threshold', 0) + min_objects = kwargs.get('min_objects', 0) + current_dict = self._clean_labels(data_dict=current_dict, relabel=relabel, + small_object_threshold=small_object_threshold, + min_objects=min_objects) + dicts[idx] = current_dict + return dicts + + def summarize_dataset(self): + """Computes summary statistics for the images in the dataset + + Returns: + dict of cell counts and image counts by tissue + dict of cell counts and image counts by platform + """ + all_y = np.concatenate((self.train_dict['y'], + self.val_dict['y'], + self.test_dict['y']), + axis=0) + all_tissue = np.concatenate((self.train_dict['tissue_list'], + self.val_dict['tissue_list'], + self.test_dict['tissue_list']), + axis=0) + + all_platform = np.concatenate((self.train_dict['platform_list'], + self.val_dict['platform_list'], + self.test_dict['platform_list']), + axis=0) + all_counts = np.zeros(all_y.shape[0]) + for i in range(all_y.shape[0]): + unique_counts = len(np.unique(all_y[i, ..., 0])) - 1 + all_counts[i] = unique_counts + + tissue_dict = {} + for tissue in np.unique(all_tissue): + tissue_idx = np.isin(all_tissue, tissue) + tissue_counts = np.sum(all_counts[tissue_idx]) + tissue_unique = np.sum(tissue_idx) + tissue_dict[tissue] = {'cell_num': tissue_counts, + 'image_num': tissue_unique} + + platform_dict = {} + for platform in np.unique(all_platform): + platform_idx = np.isin(all_platform, platform) + platform_counts = np.sum(all_counts[platform_idx]) + platform_unique = np.sum(platform_idx) + platform_dict[platform] = {'cell_num': platform_counts, + 'image_num': platform_unique} + + return tissue_dict, platform_dict diff --git a/caliban_toolbox/dataset_builder_test.py b/caliban_toolbox/dataset_builder_test.py new file mode 100644 index 0000000..455f48d --- /dev/null +++ b/caliban_toolbox/dataset_builder_test.py @@ -0,0 +1,695 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import json +import pytest + +import numpy as np +from pathlib import Path + +from caliban_toolbox.dataset_builder import DatasetBuilder + + +def _create_test_npz(path, constant_value=1, X_shape=(10, 20, 20, 3), y_shape=(10, 20, 20, 1)): + X_data = np.full(X_shape, constant_value) + y_data = np.full(y_shape, constant_value * 2, dtype='int16') + np.savez(path, X=X_data, y=y_data) + + +def _create_minimal_dataset(path): + """Creates a minimal dataset so that __init__ checks pass""" + exp_path = os.path.join(path, 'example_exp1') + os.makedirs(exp_path) + Path(os.path.join(exp_path, 'metadata.json')).touch() + Path(os.path.join(exp_path, 'example_data.npz')).touch() + + +def _create_test_dataset(path, experiments, tissues, platforms, npz_num): + """Creates an example directory to load data from + + Args: + path: folder to hold datasets + experiments: list of experiment names + tissues: list of tissue types for each experiment + platforms: list of platform types for each experiment + npz_num: number of unique NPZ files within each experiment. The NPZs within + each experiment are constant-valued arrays corresponding to the index of that exp + + Raises: + ValueError: If tissue_list, platform_list, or NPZ_num have different lengths + """ + lengths = [len(x) for x in [experiments, tissues, platforms, npz_num]] + if len(set(lengths)) != 1: + raise ValueError('All inputs must have the same length') + + for i in range(len(experiments)): + experiment_folder = os.path.join(path, experiments[i]) + os.makedirs(experiment_folder) + + metadata = dict() + metadata['tissue'] = tissues[i] + metadata['platform'] = platforms[i] + + metadata_path = os.path.join(experiment_folder, 'metadata.json') + + with open(metadata_path, 'w') as write_file: + json.dump(metadata, write_file) + + for npz in range(npz_num[i]): + _create_test_npz(path=os.path.join(experiment_folder, 'sub_exp_{}.npz'.format(npz)), + constant_value=i) + + +def _create_test_dict(tissues, platforms): + data = [] + for i in range(len(tissues)): + current_data = np.full((5, 40, 40, 3), i) + data.append(current_data) + + data = np.concatenate(data, axis=0) + X_data = data + y_data = data[..., :1].astype('int16') + + tissue_list = [tissues[i] for i in range(len(tissues)) for _ in range(5)] + platform_list = [platforms[i] for i in range(len(platforms)) for _ in range(5)] + + return {'X': X_data, 'y': y_data, 'tissue_list': tissue_list, 'platform_list': platform_list} + + +def mocked_compute_cell_size(data_dict, by_image): + """Mocks compute cell size so we don't need to create synthetic data with correct cell size""" + X = data_dict['X'] + constant_val = X[0, 0, 0, 0] + + # The default resize is 400. We want to create median cell sizes that divide evenly + # into that number when computing the desired resize ratio + + # even constant_vals will return a median cell size 1/4 the size of the target, odds 4x + if constant_val % 2 == 0: + cell_size = 100 + else: + cell_size = 1600 + + return cell_size + + +def test__init__(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + assert db.dataset_path == tmp_path + + # bad path + with pytest.raises(ValueError): + _ = DatasetBuilder(dataset_path='bad_path') + + +def test__validate_dataset(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(dataset_path=tmp_path) + + # bad path + with pytest.raises(ValueError): + db._validate_dataset('bad_path') + + dataset_path = os.path.join(tmp_path, 'example_dataset') + os.makedirs(dataset_path) + + # no folders in supplied dataset + with pytest.raises(ValueError): + db._validate_dataset(dataset_path) + + os.makedirs(os.path.join(dataset_path, 'experiment_1')) + Path(os.path.join(dataset_path, 'experiment_1', 'example_file.npz')).touch() + + # supplied experiment has an NPZ and no metadata file + with pytest.raises(ValueError): + db._validate_dataset(tmp_path) + + # directory has a metadata file and no NPZ + os.remove(os.path.join(dataset_path, 'experiment_1', 'example_file.npz')) + Path(os.path.join(dataset_path, 'experiment_1', 'metadata.json')).touch() + + with pytest.raises(ValueError): + db._validate_dataset(os.path.join(tmp_path)) + + +def test__get_metadata(tmp_path): + tissues = ['tissue1', 'tissue2'] + platforms = ['platform1', 'platform2'] + experiments = ['exp1', 'exp2'] + npzs = [1, 1] + + _create_test_dataset(path=tmp_path, experiments=experiments, platforms=platforms, + tissues=tissues, npz_num=npzs) + + db = DatasetBuilder(tmp_path) + for i in range(len(experiments)): + metadata = db._get_metadata(os.path.join(tmp_path, experiments[i])) + assert metadata['tissue'] == tissues[i] + assert metadata['platform'] == platforms[i] + + +def test__identify_tissue_and_platform_types(tmp_path): + # create dataset + experiments = ['exp{}'.format(i) for i in range(5)] + tissues = ['tissue1', 'tissue2', 'tissue3', 'tissue2', 'tissue1'] + platforms = ['platform1', 'platform1', 'platform2', 'platform2', 'platform3'] + npz_num = [1] * 5 + _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues, + platforms=platforms, npz_num=npz_num) + + db = DatasetBuilder(dataset_path=tmp_path) + + db._identify_tissue_and_platform_types() + + # check that all tissues and platforms added + assert set(db.all_tissues) == set(tissues) + assert set(db.all_platforms) == set(platforms) + + +def test__load_experiment_single_npz(tmp_path): + experiments, tissues, platforms, npz_num = ['exp1'], ['tissue1'], ['platform1'], [1] + _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues, + platforms=platforms, npz_num=npz_num) + + # initialize db + db = DatasetBuilder(tmp_path) + + # load dataset + X, y, tissue, platform = db._load_experiment(os.path.join(tmp_path, experiments[0])) + + # A single NPZ with 10 images + assert X.shape[0] == 10 + assert y.shape[0] == 10 + + assert tissue == tissues[0] + assert platform == platforms[0] + + +def test__load_experiment_multiple_npz(tmp_path): + experiments, tissues, platforms, npz_num = ['exp1'], ['tissue1'], ['platform1'], [5] + _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues, + platforms=platforms, npz_num=npz_num) + + # initialize db + db = DatasetBuilder(tmp_path) + + # load dataset + X, y, tissue, platform = db._load_experiment(os.path.join(tmp_path, experiments[0])) + + # 5 NPZs with 10 images each + assert X.shape[0] == 50 + assert y.shape[0] == 50 + + assert tissue == tissues[0] + assert platform == platforms[0] + + +def test__load_all_experiments(tmp_path): + # create dataset + experiments = ['exp{}'.format(i) for i in range(5)] + tissues = ['tissue1', 'tissue2', 'tissue3', 'tissue4', 'tissue5'] + platforms = ['platform5', 'platform4', 'platform3', 'platform2', 'platform1'] + npz_num = [2, 2, 4, 6, 8] + _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues, + platforms=platforms, npz_num=npz_num) + + total_img_num = np.sum(npz_num) * 10 + + # initialize db + db = DatasetBuilder(tmp_path) + db._identify_tissue_and_platform_types() + + train_ratio, val_ratio, test_ratio = 0.7, 0.2, 0.1 + + db._load_all_experiments(data_split=[train_ratio, val_ratio, test_ratio], seed=None) + + # get outputs + train_dict, val_dict, test_dict = db.train_dict, db.val_dict, db.test_dict + + # check that splits were performed correctly + for ratio, dict in zip((train_ratio, val_ratio, test_ratio), + (train_dict, val_dict, test_dict)): + + X_data, y_data = dict['X'], dict['y'] + assert X_data.shape[0] == ratio * total_img_num + assert y_data.shape[0] == ratio * total_img_num + + tissue_list, platform_list = dict['tissue_list'], dict['platform_list'] + assert len(tissue_list) == len(platform_list) == X_data.shape[0] + + # check that the metadata maps to the correct images + for dict in (train_dict, val_dict, test_dict): + X_data, tissue_list, platform_list = dict['X'], dict['tissue_list'], dict['platform_list'] + + # loop over each tissue type, and check that the NPZ is filled with correct constant value + for constant_val, tissue in enumerate(tissues): + + # index of images with matching tissue type + tissue_idx = tissue_list == tissue + + images = X_data[tissue_idx] + assert np.all(images == constant_val) + + # loop over each platform type, and check that the NPZ contains correct constant value + for constant_val, platform in enumerate(platforms): + + # index of images with matching platform type + platform_idx = platform_list == platform + + images = X_data[platform_idx] + assert np.all(images == constant_val) + + +def test__subset_data_dict(tmp_path): + _create_minimal_dataset(tmp_path) + + X = np.arange(100) + y = np.arange(100) + tissue_list = ['tissue1'] * 10 + ['tissue2'] * 50 + ['tissue3'] * 40 + platform_list = ['platform1'] * 20 + ['platform2'] * 40 + ['platform3'] * 40 + data_dict = {'X': X, 'y': y, 'tissue_list': tissue_list, 'platform_list': platform_list} + + db = DatasetBuilder(tmp_path) + + # all tissues, one platform + tissues = ['tissue1', 'tissue2', 'tissue3'] + platforms = ['platform1'] + subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) + X_subset = subset_dict['X'] + keep_idx = np.isin(platform_list, platforms) + + assert np.all(X_subset == X[keep_idx]) + + # all platforms, one tissue + tissues = ['tissue2'] + platforms = ['platform1', 'platform2', 'platform3'] + subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) + X_subset = subset_dict['X'] + keep_idx = np.isin(tissue_list, tissues) + + assert np.all(X_subset == X[keep_idx]) + + # drop tissue 1 and platform 3 + tissues = ['tissue2', 'tissue3'] + platforms = ['platform1', 'platform2'] + subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) + X_subset = subset_dict['X'] + platform_keep_idx = np.isin(platform_list, platforms) + tissue_keep_idx = np.isin(tissue_list, tissues) + keep_idx = np.logical_and(platform_keep_idx, tissue_keep_idx) + + assert np.all(X_subset == X[keep_idx]) + + # tissue/platform combination that doesn't exist + tissues = ['tissue1'] + platforms = ['platform3'] + with pytest.raises(ValueError): + _ = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) + + +def test__reshape_dict_no_resize(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + # create dict + tissues = ['tissue1', 'tissue2', 'tissue3'] + platforms = ['platform1', 'platform2', 'platform3'] + data_dict = _create_test_dict(tissues=tissues, platforms=platforms) + + # this is 1/2 the size on each dimension as original, so we expect 4x more crops + output_shape = (20, 20) + + reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=False, output_shape=output_shape) + X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] + assert X_reshaped.shape[1:3] == output_shape + + assert X_reshaped.shape[0] == 4 * data_dict['X'].shape[0] + + # make sure that for each tissue, the arrays with correct value have correct tissue label + for constant_val, tissue in enumerate(tissues): + tissue_idx = X_reshaped[:, 0, 0, 0] == constant_val + tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] + assert np.all(tissue_labels == tissue) + + +def test__reshape_dict_by_value(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + # create dict + tissues = ['tissue1', 'tissue2', 'tissue3'] + platforms = ['platform1', 'platform2', 'platform3'] + data_dict = _create_test_dict(tissues=tissues, platforms=platforms) + + # same size as input data + output_shape = (40, 40) + + reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=3, + output_shape=output_shape) + X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] + assert X_reshaped.shape[1:3] == output_shape + + # make sure that for each tissue, the arrays with correct value have correct tissue label + for constant_val, tissue in enumerate(tissues): + # each image was tagged with a different, compute that here + image_val = np.max(X_reshaped, axis=(1, 2, 3)) + + tissue_idx = image_val == constant_val + tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] + assert np.all(tissue_labels == tissue) + + # There were originally 5 images of each tissue type. Each dimension was resized 3x, + # so there should be 9x more images + assert len(tissue_labels) == 5 * 9 + + # now with a resize to make images smaller + reshaped_dict = db._reshape_dict(data_dict=data_dict, resize=0.5, + output_shape=output_shape) + X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] + assert X_reshaped.shape[1:3] == output_shape + + # make sure that for each tissue, the arrays with correct value have correct tissue label + for constant_val, tissue in enumerate(tissues): + # each image was tagged with a different, compute that here + image_val = np.max(X_reshaped, axis=(1, 2, 3)) + + tissue_idx = image_val == constant_val + tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] + assert np.all(tissue_labels == tissue) + + # There were originally 5 images of each tissue type. Each dimension was resized 0.5, + # and because the images are padded there should be the same total number of images + assert len(tissue_labels) == 5 + + +def test__reshape_dict_by_tissue(tmp_path, mocker): + mocker.patch('caliban_toolbox.dataset_builder.compute_cell_size', mocked_compute_cell_size) + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + # create dict + tissues = ['tissue1', 'tissue2', 'tissue3'] + platforms = ['platform1', 'platform2', 'platform3'] + data_dict = _create_test_dict(tissues=tissues, platforms=platforms) + + # same size as input data + output_shape = (40, 40) + + reshaped_dict = db._reshape_dict(data_dict=data_dict, resize='by_tissue', + output_shape=output_shape) + X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] + assert X_reshaped.shape[1:3] == output_shape + + # make sure that for each tissue, the arrays with correct value have correct tissue label + for constant_val, tissue in enumerate(tissues): + # each image was tagged with a different, compute that here + image_val = np.max(X_reshaped, axis=(1, 2, 3)) + + tissue_idx = image_val == constant_val + tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] + assert np.all(tissue_labels == tissue) + + # There were originally 5 images of each tissue type. Tissue types with even values + # are resized to be 2x larger on each dimension, and should have 4x more images + if constant_val % 2 == 0: + assert len(tissue_labels) == 5 * 4 + # tissue types with odd values are resized to be smaller, which leads to same number + # of unique images due to padding + else: + assert len(tissue_labels) == 5 + + +# TODO: Is there a way to check the resize value of each unique image? +def test__reshape_dict_by_image(tmp_path, mocker): + mocker.patch('caliban_toolbox.dataset_builder.compute_cell_size', mocked_compute_cell_size) + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + # create dict + tissues = ['tissue1', 'tissue2', 'tissue3'] + platforms = ['platform1', 'platform2', 'platform3'] + data_dict = _create_test_dict(tissues=tissues, platforms=platforms) + + # same size as input data + output_shape = (40, 40) + + reshaped_dict = db._reshape_dict(data_dict=data_dict, resize='by_image', + output_shape=output_shape) + X_reshaped, tissue_list_reshaped = reshaped_dict['X'], reshaped_dict['tissue_list'] + assert X_reshaped.shape[1:3] == output_shape + + # make sure that for each tissue, the arrays with correct value have correct tissue label + for constant_val, tissue in enumerate(tissues): + # each image was tagged with a different, compute that here + image_val = np.max(X_reshaped, axis=(1, 2, 3)) + + tissue_idx = image_val == constant_val + tissue_labels = np.array(tissue_list_reshaped)[tissue_idx] + assert np.all(tissue_labels == tissue) + + # There were originally 5 images of each tissue type. Tissue types with even values + # are resized to be 2x larger on each dimension, and should have 4x more images + if constant_val % 2 == 0: + assert len(tissue_labels) == 5 * 4 + # tissue types with odd values are resized to be smaller, which leads to same number + # of unique images due to padding + else: + assert len(tissue_labels) == 5 + + +def test__clean_labels(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + test_label = np.zeros((50, 50), dtype='int') + test_label[:10, :10] = 2 + test_label[12:17, 12:17] = 2 + test_label[20:22, 22:23] = 3 + + test_labels = np.zeros((2, 50, 50, 1), dtype='int') + test_labels[0, ..., 0] = test_label + + test_X = np.zeros_like(test_labels) + test_tissue = ['tissue1', 'tissue2'] + test_platform = ['platform2', 'platform3'] + + test_dict = {'X': test_X, 'y': test_labels, 'tissue_list': test_tissue, + 'platform_list': test_platform} + + # relabel sequential + cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=False) + assert len(np.unique(cleaned_dict['y'])) == 2 + 1 # 0 for background + + # true relabel + cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=True) + assert len(np.unique(cleaned_dict['y'])) == 3 + 1 + + # remove small objects + cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=True, + small_object_threshold=15) + assert len(np.unique(cleaned_dict['y'])) == 2 + 1 + + # remove sparse images + cleaned_dict = db._clean_labels(data_dict=test_dict, relabel=True, min_objects=1) + assert cleaned_dict['y'].shape[0] == 1 + assert cleaned_dict['X'].shape[0] == 1 + assert len(cleaned_dict['tissue_list']) == 1 + assert cleaned_dict['tissue_list'][0] == 'tissue1' + assert len(cleaned_dict['platform_list']) == 1 + assert cleaned_dict['platform_list'][0] == 'platform2' + + +def test__validate_categories(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + category_list = ['cat1', 'cat2', 'cat3'] + + # convert single category to list + supplied_categories = 'cat1' + validated = db._validate_categories(category_list=category_list, + supplied_categories=supplied_categories) + assert validated == [supplied_categories] + + # convert 'all' to list of all categories + supplied_categories = 'all' + validated = db._validate_categories(category_list=category_list, + supplied_categories=supplied_categories) + assert np.all(validated == category_list) + + # convert 'all' to list of all categories + supplied_categories = ['cat1', 'cat3'] + validated = db._validate_categories(category_list=category_list, + supplied_categories=supplied_categories) + assert np.all(validated == supplied_categories) + + # invalid string + supplied_categories = 'cat4' + with pytest.raises(ValueError): + _ = db._validate_categories(category_list=category_list, + supplied_categories=supplied_categories) + + # invalid list + supplied_categories = ['cat4', 'cat1'] + with pytest.raises(ValueError): + _ = db._validate_categories(category_list=category_list, + supplied_categories=supplied_categories) + + +def test__validate_output_shape(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + # make sure list or tuple is converted + output_shapes = [[222, 333], (222, 333)] + for output_shape in output_shapes: + validated_shape = db._validate_output_shape(output_shape) + assert validated_shape == [output_shape, output_shape, output_shape] + + # not all splits specified + output_shape = [(123, 456), (789, 1011)] + with pytest.raises(ValueError): + _ = db._validate_output_shape(output_shape=output_shape) + + # not all splits have 2 entries + output_shape = [(12, 34), (56, 78), (910, 1112, 1314)] + with pytest.raises(ValueError): + _ = db._validate_output_shape(output_shape=output_shape) + + # too many splits + output_shape = [(12, 34), (56, 78), (910, 1112), (1314, )] + with pytest.raises(ValueError): + _ = db._validate_output_shape(output_shape=output_shape) + + # not a list/tuple + output_shape = 56 + with pytest.raises(ValueError): + _ = db._validate_output_shape(output_shape=output_shape) + + +def test_build_dataset(tmp_path): + # create dataset + experiments = ['exp{}'.format(i) for i in range(5)] + tissues = ['tissue1', 'tissue2', 'tissue3', 'tissue4', 'tissue5'] + platforms = ['platform5', 'platform4', 'platform3', 'platform2', 'platform1'] + npz_num = [2, 2, 4, 6, 8] + _create_test_dataset(tmp_path, experiments=experiments, tissues=tissues, + platforms=platforms, npz_num=npz_num) + + db = DatasetBuilder(tmp_path) + + # dataset with all data included + output_dicts = db.build_dataset(tissues=tissues, platforms=platforms, output_shape=(20, 20)) + + for dict in output_dicts: + # make sure correct tissues and platforms loaded + current_tissues = dict['tissue_list'] + current_platforms = dict['platform_list'] + assert set(current_tissues) == set(tissues) + assert set(current_platforms) == set(platforms) + + # dataset with only a subset included + tissues, platforms = tissues[:3], platforms[:3] + output_dicts = db.build_dataset(tissues=tissues, platforms=platforms, output_shape=(20, 20)) + + for dict in output_dicts: + # make sure correct tissues and platforms loaded + current_tissues = dict['tissue_list'] + current_platforms = dict['platform_list'] + assert set(current_tissues) == set(tissues) + assert set(current_platforms) == set(platforms) + + # cropping to 1/2 the size, there should be 4x more crops + output_dicts_crop = db.build_dataset(tissues=tissues, platforms=platforms, + output_shape=(10, 10), relabel=True) + + for base_dict, crop_dict in zip(output_dicts, output_dicts_crop): + X_base, X_crop = base_dict['X'], crop_dict['X'] + assert X_base.shape[0] * 4 == X_crop.shape[0] + + # check that NPZs have been relabeled + for current_dict in output_dicts_crop: + assert len(np.unique(current_dict['y'])) == 2 + + # different sizes for different splits + output_dicts_diff_sizes = db.build_dataset(tissues=tissues, platforms=platforms, + output_shape=[(10, 10), (15, 15), (20, 20)]) + + assert output_dicts_diff_sizes[0]['X'].shape[1:3] == (10, 10) + assert output_dicts_diff_sizes[1]['X'].shape[1:3] == (15, 15) + assert output_dicts_diff_sizes[2]['X'].shape[1:3] == (20, 20) + + # full runthrough with default options changed + _ = db.build_dataset(tissues='all', platforms=platforms, output_shape=(10, 10), + relabel_hard=True, resize='by_image', small_object_threshold=5) + + +def test_summarize_dataset(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + # create dict + tissues = ['tissue1', 'tissue2', 'tissue3'] + platforms = ['platform1', 'platform2', 'platform3'] + train_dict = _create_test_dict(tissues=tissues, platforms=platforms) + val_dict = _create_test_dict(tissues=tissues[1:], platforms=platforms[1:]) + test_dict = _create_test_dict(tissues=tissues[:-1], platforms=platforms[:-1]) + + # make sure each dict has 2 cells in every image for counting purposes + for current_dict in [train_dict, val_dict, test_dict]: + current_labels = current_dict['y'] + current_labels[:, 0, 0, 0] = 5 + current_labels[:, 10, 0, 0] = 12 + + current_dict['y'] = current_labels + + db.train_dict = train_dict + db.val_dict = val_dict + db.test_dict = test_dict + + tissue_dict, platform_dict = db.summarize_dataset() + + # check that all tissues and platforms are present + for i in range(len(tissues)): + assert tissues[i] in tissue_dict + assert platforms[i] in platform_dict + + # Check that math is computed correctly + for dict in [tissue_dict, platform_dict]: + for key in list(dict.keys()): + + # each image has only two cells + cell_num = dict[key]['cell_num'] + image_num = dict[key]['image_num'] + assert cell_num == image_num * 2 + + # middle categories are present in all three dicts, and hence have 15 + if key in ['tissue2', 'platform2']: + assert image_num == 15 + else: + assert image_num == 10 diff --git a/caliban_toolbox/utils/misc_utils.py b/caliban_toolbox/utils/misc_utils.py index faec3db..82013f7 100644 --- a/caliban_toolbox/utils/misc_utils.py +++ b/caliban_toolbox/utils/misc_utils.py @@ -119,3 +119,19 @@ def list_npzs_folder(npz_dir): npz_list = sorted_nicely(npz_list) return npz_list + + +def list_folders(base_dir): + """Lists all folders in current directory + + Args: + base_dir: directory with folders + + Returns: + list of folders in base_dir, empty if None + """ + + files = os.listdir(base_dir) + folders = [file for file in files if os.path.isdir(os.path.join(base_dir, file))] + + return folders diff --git a/caliban_toolbox/utils/misc_utils_test.py b/caliban_toolbox/utils/misc_utils_test.py index a17811f..88acd13 100644 --- a/caliban_toolbox/utils/misc_utils_test.py +++ b/caliban_toolbox/utils/misc_utils_test.py @@ -27,6 +27,7 @@ from __future__ import division from __future__ import print_function +import os import numpy as np from caliban_toolbox.utils import misc_utils @@ -41,3 +42,17 @@ def test_sorted_nicely(): expected = ['test_0_0', 'test_1_0', 'test_1_1'] unsorted = ['test_1_1', 'test_0_0', 'test_1_0'] assert(np.array_equal(expected, misc_utils.sorted_nicely(unsorted))) + + +def test_list_folders(tmp_path): + # test with no folders + empty_dir = misc_utils.list_folders(tmp_path) + + assert empty_dir == [] + + os.makedirs(os.path.join(tmp_path, 'folder1')) + os.makedirs(os.path.join(tmp_path, 'folder2')) + + folders = misc_utils.list_folders((tmp_path)) + + assert set(folders) == {'folder1', 'folder2'}