From aff807a57fbdf9cee91a757d2a9ab5cc8a94a55a Mon Sep 17 00:00:00 2001 From: Noah Greenwald Date: Fri, 4 Sep 2020 13:52:58 -0700 Subject: [PATCH] Balanced datasets (#118) * migrate dataset_builder to new branch * switched from lists to arrays for metadata * removed old kwarg * simplify argument name --- caliban_toolbox/dataset_builder.py | 109 +++++++++++++++++++----- caliban_toolbox/dataset_builder_test.py | 88 ++++++++++++++++--- 2 files changed, 161 insertions(+), 36 deletions(-) diff --git a/caliban_toolbox/dataset_builder.py b/caliban_toolbox/dataset_builder.py index 5dd7b5f..88a901e 100644 --- a/caliban_toolbox/dataset_builder.py +++ b/caliban_toolbox/dataset_builder.py @@ -51,8 +51,8 @@ def __init__(self, dataset_path): self.dataset_path = dataset_path self.experiment_folders = experiment_folders - self.all_tissues = [] - self.all_platforms = [] + self.all_tissues = None + self.all_platforms = None # dicts to hold aggregated data self.train_dict = {} @@ -117,8 +117,8 @@ def _identify_tissue_and_platform_types(self): tissues.append(metadata['tissue']) platforms.append(metadata['platform']) - self.all_tissues.extend(tissues) - self.all_platforms.extend(platforms) + self.all_tissues = np.array(tissues) + self.all_platforms = np.array(platforms) def _load_experiment(self, experiment_path): """Load the NPZ files present in a single experiment folder @@ -282,11 +282,11 @@ def _subset_data_dict(self, data_dict, tissues, platforms): raise ValueError('No matching data for specified parameters') X, y = X[combined_idx], y[combined_idx] - tissue_list = np.array(tissue_list)[combined_idx] - platform_list = np.array(platform_list)[combined_idx] + tissue_list = tissue_list[combined_idx] + platform_list = platform_list[combined_idx] - subset_dict = {'X': X, 'y': y, 'tissue_list': list(tissue_list), - 'platform_list': list(platform_list)} + subset_dict = {'X': X, 'y': y, 'tissue_list': tissue_list, + 'platform_list': platform_list} return subset_dict def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize_target=400, @@ -306,8 +306,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize median cell size before resizing occurs """ X, y = data_dict['X'], data_dict['y'] - tissue_list = np.array(data_dict['tissue_list']) - platform_list = np.array(data_dict['platform_list']) + tissue_list = data_dict['tissue_list'] + platform_list = data_dict['platform_list'] if not resize: # no resizing @@ -318,8 +318,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize multiplier = int(X_new.shape[0] / X.shape[0]) # then we duplicate the labels in place to match expanded array size - tissue_list_new = [item for item in tissue_list for _ in range(multiplier)] - platform_list_new = [item for item in platform_list for _ in range(multiplier)] + tissue_list_new = np.repeat(tissue_list, multiplier) + platform_list_new = np.repeat(platform_list, multiplier) elif isinstance(resize, (float, int)): # resized based on supplied value @@ -331,8 +331,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize multiplier = int(X_new.shape[0] / X.shape[0]) # then we duplicate the labels in place to match expanded array size - tissue_list_new = [item for item in tissue_list for _ in range(multiplier)] - platform_list_new = [item for item in platform_list for _ in range(multiplier)] + tissue_list_new = np.repeat(tissue_list, multiplier) + platform_list_new = np.repeat(platform_list, multiplier) else: X_new, y_new, tissue_list_new, platform_list_new = [], [], [], [] @@ -377,9 +377,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize multiplier = int(X_batch_resized.shape[0] / X_batch.shape[0]) # then we duplicate the labels in place to match expanded array size - tissue_list_batch = [item for item in tissue_list_batch for _ in range(multiplier)] - platform_list_batch = \ - [item for item in platform_list_batch for _ in range(multiplier)] + tissue_list_batch = np.repeat(tissue_list_batch, multiplier) + platform_list_batch = np.repeat(platform_list_batch, multiplier) # add each batch onto main list X_new.append(X_batch_resized) @@ -409,8 +408,8 @@ def _clean_labels(self, data_dict, relabel=False, small_object_threshold=0, cleaned_dict: dictionary with cleaned labels """ X, y = data_dict['X'], data_dict['y'] - tissue_list = np.array(data_dict['tissue_list']) - platform_list = np.array(data_dict['platform_list']) + tissue_list = data_dict['tissue_list'] + platform_list = data_dict['platform_list'] keep_idx = np.repeat(True, y.shape[0]) cleaned_y = np.zeros_like(y) @@ -434,11 +433,68 @@ def _clean_labels(self, data_dict, relabel=False, small_object_threshold=0, cleaned_tissue = tissue_list[keep_idx] cleaned_platform = platform_list[keep_idx] - cleaned_dict = {'X': cleaned_X, 'y': cleaned_y, 'tissue_list': list(cleaned_tissue), - 'platform_list': list(cleaned_platform)} + cleaned_dict = {'X': cleaned_X, 'y': cleaned_y, 'tissue_list': cleaned_tissue, + 'platform_list': cleaned_platform} return cleaned_dict + def _balance_dict(self, data_dict, seed, category): + """Balance a dictionary of training data so that each category is equally represented + + Args: + data_dict: dictionary of training data + seed: seed for random duplication of less-represented classes + category: name of the key in the dictionary to use for balancing + + Returns: + dict: training data that has been balanced + """ + + np.random.seed(seed) + category_list = data_dict[category] + + unique_categories, unique_counts = np.unique(category_list, return_counts=True) + max_counts = np.max(unique_counts) + + # original variables + X_unbalanced, y_unbalanced = data_dict['X'], data_dict['y'] + tissue_unbalanced = np.array(data_dict['tissue_list']) + platform_unbalanced = np.array(data_dict['platform_list']) + + # create list to hold balanced versions + X_balanced, y_balanced, tissue_balanced, platform_balanced = [], [], [], [] + for category in unique_categories: + cat_idx = category == category_list + X_cat, y_cat = X_unbalanced[cat_idx], y_unbalanced[cat_idx] + tissue_cat, platform_cat = tissue_unbalanced[cat_idx], platform_unbalanced[cat_idx] + + category_counts = X_cat.shape[0] + if category_counts == max_counts: + # we don't need to balance, as this category already has max number of examples + X_balanced.append(X_cat) + y_balanced.append(y_cat) + tissue_balanced.append(tissue_cat) + platform_balanced.append(platform_cat) + else: + # randomly select max_counts number of indices to upsample data + balance_idx = np.random.choice(range(category_counts), size=max_counts, + replace=True) + + # index into each array using random index to generate randomly upsampled version + X_balanced.append(X_cat[balance_idx]) + y_balanced.append(y_cat[balance_idx]) + tissue_balanced.append(tissue_cat[balance_idx]) + platform_balanced.append(platform_cat[balance_idx]) + + # combine balanced versions of each category into single array + X_balanced = np.concatenate(X_balanced, axis=0) + y_balanced = np.concatenate(y_balanced, axis=0) + tissue_balanced = np.concatenate(tissue_balanced, axis=0) + platform_balanced = np.concatenate(platform_balanced, axis=0) + + return {'X': X_balanced, 'y': y_balanced, 'tissue_list': tissue_balanced, + 'platform_list': platform_balanced} + def _validate_categories(self, category_list, supplied_categories): """Check that an appropriate subset of a list of categories was supplied @@ -508,7 +564,7 @@ def _validate_output_shape(self, output_shape): 'or length 3, got {}'.format(output_shape)) def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512), resize=False, - data_split=(0.8, 0.1, 0.1), seed=0, **kwargs): + data_split=(0.8, 0.1, 0.1), seed=0, balance=False, **kwargs): """Construct a dataset for model training and evaluation Args: @@ -526,6 +582,8 @@ def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512), - by_image. Resizes by median cell size within each image data_split: tuple specifying the fraction of the dataset for train/val/test seed: seed for reproducible splitting of dataset + balance: if true, randomly duplicate less-represented tissue types + in train and val splits so that there are the same number of images of each type **kwargs: other arguments to be passed to helper functions Returns: @@ -534,7 +592,7 @@ def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512), Raises: ValueError: If invalid resize parameter supplied """ - if self.all_tissues == []: + if self.all_tissues is None: self._identify_tissue_and_platform_types() # validate inputs @@ -582,6 +640,11 @@ def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512), current_dict = self._clean_labels(data_dict=current_dict, relabel=relabel, small_object_threshold=small_object_threshold, min_objects=min_objects) + + # don't balance test split + if balance and idx != 2: + current_dict = self._balance_dict(current_dict, seed=seed, category='tissue_list') + dicts[idx] = current_dict return dicts diff --git a/caliban_toolbox/dataset_builder_test.py b/caliban_toolbox/dataset_builder_test.py index 455f48d..5914adf 100644 --- a/caliban_toolbox/dataset_builder_test.py +++ b/caliban_toolbox/dataset_builder_test.py @@ -93,8 +93,8 @@ def _create_test_dict(tissues, platforms): X_data = data y_data = data[..., :1].astype('int16') - tissue_list = [tissues[i] for i in range(len(tissues)) for _ in range(5)] - platform_list = [platforms[i] for i in range(len(platforms)) for _ in range(5)] + tissue_list = np.repeat(tissues, 5) + platform_list = np.repeat(platforms, 5) return {'X': X_data, 'y': y_data, 'tissue_list': tissue_list, 'platform_list': platform_list} @@ -290,8 +290,8 @@ def test__subset_data_dict(tmp_path): X = np.arange(100) y = np.arange(100) - tissue_list = ['tissue1'] * 10 + ['tissue2'] * 50 + ['tissue3'] * 40 - platform_list = ['platform1'] * 20 + ['platform2'] * 40 + ['platform3'] * 40 + tissue_list = np.array(['tissue1'] * 10 + ['tissue2'] * 50 + ['tissue3'] * 40) + platform_list = np.array(['platform1'] * 20 + ['platform2'] * 40 + ['platform3'] * 40) data_dict = {'X': X, 'y': y, 'tissue_list': tissue_list, 'platform_list': platform_list} db = DatasetBuilder(tmp_path) @@ -306,8 +306,8 @@ def test__subset_data_dict(tmp_path): assert np.all(X_subset == X[keep_idx]) # all platforms, one tissue - tissues = ['tissue2'] - platforms = ['platform1', 'platform2', 'platform3'] + tissues = np.array(['tissue2']) + platforms = np.array(['platform1', 'platform2', 'platform3']) subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) X_subset = subset_dict['X'] keep_idx = np.isin(tissue_list, tissues) @@ -315,8 +315,8 @@ def test__subset_data_dict(tmp_path): assert np.all(X_subset == X[keep_idx]) # drop tissue 1 and platform 3 - tissues = ['tissue2', 'tissue3'] - platforms = ['platform1', 'platform2'] + tissues = np.array(['tissue2', 'tissue3']) + platforms = np.array(['platform1', 'platform2']) subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) X_subset = subset_dict['X'] platform_keep_idx = np.isin(platform_list, platforms) @@ -326,8 +326,8 @@ def test__subset_data_dict(tmp_path): assert np.all(X_subset == X[keep_idx]) # tissue/platform combination that doesn't exist - tissues = ['tissue1'] - platforms = ['platform3'] + tissues = np.array(['tissue1']) + platforms = np.array(['platform3']) with pytest.raises(ValueError): _ = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms) @@ -495,8 +495,8 @@ def test__clean_labels(tmp_path): test_labels[0, ..., 0] = test_label test_X = np.zeros_like(test_labels) - test_tissue = ['tissue1', 'tissue2'] - test_platform = ['platform2', 'platform3'] + test_tissue = np.array(['tissue1', 'tissue2']) + test_platform = np.array(['platform2', 'platform3']) test_dict = {'X': test_X, 'y': test_labels, 'tissue_list': test_tissue, 'platform_list': test_platform} @@ -524,6 +524,67 @@ def test__clean_labels(tmp_path): assert cleaned_dict['platform_list'][0] == 'platform2' +def test__balance_dict(tmp_path): + _create_minimal_dataset(tmp_path) + db = DatasetBuilder(tmp_path) + + X_data = np.random.rand(9, 10, 10, 3) + y_data = np.random.rand(9, 10, 10, 1) + tissue_list = np.array(['tissue1'] * 3 + ['tissue2'] * 3 + ['tissue3'] * 3) + platform_list = np.array(['platform1'] * 3 + ['platform2'] * 3 + ['platform3'] * 3) + + balanced_dict = {'X': X_data, 'y': y_data, 'tissue_list': tissue_list, + 'platform_list': platform_list} + output_dict = db._balance_dict(data_dict=balanced_dict, seed=0, category='tissue_list') + + # data is already balanced, all items should be identical + for key in output_dict: + assert np.all(output_dict[key] == balanced_dict[key]) + + # tissue 3 has most, others need to be upsampled + tissue_list = np.array(['tissue1'] * 1 + ['tissue2'] * 2 + ['tissue3'] * 6) + unbalanced_dict = {'X': X_data, 'y': y_data, 'tissue_list': tissue_list, + 'platform_list': platform_list} + output_dict = db._balance_dict(data_dict=unbalanced_dict, seed=0, category='tissue_list') + + # tissue 3 is unchanged + for key in output_dict: + assert np.all(output_dict[key][-6:] == unbalanced_dict[key][-6:]) + + # tissue 1 only has a single example, all copies should be equal + tissue1_idx = np.where(output_dict['tissue_list'] == 'tissue1')[0] + for key in output_dict: + vals = output_dict[key] + for idx in tissue1_idx: + new_val = vals[idx] + old_val = unbalanced_dict[key][0] + assert np.all(new_val == old_val) + + # tissue 2 has 2 examples, all copies should be equal to one of those values + tissue2_idx = np.where(output_dict['tissue_list'] == 'tissue2')[0] + for key in output_dict: + vals = output_dict[key] + for idx in tissue2_idx: + new_val = vals[idx] + old_val1 = unbalanced_dict[key][1] + old_val2 = unbalanced_dict[key][2] + assert np.all(new_val == old_val1) or np.all(new_val == old_val2) + + # check with same seed + output_dict_same_seed = db._balance_dict(data_dict=unbalanced_dict, seed=0, + category='tissue_list') + + for key in output_dict_same_seed: + assert np.all(output_dict_same_seed[key] == output_dict[key]) + + # check with different seed + output_dict_diff_seed = db._balance_dict(data_dict=unbalanced_dict, seed=1, + category='tissue_list') + + for key in ['X', 'y']: + assert not np.all(output_dict_diff_seed[key] == output_dict[key]) + + def test__validate_categories(tmp_path): _create_minimal_dataset(tmp_path) db = DatasetBuilder(tmp_path) @@ -646,7 +707,8 @@ def test_build_dataset(tmp_path): # full runthrough with default options changed _ = db.build_dataset(tissues='all', platforms=platforms, output_shape=(10, 10), - relabel_hard=True, resize='by_image', small_object_threshold=5) + relabel=True, resize='by_image', small_object_threshold=5, + balance=True) def test_summarize_dataset(tmp_path):