Skip to content

Commit

Permalink
Balanced datasets (#118)
Browse files Browse the repository at this point in the history
* migrate dataset_builder to new branch

* switched from lists to arrays for metadata

* removed old kwarg

* simplify argument name
  • Loading branch information
ngreenwald committed Sep 4, 2020
1 parent 1dc8b57 commit aff807a
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 36 deletions.
109 changes: 86 additions & 23 deletions caliban_toolbox/dataset_builder.py
Expand Up @@ -51,8 +51,8 @@ def __init__(self, dataset_path):
self.dataset_path = dataset_path
self.experiment_folders = experiment_folders

self.all_tissues = []
self.all_platforms = []
self.all_tissues = None
self.all_platforms = None

# dicts to hold aggregated data
self.train_dict = {}
Expand Down Expand Up @@ -117,8 +117,8 @@ def _identify_tissue_and_platform_types(self):
tissues.append(metadata['tissue'])
platforms.append(metadata['platform'])

self.all_tissues.extend(tissues)
self.all_platforms.extend(platforms)
self.all_tissues = np.array(tissues)
self.all_platforms = np.array(platforms)

def _load_experiment(self, experiment_path):
"""Load the NPZ files present in a single experiment folder
Expand Down Expand Up @@ -282,11 +282,11 @@ def _subset_data_dict(self, data_dict, tissues, platforms):
raise ValueError('No matching data for specified parameters')

X, y = X[combined_idx], y[combined_idx]
tissue_list = np.array(tissue_list)[combined_idx]
platform_list = np.array(platform_list)[combined_idx]
tissue_list = tissue_list[combined_idx]
platform_list = platform_list[combined_idx]

subset_dict = {'X': X, 'y': y, 'tissue_list': list(tissue_list),
'platform_list': list(platform_list)}
subset_dict = {'X': X, 'y': y, 'tissue_list': tissue_list,
'platform_list': platform_list}
return subset_dict

def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize_target=400,
Expand All @@ -306,8 +306,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize
median cell size before resizing occurs
"""
X, y = data_dict['X'], data_dict['y']
tissue_list = np.array(data_dict['tissue_list'])
platform_list = np.array(data_dict['platform_list'])
tissue_list = data_dict['tissue_list']
platform_list = data_dict['platform_list']

if not resize:
# no resizing
Expand All @@ -318,8 +318,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize
multiplier = int(X_new.shape[0] / X.shape[0])

# then we duplicate the labels in place to match expanded array size
tissue_list_new = [item for item in tissue_list for _ in range(multiplier)]
platform_list_new = [item for item in platform_list for _ in range(multiplier)]
tissue_list_new = np.repeat(tissue_list, multiplier)
platform_list_new = np.repeat(platform_list, multiplier)

elif isinstance(resize, (float, int)):
# resized based on supplied value
Expand All @@ -331,8 +331,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize
multiplier = int(X_new.shape[0] / X.shape[0])

# then we duplicate the labels in place to match expanded array size
tissue_list_new = [item for item in tissue_list for _ in range(multiplier)]
platform_list_new = [item for item in platform_list for _ in range(multiplier)]
tissue_list_new = np.repeat(tissue_list, multiplier)
platform_list_new = np.repeat(platform_list, multiplier)
else:
X_new, y_new, tissue_list_new, platform_list_new = [], [], [], []

Expand Down Expand Up @@ -377,9 +377,8 @@ def _reshape_dict(self, data_dict, resize=False, output_shape=(512, 512), resize
multiplier = int(X_batch_resized.shape[0] / X_batch.shape[0])

# then we duplicate the labels in place to match expanded array size
tissue_list_batch = [item for item in tissue_list_batch for _ in range(multiplier)]
platform_list_batch = \
[item for item in platform_list_batch for _ in range(multiplier)]
tissue_list_batch = np.repeat(tissue_list_batch, multiplier)
platform_list_batch = np.repeat(platform_list_batch, multiplier)

# add each batch onto main list
X_new.append(X_batch_resized)
Expand Down Expand Up @@ -409,8 +408,8 @@ def _clean_labels(self, data_dict, relabel=False, small_object_threshold=0,
cleaned_dict: dictionary with cleaned labels
"""
X, y = data_dict['X'], data_dict['y']
tissue_list = np.array(data_dict['tissue_list'])
platform_list = np.array(data_dict['platform_list'])
tissue_list = data_dict['tissue_list']
platform_list = data_dict['platform_list']
keep_idx = np.repeat(True, y.shape[0])
cleaned_y = np.zeros_like(y)

Expand All @@ -434,11 +433,68 @@ def _clean_labels(self, data_dict, relabel=False, small_object_threshold=0,
cleaned_tissue = tissue_list[keep_idx]
cleaned_platform = platform_list[keep_idx]

cleaned_dict = {'X': cleaned_X, 'y': cleaned_y, 'tissue_list': list(cleaned_tissue),
'platform_list': list(cleaned_platform)}
cleaned_dict = {'X': cleaned_X, 'y': cleaned_y, 'tissue_list': cleaned_tissue,
'platform_list': cleaned_platform}

return cleaned_dict

def _balance_dict(self, data_dict, seed, category):
"""Balance a dictionary of training data so that each category is equally represented
Args:
data_dict: dictionary of training data
seed: seed for random duplication of less-represented classes
category: name of the key in the dictionary to use for balancing
Returns:
dict: training data that has been balanced
"""

np.random.seed(seed)
category_list = data_dict[category]

unique_categories, unique_counts = np.unique(category_list, return_counts=True)
max_counts = np.max(unique_counts)

# original variables
X_unbalanced, y_unbalanced = data_dict['X'], data_dict['y']
tissue_unbalanced = np.array(data_dict['tissue_list'])
platform_unbalanced = np.array(data_dict['platform_list'])

# create list to hold balanced versions
X_balanced, y_balanced, tissue_balanced, platform_balanced = [], [], [], []
for category in unique_categories:
cat_idx = category == category_list
X_cat, y_cat = X_unbalanced[cat_idx], y_unbalanced[cat_idx]
tissue_cat, platform_cat = tissue_unbalanced[cat_idx], platform_unbalanced[cat_idx]

category_counts = X_cat.shape[0]
if category_counts == max_counts:
# we don't need to balance, as this category already has max number of examples
X_balanced.append(X_cat)
y_balanced.append(y_cat)
tissue_balanced.append(tissue_cat)
platform_balanced.append(platform_cat)
else:
# randomly select max_counts number of indices to upsample data
balance_idx = np.random.choice(range(category_counts), size=max_counts,
replace=True)

# index into each array using random index to generate randomly upsampled version
X_balanced.append(X_cat[balance_idx])
y_balanced.append(y_cat[balance_idx])
tissue_balanced.append(tissue_cat[balance_idx])
platform_balanced.append(platform_cat[balance_idx])

# combine balanced versions of each category into single array
X_balanced = np.concatenate(X_balanced, axis=0)
y_balanced = np.concatenate(y_balanced, axis=0)
tissue_balanced = np.concatenate(tissue_balanced, axis=0)
platform_balanced = np.concatenate(platform_balanced, axis=0)

return {'X': X_balanced, 'y': y_balanced, 'tissue_list': tissue_balanced,
'platform_list': platform_balanced}

def _validate_categories(self, category_list, supplied_categories):
"""Check that an appropriate subset of a list of categories was supplied
Expand Down Expand Up @@ -508,7 +564,7 @@ def _validate_output_shape(self, output_shape):
'or length 3, got {}'.format(output_shape))

def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512), resize=False,
data_split=(0.8, 0.1, 0.1), seed=0, **kwargs):
data_split=(0.8, 0.1, 0.1), seed=0, balance=False, **kwargs):
"""Construct a dataset for model training and evaluation
Args:
Expand All @@ -526,6 +582,8 @@ def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512),
- by_image. Resizes by median cell size within each image
data_split: tuple specifying the fraction of the dataset for train/val/test
seed: seed for reproducible splitting of dataset
balance: if true, randomly duplicate less-represented tissue types
in train and val splits so that there are the same number of images of each type
**kwargs: other arguments to be passed to helper functions
Returns:
Expand All @@ -534,7 +592,7 @@ def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512),
Raises:
ValueError: If invalid resize parameter supplied
"""
if self.all_tissues == []:
if self.all_tissues is None:
self._identify_tissue_and_platform_types()

# validate inputs
Expand Down Expand Up @@ -582,6 +640,11 @@ def build_dataset(self, tissues='all', platforms='all', output_shape=(512, 512),
current_dict = self._clean_labels(data_dict=current_dict, relabel=relabel,
small_object_threshold=small_object_threshold,
min_objects=min_objects)

# don't balance test split
if balance and idx != 2:
current_dict = self._balance_dict(current_dict, seed=seed, category='tissue_list')

dicts[idx] = current_dict
return dicts

Expand Down
88 changes: 75 additions & 13 deletions caliban_toolbox/dataset_builder_test.py
Expand Up @@ -93,8 +93,8 @@ def _create_test_dict(tissues, platforms):
X_data = data
y_data = data[..., :1].astype('int16')

tissue_list = [tissues[i] for i in range(len(tissues)) for _ in range(5)]
platform_list = [platforms[i] for i in range(len(platforms)) for _ in range(5)]
tissue_list = np.repeat(tissues, 5)
platform_list = np.repeat(platforms, 5)

return {'X': X_data, 'y': y_data, 'tissue_list': tissue_list, 'platform_list': platform_list}

Expand Down Expand Up @@ -290,8 +290,8 @@ def test__subset_data_dict(tmp_path):

X = np.arange(100)
y = np.arange(100)
tissue_list = ['tissue1'] * 10 + ['tissue2'] * 50 + ['tissue3'] * 40
platform_list = ['platform1'] * 20 + ['platform2'] * 40 + ['platform3'] * 40
tissue_list = np.array(['tissue1'] * 10 + ['tissue2'] * 50 + ['tissue3'] * 40)
platform_list = np.array(['platform1'] * 20 + ['platform2'] * 40 + ['platform3'] * 40)
data_dict = {'X': X, 'y': y, 'tissue_list': tissue_list, 'platform_list': platform_list}

db = DatasetBuilder(tmp_path)
Expand All @@ -306,17 +306,17 @@ def test__subset_data_dict(tmp_path):
assert np.all(X_subset == X[keep_idx])

# all platforms, one tissue
tissues = ['tissue2']
platforms = ['platform1', 'platform2', 'platform3']
tissues = np.array(['tissue2'])
platforms = np.array(['platform1', 'platform2', 'platform3'])
subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)
X_subset = subset_dict['X']
keep_idx = np.isin(tissue_list, tissues)

assert np.all(X_subset == X[keep_idx])

# drop tissue 1 and platform 3
tissues = ['tissue2', 'tissue3']
platforms = ['platform1', 'platform2']
tissues = np.array(['tissue2', 'tissue3'])
platforms = np.array(['platform1', 'platform2'])
subset_dict = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)
X_subset = subset_dict['X']
platform_keep_idx = np.isin(platform_list, platforms)
Expand All @@ -326,8 +326,8 @@ def test__subset_data_dict(tmp_path):
assert np.all(X_subset == X[keep_idx])

# tissue/platform combination that doesn't exist
tissues = ['tissue1']
platforms = ['platform3']
tissues = np.array(['tissue1'])
platforms = np.array(['platform3'])
with pytest.raises(ValueError):
_ = db._subset_data_dict(data_dict=data_dict, tissues=tissues, platforms=platforms)

Expand Down Expand Up @@ -495,8 +495,8 @@ def test__clean_labels(tmp_path):
test_labels[0, ..., 0] = test_label

test_X = np.zeros_like(test_labels)
test_tissue = ['tissue1', 'tissue2']
test_platform = ['platform2', 'platform3']
test_tissue = np.array(['tissue1', 'tissue2'])
test_platform = np.array(['platform2', 'platform3'])

test_dict = {'X': test_X, 'y': test_labels, 'tissue_list': test_tissue,
'platform_list': test_platform}
Expand Down Expand Up @@ -524,6 +524,67 @@ def test__clean_labels(tmp_path):
assert cleaned_dict['platform_list'][0] == 'platform2'


def test__balance_dict(tmp_path):
_create_minimal_dataset(tmp_path)
db = DatasetBuilder(tmp_path)

X_data = np.random.rand(9, 10, 10, 3)
y_data = np.random.rand(9, 10, 10, 1)
tissue_list = np.array(['tissue1'] * 3 + ['tissue2'] * 3 + ['tissue3'] * 3)
platform_list = np.array(['platform1'] * 3 + ['platform2'] * 3 + ['platform3'] * 3)

balanced_dict = {'X': X_data, 'y': y_data, 'tissue_list': tissue_list,
'platform_list': platform_list}
output_dict = db._balance_dict(data_dict=balanced_dict, seed=0, category='tissue_list')

# data is already balanced, all items should be identical
for key in output_dict:
assert np.all(output_dict[key] == balanced_dict[key])

# tissue 3 has most, others need to be upsampled
tissue_list = np.array(['tissue1'] * 1 + ['tissue2'] * 2 + ['tissue3'] * 6)
unbalanced_dict = {'X': X_data, 'y': y_data, 'tissue_list': tissue_list,
'platform_list': platform_list}
output_dict = db._balance_dict(data_dict=unbalanced_dict, seed=0, category='tissue_list')

# tissue 3 is unchanged
for key in output_dict:
assert np.all(output_dict[key][-6:] == unbalanced_dict[key][-6:])

# tissue 1 only has a single example, all copies should be equal
tissue1_idx = np.where(output_dict['tissue_list'] == 'tissue1')[0]
for key in output_dict:
vals = output_dict[key]
for idx in tissue1_idx:
new_val = vals[idx]
old_val = unbalanced_dict[key][0]
assert np.all(new_val == old_val)

# tissue 2 has 2 examples, all copies should be equal to one of those values
tissue2_idx = np.where(output_dict['tissue_list'] == 'tissue2')[0]
for key in output_dict:
vals = output_dict[key]
for idx in tissue2_idx:
new_val = vals[idx]
old_val1 = unbalanced_dict[key][1]
old_val2 = unbalanced_dict[key][2]
assert np.all(new_val == old_val1) or np.all(new_val == old_val2)

# check with same seed
output_dict_same_seed = db._balance_dict(data_dict=unbalanced_dict, seed=0,
category='tissue_list')

for key in output_dict_same_seed:
assert np.all(output_dict_same_seed[key] == output_dict[key])

# check with different seed
output_dict_diff_seed = db._balance_dict(data_dict=unbalanced_dict, seed=1,
category='tissue_list')

for key in ['X', 'y']:
assert not np.all(output_dict_diff_seed[key] == output_dict[key])


def test__validate_categories(tmp_path):
_create_minimal_dataset(tmp_path)
db = DatasetBuilder(tmp_path)
Expand Down Expand Up @@ -646,7 +707,8 @@ def test_build_dataset(tmp_path):

# full runthrough with default options changed
_ = db.build_dataset(tissues='all', platforms=platforms, output_shape=(10, 10),
relabel_hard=True, resize='by_image', small_object_threshold=5)
relabel=True, resize='by_image', small_object_threshold=5,
balance=True)


def test_summarize_dataset(tmp_path):
Expand Down

0 comments on commit aff807a

Please sign in to comment.