From 857a5dc0624a6de761903fe92bf207e968d40c7c Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 17:18:55 -0700 Subject: [PATCH 01/12] refactored crop pipeline to use separate X and y arrays --- caliban_toolbox/reshape_data.py | 82 ++++++++++++++++------------ caliban_toolbox/reshape_data_test.py | 76 ++++++++++++++++---------- 2 files changed, 96 insertions(+), 62 deletions(-) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index f6be08d..e9895f7 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -121,11 +121,12 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding) return cropped_xr, padded_input.shape -def crop_multichannel_data(data_xr, crop_size, overlap_frac, test_parameters=False): +def crop_multichannel_data(X_data, y_data, crop_size, overlap_frac, test_parameters=False): """Reads in a stack of images and crops them into small pieces for easier annotation Args: - data_xr: xarray to be cropped of size [fovs, stacks, 1, slices, rows, cols, channels] + X_data: xarray containing raw images to be cropped + y_data: xarray containing labeled images to be chopped crop_size: (row_crop, col_crop) tuple specifying shape of the crop overlap_frac: fraction that crops will overlap each other on each edge test_parameters: boolean to determine whether to run all fovs, or only the first @@ -146,26 +147,33 @@ def crop_multichannel_data(data_xr, crop_size, overlap_frac, test_parameters=Fal if overlap_frac < 0 or overlap_frac > 1: raise ValueError('overlap_frac must be between 0 and 1') - if list(data_xr.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']: - raise ValueError('data_xr does not have expected dims, found {}'.format(data_xr.dims)) + if list(X_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']: + raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims)) + + if list(y_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']: + raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims)) # check if testing or running all samples if test_parameters: - data_xr = data_xr[:1, ...] + X_data, y_data = X_data[:1, ...], y_data[:1, ...] # compute the start and end coordinates for the row and column crops - row_starts, row_ends, row_padding = compute_crop_indices(img_len=data_xr.shape[4], + row_starts, row_ends, row_padding = compute_crop_indices(img_len=X_data.shape[4], crop_size=crop_size[0], overlap_frac=overlap_frac) - col_starts, col_ends, col_padding = compute_crop_indices(img_len=data_xr.shape[5], + col_starts, col_ends, col_padding = compute_crop_indices(img_len=X_data.shape[5], crop_size=crop_size[1], overlap_frac=overlap_frac) # crop images - data_xr_cropped, padded_shape = crop_helper(data_xr, row_starts=row_starts, row_ends=row_ends, - col_starts=col_starts, col_ends=col_ends, - padding=(row_padding, col_padding)) + X_data_cropped, padded_shape = crop_helper(X_data, row_starts=row_starts, row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(row_padding, col_padding)) + + y_data_cropped, padded_shape = crop_helper(y_data, row_starts=row_starts, row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(row_padding, col_padding)) # save relevant parameters for reconstructing image log_data = {} @@ -177,9 +185,9 @@ def crop_multichannel_data(data_xr, crop_size, overlap_frac, test_parameters=Fal log_data['col_crop_size'] = crop_size[1] log_data['row_padding'] = int(row_padding) log_data['col_padding'] = int(col_padding) - log_data['num_crops'] = data_xr_cropped.shape[2] + log_data['num_crops'] = X_data_cropped.shape[2] - return data_xr_cropped, log_data + return X_data_cropped, y_data_cropped, log_data def compute_slice_indices(stack_len, slice_len, slice_overlap): @@ -267,11 +275,12 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices): return slice_xr -def create_slice_data(data_xr, slice_stack_len, slice_overlap=0): +def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0): """Takes an array of data and splits it up into smaller pieces along the stack dimension Args: - data_xr: xarray of [fovs, stacks, crops, slices, rows, cols, channels] to be split up + X_data: xarray of raw image data to be split + y_data: xarray of labels to be split slice_stack_len: number of z/t frames in each slice slice_overlap: number of z/t frames in each slice that overlap one another @@ -281,35 +290,41 @@ def create_slice_data(data_xr, slice_stack_len, slice_overlap=0): """ # sanitize inputs - if len(data_xr.shape) != 7: + if len(X_data.shape) != 7: + raise ValueError('invalid input data shape, ' + 'expected array of len(7), got {}'.format(X_data.shape)) + + if len(y_data.shape) != 7: raise ValueError('invalid input data shape, ' - 'expected array of len(7), got {}'.format(data_xr.shape)) + 'expected array of len(7), got {}'.format(y_data.shape)) - if slice_stack_len > data_xr.shape[1]: + if slice_stack_len > X_data.shape[1]: raise ValueError('slice size is greater than stack length') # compute indices for slices - stack_len = data_xr.shape[1] + stack_len = X_data.shape[1] slice_start_indices, slice_end_indices = \ compute_slice_indices(stack_len, slice_stack_len, slice_overlap) - slice_xr = slice_helper(data_xr, slice_start_indices, slice_end_indices) + X_data_slice = slice_helper(X_data, slice_start_indices, slice_end_indices) + y_data_slice = slice_helper(y_data, slice_start_indices, slice_end_indices) log_data = {} log_data['slice_start_indices'] = slice_start_indices.tolist() log_data['slice_end_indices'] = slice_end_indices.tolist() log_data['num_slices'] = len(slice_start_indices) - return slice_xr, log_data + return X_data_slice, y_data_slice, log_data -def save_npzs_for_caliban(resized_xr, original_xr, log_data, save_dir, blank_labels='include', +def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, blank_labels='include', save_format='npz', verbose=True): """Take an array of processed image data and save as NPZ for caliban Args: - resized_xr: 7D tensor of cropped and sliced data - original_xr: the unmodified xarray + X_data: 7D tensor of cropped and sliced raw images + y_data: 7D tensor of cropped and sliced labeled images + original_data: the original unmodified images log_data: data used to reconstruct images save_dir: path to save the npz and JSON files blank_labels: whether to include NPZs with blank labels (poor predictions) @@ -325,7 +340,7 @@ def save_npzs_for_caliban(resized_xr, original_xr, log_data, save_dir, blank_lab num_crops = log_data.get('num_crops', 1) num_slices = log_data.get('num_slices', 1) - fov_names = original_xr.fovs.values + fov_names = original_data.fovs.values fov_len = len(fov_names) if blank_labels not in ['skip', 'include', 'separate']: @@ -340,10 +355,9 @@ def save_npzs_for_caliban(resized_xr, original_xr, log_data, save_dir, blank_lab # generate identifier for crop npz_id = 'fov_{}_crop_{}_slice_{}'.format(fov_names[fov], crop, slice) - # subset xarray based on supplied indices - current_xr = resized_xr[fov, :, crop, slice, ...] - labels = current_xr[..., -1:].values - channels = current_xr[..., :-1].values + # get working batch + labels = y_data[fov, :, crop, slice, ...].values + channels = X_data[fov, :, crop, slice, ...].values # determine if labels are blank, and if so what to do with npz if np.sum(labels) == 0: @@ -359,7 +373,7 @@ def save_npzs_for_caliban(resized_xr, original_xr, log_data, save_dir, blank_lab np.savez(save_path + '.npz', X=channels, y=labels) elif save_format == 'xr': - current_xr.to_netcdf(save_path + '.xr') + raise NotImplementedError() # blank labels don't get saved, empty area of tissue elif blank_labels == 'skip': @@ -377,7 +391,7 @@ def save_npzs_for_caliban(resized_xr, original_xr, log_data, save_dir, blank_lab np.savez(save_path + '.npz', X=channels, y=labels) elif save_format == 'xr': - current_xr.to_netcdf(save_path + '.xr') + raise NotImplementedError() else: # crop is not blank, save based on file_format @@ -388,12 +402,12 @@ def save_npzs_for_caliban(resized_xr, original_xr, log_data, save_dir, blank_lab np.savez(save_path + '.npz', X=channels, y=labels) elif save_format == 'xr': - current_xr.to_netcdf(save_path + '.xr') + raise NotImplementedError() log_data['fov_names'] = fov_names.tolist() - log_data['channel_names'] = original_xr.channels.values.tolist() - log_data['original_shape'] = original_xr.shape - log_data['slice_stack_len'] = resized_xr.shape[1] + log_data['channel_names'] = original_data.channels.values.tolist() + log_data['original_shape'] = original_data.shape + log_data['slice_stack_len'] = X_data.shape[1] log_data['save_format'] = save_format log_path = os.path.join(save_dir, 'log_data.json') diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index d208e6a..c06f753 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -177,18 +177,24 @@ def test_crop_multichannel_data(): overlap_frac = 0.2 # test only one crop - test_xr = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=channel_len) + test_X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=channel_len) - data_xr_cropped, log_data = reshape_data.crop_multichannel_data(data_xr=test_xr, - crop_size=crop_size, - overlap_frac=overlap_frac, - test_parameters=False) + test_y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=channel_len) + + X_data_cropped, y_data_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=test_X_data, + y_data=test_y_data, + crop_size=crop_size, + overlap_frac=overlap_frac, + test_parameters=False) expected_crop_num = len(reshape_data.compute_crop_indices(row_len, crop_size[0], overlap_frac)[0]) ** 2 - assert (data_xr_cropped.shape == (fov_len, stack_len, expected_crop_num, slice_num, + assert (X_data_cropped.shape == (fov_len, stack_len, expected_crop_num, slice_num, crop_size[0], crop_size[1], channel_len)) assert log_data["num_crops"] == expected_crop_num @@ -306,13 +312,17 @@ def test_create_slice_data(): fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, - slice_num=num_slices, row_len=row_len, col_len=col_len, - chan_len=chan_len) + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, + slice_num=num_slices, row_len=row_len, col_len=col_len, + chan_len=chan_len) - slice_xr, slice_indices = reshape_data.create_slice_data(input_data, slice_stack_len) + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, + slice_num=num_slices, row_len=row_len, col_len=col_len, + chan_len=chan_len) - assert slice_xr.shape == (fov_len, slice_stack_len, num_crops, + X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) + + assert X_slice.shape == (fov_len, slice_stack_len, num_crops, int(np.ceil(stack_len / slice_stack_len)), row_len, col_len, chan_len) @@ -321,14 +331,19 @@ def test_save_npzs_for_caliban(): fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, slice_num=num_slices, row_len=row_len, col_len=col_len, chan_len=chan_len) - slice_xr, log_data = reshape_data.create_slice_data(input_data, slice_stack_len) + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, + slice_num=num_slices, + row_len=row_len, col_len=col_len, chan_len=1) + + sliced_X, sliced_y, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(resized_xr=slice_xr, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, original_data=X_data, log_data=copy.copy(log_data), save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) @@ -344,38 +359,41 @@ def test_save_npzs_for_caliban(): with open(os.path.join(temp_dir, "log_data.json")) as json_file: saved_log_data = json.load(json_file) - assert saved_log_data["original_shape"] == list(input_data.shape) + assert saved_log_data["original_shape"] == list(X_data.shape) with tempfile.TemporaryDirectory() as temp_dir: # check that combined crop and slice saving works crop_size = (10, 10) overlap_frac = 0.2 - data_xr_cropped, log_data_crop = \ - reshape_data.crop_multichannel_data(data_xr=slice_xr, + X_cropped, y_cropped, log_data_crop = \ + reshape_data.crop_multichannel_data(X_data=sliced_X, + y_data=sliced_y, crop_size=crop_size, overlap_frac=overlap_frac, test_parameters=False) - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, + original_data=X_data, log_data={**log_data, **log_data_crop}, save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) - expected_crop_num = data_xr_cropped.shape[2] * data_xr_cropped.shape[3] + expected_crop_num = X_cropped.shape[2] * X_cropped.shape[3] files = os.listdir(temp_dir) files = [file for file in files if "npz" in file] assert len(files) == expected_crop_num + # check that arguments specifying what to do with blank crops are working - # set specified crops to not be blank - slice_xr[0, 0, 0, [1, 4, 7], 0, 0, -1] = 27 - np.sum(np.nonzero(slice_xr.values)) - expected_crop_num = slice_xr.shape[2] * slice_xr.shape[3] + # set specified crops to not be blank + sliced_y[0, 0, 0, [1, 4, 7], 0, 0, 0] = 27 + expected_crop_num = sliced_X.shape[2] * sliced_X.shape[3] # test that function correctly includes blank crops when saving with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(resized_xr=slice_xr, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, + original_data=X_data, log_data=copy.copy(log_data), save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) @@ -388,7 +406,8 @@ def test_save_npzs_for_caliban(): # test that function correctly skips blank crops when saving with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(resized_xr=slice_xr, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, + original_data=X_data, log_data=copy.copy(log_data), save_dir=temp_dir, save_format="npz", blank_labels="skip", verbose=False) @@ -400,7 +419,8 @@ def test_save_npzs_for_caliban(): # test that function correctly saves blank crops to separate folder with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(resized_xr=slice_xr, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, + original_data=X_data, log_data=copy.copy(log_data), save_dir=temp_dir, save_format="npz", blank_labels="separate", verbose=False) From 1756509815eb6c120bf3124c152aea731ee1b2d9 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 18:39:43 -0700 Subject: [PATCH 02/12] refactored test suite for separate X and y --- caliban_toolbox/reshape_data.py | 33 +-- caliban_toolbox/reshape_data_test.py | 336 +++++++++++++++------------ 2 files changed, 204 insertions(+), 165 deletions(-) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index e9895f7..33a726c 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -356,6 +356,8 @@ def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, bla npz_id = 'fov_{}_crop_{}_slice_{}'.format(fov_names[fov], crop, slice) # get working batch + print('y_data shape is {}'.format(y_data.shape)) + print('indices are {}'.format((fov, crop, slice))) labels = y_data[fov, :, crop, slice, ...].values channels = X_data[fov, :, crop, slice, ...].values @@ -497,21 +499,22 @@ def load_npzs(crop_dir, log_data, verbose=True): # load xarray elif save_format == 'xr': - xr_path = os.path.join(crop_dir, get_saved_file_path(saved_files, fov_names[fov], - crop, slice)) - if os.path.exists(xr_path): - temp_xr = xr.open_dataarray(xr_path) - - # last slice may be truncated, modify index - if slice == num_slices - 1: - current_stack_len = temp_xr.shape[1] - else: - current_stack_len = stack_len - - stack[fov, :current_stack_len, crop, slice, ...] = temp_xr[..., -1:] - else: - # npz not generated, did not contain any labels, keep blank - print('could not find xr {}, skipping'.format(xr_path)) + raise NotImplementedError() + # xr_path = os.path.join(crop_dir, get_saved_file_path(saved_files, fov_names[fov], + # crop, slice)) + # if os.path.exists(xr_path): + # temp_xr = xr.open_dataarray(xr_path) + # + # # last slice may be truncated, modify index + # if slice == num_slices - 1: + # current_stack_len = temp_xr.shape[1] + # else: + # current_stack_len = stack_len + # + # stack[fov, :current_stack_len, crop, slice, ...] = temp_xr[..., -1:] + # else: + # # npz not generated, did not contain any labels, keep blank + # print('could not find xr {}, skipping'.format(xr_path)) return stack diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index c06f753..14a528c 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -478,32 +478,39 @@ def test_load_npzs(): row_len, col_len, chan_len = 50, 50, 3 slice_stack_len = 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) # slice the data - slice_xr, log_data = reshape_data.create_slice_data(input_data, slice_stack_len) + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, + slice_stack_len) # crop the data crop_size = (10, 10) overlap_frac = 0.2 - data_xr_cropped, log_data_crop = \ + X_cropped, y_cropped, log_data_crop = \ reshape_data.crop_multichannel_data( - data_xr=slice_xr, + X_data=X_slice, + y_data=y_slice, crop_size=crop_size, overlap_frac=overlap_frac, test_parameters=False) + # tag the upper left hand corner of the label in each slice - slice_tags = np.arange(data_xr_cropped.shape[3]) - crop_tags = np.arange(data_xr_cropped.shape[2]) - data_xr_cropped[0, 0, :, 0, 0, 0, 2] = crop_tags - data_xr_cropped[0, 0, 0, :, 0, 0, 2] = slice_tags + slice_tags = np.arange(y_cropped.shape[3]) + crop_tags = np.arange(y_cropped.shape[2]) + y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags + y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags combined_log_data = {**log_data, **log_data_crop} # save the tagged data - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=combined_log_data, save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) @@ -514,7 +521,7 @@ def test_load_npzs(): loaded_slices = reshape_data.load_npzs(temp_dir, saved_log_data, verbose=False) # dims other than channels are the same - assert (np.all(loaded_slices.shape[:-1] == data_xr_cropped.shape[:-1])) + assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1])) assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) @@ -526,33 +533,38 @@ def test_load_npzs(): row_len, col_len, chan_len = 50, 50, 3 slice_stack_len = 7 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, chan_len=chan_len) + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + # slice the data - slice_xr, log_data = reshape_data.create_slice_data(input_data, slice_stack_len) + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) # crop the data crop_size = (10, 10) overlap_frac = 0.2 - data_xr_cropped, log_data_crop = \ + X_cropped, y_cropped, log_data_crop = \ reshape_data.crop_multichannel_data( - data_xr=slice_xr, + X_data=X_slice, + y_data=y_slice, crop_size=crop_size, overlap_frac=overlap_frac, test_parameters=False) # tag the upper left hand corner of the annotations in each slice - slice_tags = np.arange(data_xr_cropped.shape[3]) - crop_tags = np.arange(data_xr_cropped.shape[2]) - data_xr_cropped[0, 0, :, 0, 0, 0, 2] = crop_tags - data_xr_cropped[0, 0, 0, :, 0, 0, 2] = slice_tags + slice_tags = np.arange(y_cropped.shape[3]) + crop_tags = np.arange(X_cropped.shape[2]) + y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags + y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags combined_log_data = {**log_data, **log_data_crop} # save the tagged data - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=combined_log_data, save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) @@ -560,7 +572,7 @@ def test_load_npzs(): loaded_slices = reshape_data.load_npzs(temp_dir, combined_log_data) # dims other than channels are the same - assert (np.all(loaded_slices.shape[:-1] == data_xr_cropped.shape[:-1])) + assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1])) assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) @@ -570,30 +582,36 @@ def test_stitch_crops(): # generate stack of crops from image with grid pattern fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, chan_len=chan_len) + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + # create image with artificial objects to be segmented cell_idx = 1 for i in range(12): for j in range(11): - for fov in range(input_data.shape[0]): - input_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 3] = cell_idx + for fov in range(y_data.shape[0]): + y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), + (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx cell_idx += 1 # crop the image crop_size, overlap_frac = 400, 0.2 - cropped, log_data = reshape_data.crop_multichannel_data(data_xr=input_data, - crop_size=(crop_size, crop_size), - overlap_frac=overlap_frac) - cropped_labels = cropped[..., -1:].values - log_data["original_shape"] = input_data.shape + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=(crop_size, crop_size), + overlap_frac=overlap_frac) + + log_data["original_shape"] = X_data.shape # stitch the crops back together - stitched_img = reshape_data.stitch_crops(annotated_data=cropped_labels, log_data=log_data) + stitched_img = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) # trim padding row_padding, col_padding = log_data["row_padding"], log_data["col_padding"] @@ -602,31 +620,31 @@ def test_stitch_crops(): if col_padding > 0: stitched_img = stitched_img[:, :, :, :, :, :-col_padding, :] - # dims other than channels are the same - assert np.all(stitched_img.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_img.shape == y_data.shape) # check that objects are at same location - assert (np.all(np.equal(stitched_img[..., 0] > 0, input_data.values[..., 3] > 0))) + assert (np.all(np.equal(stitched_img[..., 0] > 0, y_data.values[..., 0] > 0))) # check that same number of unique objects - assert len(np.unique(stitched_img)) == len(np.unique(input_data.values)) + assert len(np.unique(stitched_img)) == len(np.unique(y_data.values)) # test stitching imperfect annotator labels that slightly overlap # generate stack of crops from image with grid pattern fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 1, 1, 1, 800, 800, 1 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) side_len = 40 - cell_num = input_data.shape[4] // side_len + cell_num = y_data.shape[4] // side_len cell_id = np.arange(1, cell_num ** 2 + 1) cell_id = np.random.choice(cell_id, cell_num ** 2, replace=False) cell_idx = 0 for row in range(cell_num): for col in range(cell_num): - input_data[0, 0, 0, 0, row * side_len:(row + 1) * side_len, + y_data[0, 0, 0, 0, row * side_len:(row + 1) * side_len, col * side_len:(col + 1) * side_len, 0] = cell_id[cell_idx] cell_idx += 1 @@ -647,10 +665,10 @@ def test_stitch_crops(): row_starts, row_ends = starts + row_offset, ends + row_offset col_starts, col_ends = starts + col_offset, ends + col_offset - cropped, padded = reshape_data.crop_helper(input_data=input_data, row_starts=row_starts, - row_ends=row_ends, - col_starts=col_starts, col_ends=col_ends, - padding=(padding, padding)) + y_cropped, padded = reshape_data.crop_helper(input_data=y_data, row_starts=row_starts, + row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(padding, padding)) # generate log data, since we had to go inside the upper level # function to modify crop_helper inputs @@ -665,14 +683,12 @@ def test_stitch_crops(): log_data["num_col_crops"] = len(col_starts) log_data["row_padding"] = int(padding) log_data["col_padding"] = int(padding) - log_data["num_crops"] = cropped.shape[2] - log_data["original_shape"] = input_data.shape - log_data["fov_names"] = input_data.fovs.values.tolist() - log_data["channel_names"] = input_data.channels.values.tolist() - - cropped_labels = cropped[..., -1:].values + log_data["num_crops"] = y_cropped.shape[2] + log_data["original_shape"] = y_data.shape + log_data["fov_names"] = y_data.fovs.values.tolist() + log_data["channel_names"] = y_data.channels.values.tolist() - stitched_img = reshape_data.stitch_crops(annotated_data=cropped_labels, log_data=log_data) + stitched_img = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) # trim padding stitched_img = stitched_img[:, :, :, :, :-padding, :-padding, :] @@ -681,11 +697,11 @@ def test_stitch_crops(): props = skimage.measure.regionprops_table(relabeled, properties=["area", "label"]) - # dims other than channels are the same - assert np.all(stitched_img.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_img.shape == y_data.shape) # same number of unique objects before and after - assert (len(np.unique(relabeled)) == len(np.unique(input_data[0, 0, 0, 0, :, :, 0]))) + assert (len(np.unique(relabeled)) == len(np.unique(y_data[0, 0, 0, 0, :, :, 0]))) # no cell is smaller than offset subtracted from each side min_size = (side_len - offset_len * 2) ** 2 @@ -701,29 +717,33 @@ def test_reconstruct_image_data(): fov_len, stack_len, crop_num, slice_num = 2, 1, 1, 1 row_len, col_len, chan_len = 400, 400, 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) # create image with cell_idx = 1 for i in range(12): for j in range(11): - for fov in range(input_data.shape[0]): - input_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 3] = cell_idx + for fov in range(y_data.shape[0]): + y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), + (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx cell_idx += 1 crop_size, overlap_frac = 40, 0.2 # crop data - data_xr_cropped, log_data = \ - reshape_data.crop_multichannel_data(data_xr=input_data, + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, y_data=y_data, crop_size=(crop_size, crop_size), overlap_frac=0.2) # stitch data - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=log_data, save_dir=temp_dir, verbose=False) @@ -731,26 +751,27 @@ def test_reconstruct_image_data(): stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - # dims other than channels are the same - assert np.all(stitched_xr.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_xr.shape == y_data.shape) # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, input_data[:, :, 0] > 0))) + assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(input_data))) + assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) with tempfile.TemporaryDirectory() as temp_dir: # test single crop in x crop_size, overlap_frac = (400, 40), 0.2 # crop data - data_xr_cropped, log_data = reshape_data.crop_multichannel_data(data_xr=input_data, - crop_size=crop_size, - overlap_frac=0.2) + X_cropped, y_cropped, log_data = reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=crop_size, + overlap_frac=0.2) # stitch data - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=log_data, save_dir=temp_dir, verbose=False) @@ -758,26 +779,27 @@ def test_reconstruct_image_data(): stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - # dims other than channels are the same - assert np.all(stitched_xr.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_xr.shape == y_data.shape) # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, input_data[:, :, 0] > 0))) + assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(input_data))) + assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) with tempfile.TemporaryDirectory() as temp_dir: # test single crop in both crop_size, overlap_frac = (400, 400), 0.2 # crop data - data_xr_cropped, log_data = reshape_data.crop_multichannel_data(data_xr=input_data, - crop_size=crop_size, - overlap_frac=0.2) + X_cropped, y_cropped, log_data = reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=crop_size, + overlap_frac=0.2) # stitch data - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=log_data, save_dir=temp_dir, verbose=False) @@ -785,95 +807,98 @@ def test_reconstruct_image_data(): stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - # dims other than channels are the same - assert np.all(stitched_xr.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_xr.shape == y_data.shape) # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, input_data[:, :, 0] > 0))) + assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(input_data))) + assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) def test_stitch_slices(): - # generate data - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 + with tempfile.TemporaryDirectory() as temp_dir: + # generate data + (fov_len, stack_len, crop_num, + slice_num, row_len, col_len, chan_len) = 2, 1, 1, 1, 400, 400, 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, chan_len=chan_len) - # create image with - cell_idx = 1 - for i in range(12): - for j in range(11): - for fov in range(input_data.shape[0]): - input_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 3] = cell_idx - cell_idx += 1 + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=1) - crop_size, overlap_frac = 50, 0.2 - save_dir = "tests/caliban_toolbox/test_crop_and_stitch" + # create image with + cell_idx = 1 + for i in range(12): + for j in range(11): + for fov in range(y_data.shape[0]): + y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), + (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx + cell_idx += 1 - # crop data - data_xr_cropped, log_data = \ - reshape_data.crop_multichannel_data(data_xr=input_data, - crop_size=(crop_size, crop_size), - overlap_frac=0.2) + crop_size, overlap_frac = 50, 0.2 - # stitch data - reshape_data.save_npzs_for_caliban(resized_xr=data_xr_cropped, original_xr=input_data, - log_data=log_data, - save_dir=save_dir) + # crop data + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, y_data=y_data, + crop_size=(crop_size, crop_size), + overlap_frac=0.2) - reshape_data.reconstruct_image_stack(crop_dir=save_dir) + # stitch data + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, + log_data=log_data, + save_dir=temp_dir) - stitched_xr = xr.open_dataarray(os.path.join(save_dir, "stitched_images.nc")) + reshape_data.reconstruct_image_stack(crop_dir=temp_dir) - # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, input_data[:, :, 0] > 0))) + stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(input_data))) + # all the same pixels are marked + assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) - # clean up - shutil.rmtree(save_dir) + # there are the same number of cells + assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) def test_stitch_slices(): fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) # generate ordered data linear_seq = np.arange(stack_len * row_len * col_len) test_vals = linear_seq.reshape((stack_len, row_len, col_len)) - input_data[0, :, 0, 0, :, :, 2] = test_vals + y_data[0, :, 0, 0, :, :, 0] = test_vals - slice_xr, log_data = reshape_data.create_slice_data(input_data, slice_stack_len) + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) - # TODO move crop + slice testing to another test function crop_size = (10, 10) overlap_frac = 0.2 - data_xr_cropped, log_data_crop = reshape_data.crop_multichannel_data(data_xr=slice_xr, - crop_size=crop_size, - overlap_frac=overlap_frac, - test_parameters=False) - - # # get parameters - # row_crop_size, col_crop_size = crop_size[0], crop_size[1] - # num_row_crops, num_col_crops = log_data_crop["num_row_crops"], log_data_crop["num_col_crops"] - # num_slices = log_data["num_slices"] + X_cropped, y_cropped, log_data_crop = \ + reshape_data.crop_multichannel_data(X_data=X_slice, + y_data=y_slice, + crop_size=crop_size, + overlap_frac=overlap_frac, + test_parameters=False) - log_data["original_shape"] = input_data.shape - log_data["fov_names"] = input_data.fovs.values - stitched_slices = reshape_data.stitch_slices(slice_xr[..., -1:], {**log_data}) + log_data["original_shape"] = X_data.shape + log_data["fov_names"] = X_data.fovs.values + stitched_slices = reshape_data.stitch_slices(y_slice, {**log_data}) - # dims other than channels are the same - assert np.all(stitched_slices.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_slices.shape == y_data.shape) assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) @@ -882,24 +907,28 @@ def test_stitch_slices(): fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 7 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, chan_len=chan_len) + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + # generate ordered data linear_seq = np.arange(stack_len * row_len * col_len) test_vals = linear_seq.reshape((stack_len, row_len, col_len)) - input_data[0, :, 0, 0, :, :, 2] = test_vals + y_data[0, :, 0, 0, :, :, 0] = test_vals - slice_xr, log_data = reshape_data.create_slice_data(input_data, slice_stack_len) + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) # get parameters - log_data["original_shape"] = input_data.shape - log_data["fov_names"] = input_data.fovs.values - stitched_slices = reshape_data.stitch_slices(slice_xr[..., -1:], log_data) + log_data["original_shape"] = y_data.shape + log_data["fov_names"] = y_data.fovs.values + stitched_slices = reshape_data.stitch_slices(y_slice, log_data) - # dims other than channels are the same - assert np.all(stitched_slices.shape[:-1] == input_data.shape[:-1]) + assert np.all(stitched_slices.shape == y_data.shape) assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) @@ -911,24 +940,31 @@ def test_reconstruct_slice_data(): row_len, col_len, chan_len = 50, 50, 3 slice_stack_len = 4 - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) # tag upper left hand corner of the label in each image tags = np.arange(stack_len) - input_data[0, :, 0, 0, 0, 0, 2] = tags + y_data[0, :, 0, 0, 0, 0, 0] = tags - slice_xr, slice_log_data = reshape_data.create_slice_data(input_data, slice_stack_len) + X_slice, y_slice, slice_log_data = \ + reshape_data.create_slice_data(X_data=X_data, + y_data=y_data, + slice_stack_len=slice_stack_len) - reshape_data.save_npzs_for_caliban(resized_xr=slice_xr, original_xr=input_data, + reshape_data.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, log_data={**slice_log_data}, save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) stitched_slices = reshape_data.reconstruct_slice_data(temp_dir) - # dims other than channels are the same - assert np.all(stitched_slices.shape[:-1] == input_data.shape[:-1]) + # dims are the same + assert np.all(stitched_slices.shape == y_data.shape) assert np.all(np.equal(stitched_slices[0, :, 0, 0, 0, 0, 0], tags)) From 2fe53cec23c58fc9b6091fed9f42a6f183f20583 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 18:58:14 -0700 Subject: [PATCH 03/12] updated old stitch test --- caliban_toolbox/reshape_data_test.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 14a528c..0d6115e 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -821,7 +821,7 @@ def test_stitch_slices(): with tempfile.TemporaryDirectory() as temp_dir: # generate data (fov_len, stack_len, crop_num, - slice_num, row_len, col_len, chan_len) = 2, 1, 1, 1, 400, 400, 4 + slice_num, row_len, col_len, chan_len) = 2, 12, 1, 1, 400, 400, 4 X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, @@ -840,22 +840,19 @@ def test_stitch_slices(): (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx cell_idx += 1 - crop_size, overlap_frac = 50, 0.2 + slice_stack_len = 4 - # crop data - X_cropped, y_cropped, log_data = \ - reshape_data.crop_multichannel_data(X_data=X_data, y_data=y_data, - crop_size=(crop_size, crop_size), - overlap_frac=0.2) + # slice data + X_slice, y_slice, log_data = \ + reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) # stitch data - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, + reshape_data.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, log_data=log_data, save_dir=temp_dir) - reshape_data.reconstruct_image_stack(crop_dir=temp_dir) - - stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) + stitched_xr = reshape_data.reconstruct_slice_data(save_dir=temp_dir) # all the same pixels are marked assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) @@ -864,7 +861,7 @@ def test_stitch_slices(): assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) -def test_stitch_slices(): +def test_stitch_slices1(): fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 4 From 5b0918f752b7d98d96638c91d89549c434675c69 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 19:44:59 -0700 Subject: [PATCH 04/12] simplified slice and crop functions --- caliban_toolbox/reshape_data.py | 82 +++++------ caliban_toolbox/reshape_data_test.py | 194 ++++----------------------- 2 files changed, 67 insertions(+), 209 deletions(-) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 33a726c..715f96b 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -593,53 +593,55 @@ def stitch_crops(annotated_data, log_data): stitched_labels[fov, stack, 0, 0, row_starts[row]:row_ends[row], col_starts[col]:col_ends[col], 0] = combined_crop - # relabel images to remove skipped cell_ids + # trim padding to put image back to original size + if row_padding > 0: + stitched_labels = stitched_labels[:, :, :, :, :-row_padding, :, :] + if col_padding > 0: + stitched_labels = stitched_labels[:, :, :, :, :, :-col_padding, :] + return stitched_labels -def reconstruct_image_stack(crop_dir, verbose=True): +def reconstruct_crops(crop_stack, log_data): """High level function to combine crops together into a single stitched image Args: - crop_dir: directory where cropped files are stored - verbose: flag to control print statements + crop_stack: stack of cropped images + log_data: dict generated during crop creation process """ - # sanitize inputs - if not os.path.isdir(crop_dir): - raise ValueError('crop_dir not a valid directory: {}'.format(crop_dir)) + # # sanitize inputs + # if not os.path.isdir(crop_dir): + # raise ValueError('crop_dir not a valid directory: {}'.format(crop_dir)) + # + # # unpack JSON data + # with open(os.path.join(crop_dir, 'log_data.json')) as json_file: + # log_data = json.load(json_file) - # unpack JSON data - with open(os.path.join(crop_dir, 'log_data.json')) as json_file: - log_data = json.load(json_file) - row_padding, col_padding = log_data['row_padding'], log_data['col_padding'] - fov_names = log_data['fov_names'] - # combine all npz crops into a single stack - crop_stack = load_npzs(crop_dir=crop_dir, log_data=log_data, verbose=verbose) + # # combine all npz crops into a single stack + # crop_stack = load_npzs(crop_dir=crop_dir, log_data=log_data, verbose=verbose) # stitch crops together into single contiguous image stitched_images = stitch_crops(annotated_data=crop_stack, log_data=log_data) # crop image down to original size - if row_padding > 0: - stitched_images = stitched_images[:, :, :, :, :-row_padding, :, :] - if col_padding > 0: - stitched_images = stitched_images[:, :, :, :, :, :-col_padding, :] - - _, stack_len, _, _, row_len, col_len, _ = log_data['original_shape'] - # labels for each index within a dimension - coordinate_labels = [fov_names, range(stack_len), range(1), - range(1), range(row_len), range(col_len), ['segmentation_label']] - # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] + _, stack_len, _, _, row_len, col_len, _ = log_data['original_shape'] - stitched_xr = xr.DataArray(data=stitched_images, coords=coordinate_labels, - dims=dimension_labels) + # # labels for each index within a dimension + # coordinate_labels = [fov_names, range(stack_len), range(1), + # range(1), range(row_len), range(col_len), ['segmentation_label']] + # + # # labels for each dimension + # dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] + # + # stitched_xr = xr.DataArray(data=stitched_images, coords=coordinate_labels, + # dims=dimension_labels) + return stitched_images - stitched_xr.to_netcdf(os.path.join(crop_dir, 'stitched_images.nc')) + #stitched_xr.to_netcdf(os.path.join(crop_dir, 'stitched_images.nc')) def stitch_slices(slice_stack, log_data): @@ -688,26 +690,26 @@ def stitch_slices(slice_stack, log_data): return stitched_xr -def reconstruct_slice_data(save_dir, verbose=True): +def reconstruct_slices(slice_stack, log_data): """High level function to put pieces of a slice back together Args: - save_dir: full path to directory where slice pieces are stored - verbose: flag to control print statements + slice_stack: stack of sliced images + log_data: dict generated during slice creation process Returns: xarray.DataArray: 7D tensor of stitched labeled slices """ - if not os.path.isdir(save_dir): - raise FileNotFoundError('slice directory does not exist') - - json_file_path = os.path.join(save_dir, 'log_data.json') - if not os.path.exists(json_file_path): - raise FileNotFoundError('json file does not exist') - - with open(json_file_path) as json_file: - slice_log_data = json.load(json_file) + # if not os.path.isdir(save_dir): + # raise FileNotFoundError('slice directory does not exist') + # + # json_file_path = os.path.join(save_dir, 'log_data.json') + # if not os.path.exists(json_file_path): + # raise FileNotFoundError('json file does not exist') + # + # with open(json_file_path) as json_file: + # slice_log_data = json.load(json_file) slice_stack = load_npzs(save_dir, slice_log_data, verbose=verbose) diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 0d6115e..619b6bf 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -600,7 +600,7 @@ def test_stitch_crops(): (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx cell_idx += 1 - # crop the image + # ## Test when crop is same size as image crop_size, overlap_frac = 400, 0.2 X_cropped, y_cropped, log_data = \ reshape_data.crop_multichannel_data(X_data=X_data, @@ -613,12 +613,6 @@ def test_stitch_crops(): # stitch the crops back together stitched_img = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) - # trim padding - row_padding, col_padding = log_data["row_padding"], log_data["col_padding"] - if row_padding > 0: - stitched_img = stitched_img[:, :, :, :, :-row_padding, :, :] - if col_padding > 0: - stitched_img = stitched_img[:, :, :, :, :, :-col_padding, :] # dims are the same assert np.all(stitched_img.shape == y_data.shape) @@ -629,6 +623,30 @@ def test_stitch_crops(): # check that same number of unique objects assert len(np.unique(stitched_img)) == len(np.unique(y_data.values)) + + # ## Test when rows has only one crop + crop_size, overlap_frac = (400, 40), 0.2 + + # crop data + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=crop_size, + overlap_frac=overlap_frac) + + # stitch back together + log_data["original_shape"] = X_data.shape + stitched_imgs = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) + + # dims are the same + assert np.all(stitched_imgs.shape == y_data.shape) + + # all the same pixels are marked + assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0))) + + # there are the same number of cells + assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data))) + # test stitching imperfect annotator labels that slightly overlap # generate stack of crops from image with grid pattern fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 1, 1, 1, 800, 800, 1 @@ -690,9 +708,6 @@ def test_stitch_crops(): stitched_img = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) - # trim padding - stitched_img = stitched_img[:, :, :, :, :-padding, :-padding, :] - relabeled = skimage.measure.label(stitched_img[0, 0, 0, 0, :, :, 0]) props = skimage.measure.regionprops_table(relabeled, properties=["area", "label"]) @@ -711,157 +726,7 @@ def test_stitch_crops(): assert (np.all(props["area"] >= min_size)) -def test_reconstruct_image_data(): - # generate stack of crops from image with grid pattern - with tempfile.TemporaryDirectory() as temp_dir: - fov_len, stack_len, crop_num, slice_num = 2, 1, 1, 1 - row_len, col_len, chan_len = 400, 400, 4 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) - - # create image with - cell_idx = 1 - for i in range(12): - for j in range(11): - for fov in range(y_data.shape[0]): - y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx - cell_idx += 1 - - crop_size, overlap_frac = 40, 0.2 - - # crop data - X_cropped, y_cropped, log_data = \ - reshape_data.crop_multichannel_data(X_data=X_data, y_data=y_data, - crop_size=(crop_size, crop_size), - overlap_frac=0.2) - - # stitch data - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=log_data, - save_dir=temp_dir, verbose=False) - - reshape_data.reconstruct_image_stack(crop_dir=temp_dir) - - stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - - # dims are the same - assert np.all(stitched_xr.shape == y_data.shape) - - # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) - - # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) - - with tempfile.TemporaryDirectory() as temp_dir: - # test single crop in x - crop_size, overlap_frac = (400, 40), 0.2 - - # crop data - X_cropped, y_cropped, log_data = reshape_data.crop_multichannel_data(X_data=X_data, - y_data=y_data, - crop_size=crop_size, - overlap_frac=0.2) - - # stitch data - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=log_data, - save_dir=temp_dir, verbose=False) - - reshape_data.reconstruct_image_stack(crop_dir=temp_dir) - - stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - - # dims are the same - assert np.all(stitched_xr.shape == y_data.shape) - - # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) - - # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) - - with tempfile.TemporaryDirectory() as temp_dir: - # test single crop in both - crop_size, overlap_frac = (400, 400), 0.2 - - # crop data - X_cropped, y_cropped, log_data = reshape_data.crop_multichannel_data(X_data=X_data, - y_data=y_data, - crop_size=crop_size, - overlap_frac=0.2) - - # stitch data - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=log_data, - save_dir=temp_dir, verbose=False) - - reshape_data.reconstruct_image_stack(crop_dir=temp_dir) - - stitched_xr = xr.open_dataarray(os.path.join(temp_dir, "stitched_images.nc")) - - # dims are the same - assert np.all(stitched_xr.shape == y_data.shape) - - # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) - - # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) - - def test_stitch_slices(): - with tempfile.TemporaryDirectory() as temp_dir: - # generate data - (fov_len, stack_len, crop_num, - slice_num, row_len, col_len, chan_len) = 2, 12, 1, 1, 400, 400, 4 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=1) - - # create image with - cell_idx = 1 - for i in range(12): - for j in range(11): - for fov in range(y_data.shape[0]): - y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx - cell_idx += 1 - - slice_stack_len = 4 - - # slice data - X_slice, y_slice, log_data = \ - reshape_data.create_slice_data(X_data=X_data, y_data=y_data, - slice_stack_len=slice_stack_len) - - # stitch data - reshape_data.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, - log_data=log_data, - save_dir=temp_dir) - - stitched_xr = reshape_data.reconstruct_slice_data(save_dir=temp_dir) - - # all the same pixels are marked - assert (np.all(np.equal(stitched_xr[:, :, 0] > 0, y_data[:, :, 0] > 0))) - - # there are the same number of cells - assert (len(np.unique(stitched_xr)) == len(np.unique(y_data))) - - -def test_stitch_slices1(): fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 4 @@ -881,15 +746,6 @@ def test_stitch_slices1(): X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, slice_stack_len=slice_stack_len) - crop_size = (10, 10) - overlap_frac = 0.2 - X_cropped, y_cropped, log_data_crop = \ - reshape_data.crop_multichannel_data(X_data=X_slice, - y_data=y_slice, - crop_size=crop_size, - overlap_frac=overlap_frac, - test_parameters=False) - log_data["original_shape"] = X_data.shape log_data["fov_names"] = X_data.fovs.values stitched_slices = reshape_data.stitch_slices(y_slice, {**log_data}) From a29b15695bbc258d252fdf78249b40fe78fa9a48 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 21:11:01 -0700 Subject: [PATCH 05/12] cropping and slicing into same function --- caliban_toolbox/reshape_data.py | 110 +++++++--------------- caliban_toolbox/reshape_data_test.py | 134 +++++++++++++++++++++++++-- 2 files changed, 161 insertions(+), 83 deletions(-) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 715f96b..3f46132 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -356,8 +356,6 @@ def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, bla npz_id = 'fov_{}_crop_{}_slice_{}'.format(fov_names[fov], crop, slice) # get working batch - print('y_data shape is {}'.format(y_data.shape)) - print('indices are {}'.format((fov, crop, slice))) labels = y_data[fov, :, crop, slice, ...].values channels = X_data[fov, :, crop, slice, ...].values @@ -519,11 +517,11 @@ def load_npzs(crop_dir, log_data, verbose=True): return stack -def stitch_crops(annotated_data, log_data): +def stitch_crops(crop_stack, log_data): """Takes a stack of annotated labels and stitches them together into a single image Args: - annotated_data: 7D tensor of labels to be stitched together + crop_stack: 7D tensor of labels to be stitched together log_data: dictionary of parameters for reconstructing original image data Returns: @@ -539,7 +537,7 @@ def stitch_crops(annotated_data, log_data): row_starts, row_ends = log_data['row_starts'], log_data['row_ends'] col_starts, col_ends = log_data['col_starts'], log_data['col_ends'] - if annotated_data.shape[3] != 1: + if crop_stack.shape[3] != 1: raise ValueError('Stacks must be combined before stitching can occur') # for each fov and stack, loop through rows and columns of crop positions @@ -550,7 +548,7 @@ def stitch_crops(annotated_data, log_data): crop_counter = row * len(row_starts) + col # get current crop - crop = annotated_data[fov, stack, crop_counter, 0, :, :, 0] + crop = crop_stack[fov, stack, crop_counter, 0, :, :, 0] # increment values to ensure unique labels across final image lowest_allowed_val = np.amax(stitched_labels[fov, stack, ...]) @@ -602,48 +600,6 @@ def stitch_crops(annotated_data, log_data): return stitched_labels -def reconstruct_crops(crop_stack, log_data): - """High level function to combine crops together into a single stitched image - - Args: - crop_stack: stack of cropped images - log_data: dict generated during crop creation process - """ - - # # sanitize inputs - # if not os.path.isdir(crop_dir): - # raise ValueError('crop_dir not a valid directory: {}'.format(crop_dir)) - # - # # unpack JSON data - # with open(os.path.join(crop_dir, 'log_data.json')) as json_file: - # log_data = json.load(json_file) - - - # # combine all npz crops into a single stack - # crop_stack = load_npzs(crop_dir=crop_dir, log_data=log_data, verbose=verbose) - - # stitch crops together into single contiguous image - stitched_images = stitch_crops(annotated_data=crop_stack, log_data=log_data) - - # crop image down to original size - - - _, stack_len, _, _, row_len, col_len, _ = log_data['original_shape'] - - # # labels for each index within a dimension - # coordinate_labels = [fov_names, range(stack_len), range(1), - # range(1), range(row_len), range(col_len), ['segmentation_label']] - # - # # labels for each dimension - # dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - # - # stitched_xr = xr.DataArray(data=stitched_images, coords=coordinate_labels, - # dims=dimension_labels) - return stitched_images - - #stitched_xr.to_netcdf(os.path.join(crop_dir, 'stitched_images.nc')) - - def stitch_slices(slice_stack, log_data): """Helper function to stitch slices together back into original sized array @@ -678,41 +634,45 @@ def stitch_slices(slice_stack, log_data): stitched_slices[:, slice_start_indices[last_idx]:slice_end_indices[last_idx], :, 0, ...] = \ slice_stack[:, :slice_len, :, last_idx, ...] - # labels for each index within a dimension - coordinate_labels = [fov_names, range(stack_len), range(crop_num), range(1), range(row_len), - range(col_len), ['segmentation_label']] + return stitched_slices - # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - stitched_xr = xr.DataArray(stitched_slices, coords=coordinate_labels, dims=dimension_labels) +def reconstruct_image_stack(crop_dir, verbose=True): + """High level function to recombine data into a single stitched image - return stitched_xr + Args: + crop_dir: full path to directory with cropped images + verbose: flag to control print statements + """ + # sanitize inputs + if not os.path.isdir(crop_dir): + raise ValueError('crop_dir not a valid directory: {}'.format(crop_dir)) -def reconstruct_slices(slice_stack, log_data): - """High level function to put pieces of a slice back together + # unpack JSON data + with open(os.path.join(crop_dir, 'log_data.json')) as json_file: + log_data = json.load(json_file) - Args: - slice_stack: stack of sliced images - log_data: dict generated during slice creation process + # combine all npzs into a single stack + image_stack = load_npzs(crop_dir=crop_dir, log_data=log_data, verbose=verbose) - Returns: - xarray.DataArray: 7D tensor of stitched labeled slices - """ + # stitch slices if data was sliced + if 'num_slices' in log_data: + image_stack = stitch_slices(slice_stack=image_stack, log_data=log_data) - # if not os.path.isdir(save_dir): - # raise FileNotFoundError('slice directory does not exist') - # - # json_file_path = os.path.join(save_dir, 'log_data.json') - # if not os.path.exists(json_file_path): - # raise FileNotFoundError('json file does not exist') - # - # with open(json_file_path) as json_file: - # slice_log_data = json.load(json_file) + # stitch crops if data was cropped + if 'num_crops' in log_data: + image_stack = stitch_crops(crop_stack=image_stack, log_data=log_data) - slice_stack = load_npzs(save_dir, slice_log_data, verbose=verbose) + # labels for each index within a dimension + _, stack_len, _, _, row_len, col_len, _ = log_data['original_shape'] + coordinate_labels = [log_data['fov_names'], range(stack_len), range(1), + range(1), range(row_len), range(col_len), ['segmentation_label']] + + # labels for each dimension + dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - stitched_xr = stitch_slices(slice_stack, slice_log_data) + stitched_xr = xr.DataArray(data=image_stack, coords=coordinate_labels, + dims=dimension_labels) - return stitched_xr + stitched_xr.to_netcdf(os.path.join(crop_dir, 'stitched_images.xr')) diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 619b6bf..554e8d0 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -611,7 +611,7 @@ def test_stitch_crops(): log_data["original_shape"] = X_data.shape # stitch the crops back together - stitched_img = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) + stitched_img = reshape_data.stitch_crops(crop_stack=y_cropped, log_data=log_data) # dims are the same @@ -636,7 +636,7 @@ def test_stitch_crops(): # stitch back together log_data["original_shape"] = X_data.shape - stitched_imgs = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) + stitched_imgs = reshape_data.stitch_crops(crop_stack=y_cropped, log_data=log_data) # dims are the same assert np.all(stitched_imgs.shape == y_data.shape) @@ -706,7 +706,7 @@ def test_stitch_crops(): log_data["fov_names"] = y_data.fovs.values.tolist() log_data["channel_names"] = y_data.channels.values.tolist() - stitched_img = reshape_data.stitch_crops(annotated_data=y_cropped, log_data=log_data) + stitched_img = reshape_data.stitch_crops(crop_stack=y_cropped, log_data=log_data) relabeled = skimage.measure.label(stitched_img[0, 0, 0, 0, :, :, 0]) @@ -786,9 +786,55 @@ def test_stitch_slices(): assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) -def test_reconstruct_slice_data(): +def test_reconstruct_image_stack(): with tempfile.TemporaryDirectory() as temp_dir: - # generate data + # generate stack of crops from image with grid pattern + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # create image with artificial objects to be segmented + + cell_idx = 1 + for i in range(12): + for j in range(11): + for fov in range(y_data.shape[0]): + y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), + (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx + cell_idx += 1 + + # Crop the data + crop_size, overlap_frac = 100, 0.2 + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=(crop_size, crop_size), + overlap_frac=overlap_frac) + + reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, + log_data=log_data, save_dir=temp_dir) + + reshape_data.reconstruct_image_stack(crop_dir=temp_dir) + + stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) + + # dims are the same + assert np.all(stitched_imgs.shape == y_data.shape) + + # all the same pixels are marked + assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0))) + + # there are the same number of cells + assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data))) + + with tempfile.TemporaryDirectory() as temp_dir: + # generate data with the corner tagged fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1 row_len, col_len, chan_len = 50, 50, 3 slice_stack_len = 4 @@ -815,9 +861,81 @@ def test_reconstruct_slice_data(): blank_labels="include", save_format="npz", verbose=False) - stitched_slices = reshape_data.reconstruct_slice_data(temp_dir) + reshape_data.reconstruct_image_stack(temp_dir) + stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) + + assert np.all(stitched_imgs.shape == y_data.shape) + assert np.all(np.equal(stitched_imgs[0, :, 0, 0, 0, 0, 0], tags)) + + with tempfile.TemporaryDirectory() as temp_dir: + # generate data with both corners tagged and images labeled + + (fov_len, stack_len, crop_num, + slice_num, row_len, col_len, chan_len) = 1, 8, 1, 1, 400, 400, 4 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # create image with artificial objects to be segmented + + cell_idx = 1 + for i in range(1, 12): + for j in range(1, 11): + for stack in range(stack_len): + y_data[:, stack, :, :, (i * 35):(i * 35 + 10 + stack * 2), + (j * 37):(j * 37 + 8 + stack * 2), 0] = cell_idx + cell_idx += 1 + + # tag upper left hand corner of each image with squares of increasing size + for stack in range(stack_len): + y_data[0, stack, 0, 0, :stack, :stack, 0] = 1 + + # Crop the data + crop_size, overlap_frac = 100, 0.2 + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=(crop_size, crop_size), + overlap_frac=overlap_frac) + + X_slice, y_slice, slice_log_data = \ + reshape_data.create_slice_data(X_data=X_cropped, + y_data=y_cropped, + slice_stack_len=slice_stack_len) + + reshape_data.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, + log_data={**slice_log_data, **log_data}, + save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) + + reshape_data.reconstruct_image_stack(temp_dir) + stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) + + assert np.all(stitched_imgs.shape == y_data.shape) # dims are the same - assert np.all(stitched_slices.shape == y_data.shape) + assert np.all(stitched_imgs.shape == y_data.shape) + + # all the same pixels are marked + assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0))) + + # there are the same number of cells + assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data))) + + # check mark in upper left hand corner of image + for stack in range(stack_len): + original = np.zeros((10, 10)) + original[:stack, :stack] = 1 + new = stitched_imgs[0, stack, 0, 0, :10, :10, 0] + assert np.array_equal(original > 0, new > 0) + + + + - assert np.all(np.equal(stitched_slices[0, :, 0, 0, 0, 0, 0], tags)) From b063ab12dafde307c0fea1900458380538694838 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 21:24:28 -0700 Subject: [PATCH 06/12] Broke up reshape_data function into multiple utils files --- caliban_toolbox/reshape_data.py | 513 +-------------- caliban_toolbox/reshape_data_test.py | 720 +--------------------- caliban_toolbox/utils/crop_utils.py | 203 ++++++ caliban_toolbox/utils/crop_utils_test.py | 307 +++++++++ caliban_toolbox/utils/io_utils.py | 234 +++++++ caliban_toolbox/utils/io_utils_test.py | 290 +++++++++ caliban_toolbox/utils/slice_utils.py | 159 +++++ caliban_toolbox/utils/slice_utils_test.py | 200 ++++++ 8 files changed, 1411 insertions(+), 1215 deletions(-) create mode 100644 caliban_toolbox/utils/crop_utils.py create mode 100644 caliban_toolbox/utils/crop_utils_test.py create mode 100644 caliban_toolbox/utils/io_utils.py create mode 100644 caliban_toolbox/utils/io_utils_test.py create mode 100644 caliban_toolbox/utils/slice_utils.py create mode 100644 caliban_toolbox/utils/slice_utils_test.py diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 3f46132..8bb89d3 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -27,98 +27,12 @@ from __future__ import print_function from __future__ import division -import math -import numpy as np import os import json -from itertools import product - import xarray as xr - -def compute_crop_indices(img_len, crop_size, overlap_frac): - """Determine how to crop the image across one dimension. - - Args: - img_len: length of the image for given dimension - crop_size: size in pixels of the crop in given dimension - overlap_frac: fraction that adjacent crops will overlap each other on each side - - Returns: - numpy.array: coordinates for where each crop will start in given dimension - numpy.array: coordinates for where each crop will end in given dimension - int: number of pixels of padding at start and end of image in given dimension - """ - - # compute overlap fraction in pixels - overlap_pix = math.floor(crop_size * overlap_frac) - - # the crops start at pixel 0, and are spaced crop_size - overlap_pix away from each other - start_indices = np.arange(0, img_len - overlap_pix, crop_size - overlap_pix) - - # the crops each end crop_size away the start - end_indices = start_indices + crop_size - - # the padding for the final image is the amount that the last crop goes beyond the image size - padding = end_indices[-1] - img_len - - return start_indices, end_indices, padding - - -def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding): - """Crops an image into pieces according to supplied coordinates - - Args: - input_data: xarray of [fovs, stacks, crops, slices, rows, cols, channels] to be cropped - row_starts: list of indices where row crops start - row_ends: list of indices where row crops end - col_starts: list of indices where col crops start - col_ends: list of indices where col crops end - padding: tuple which specifies the amount of padding on the final image - - Returns: - numpy.array: 7D tensor of cropped images - tuple: shape of the final padded image - """ - - # determine key parameters of crop - fov_len, stack_len, input_crop_num, slice_num, _, _, channel_len = input_data.shape - - if input_crop_num > 1: - raise ValueError("Array has already been cropped") - - crop_num = len(row_starts) * len(col_starts) - crop_size_row = row_ends[0] - row_starts[0] - crop_size_col = col_ends[0] - col_starts[0] - - # create xarray to hold crops - cropped_stack = np.zeros((fov_len, stack_len, crop_num, slice_num, - crop_size_row, crop_size_col, channel_len)) - - # labels for each index within a dimension - coordinate_labels = [input_data.fovs, input_data.stacks, range(crop_num), input_data.slices, - range(crop_size_row), range(crop_size_col), input_data.channels] - - # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - - cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=dimension_labels) - - # pad the input to account for imperfectly overlapping final crop in rows and cols - formatted_padding = ((0, 0), (0, 0), (0, 0), (0, 0), (0, padding[0]), (0, padding[1]), (0, 0)) - padded_input = np.pad(input_data, formatted_padding, mode='constant', constant_values=0) - - # loop through rows and cols to generate crops - crop_counter = 0 - for i in range(len(row_starts)): - for j in range(len(col_starts)): - cropped_xr[:, :, crop_counter, ...] = padded_input[:, :, 0, :, - row_starts[i]:row_ends[i], - col_starts[j]:col_ends[j], :] - crop_counter += 1 - - return cropped_xr, padded_input.shape +from caliban_toolbox.utils import crop_utils, slice_utils, io_utils def crop_multichannel_data(X_data, y_data, crop_size, overlap_frac, test_parameters=False): @@ -158,20 +72,20 @@ def crop_multichannel_data(X_data, y_data, crop_size, overlap_frac, test_paramet X_data, y_data = X_data[:1, ...], y_data[:1, ...] # compute the start and end coordinates for the row and column crops - row_starts, row_ends, row_padding = compute_crop_indices(img_len=X_data.shape[4], + row_starts, row_ends, row_padding = crop_utils.compute_crop_indices(img_len=X_data.shape[4], crop_size=crop_size[0], overlap_frac=overlap_frac) - col_starts, col_ends, col_padding = compute_crop_indices(img_len=X_data.shape[5], + col_starts, col_ends, col_padding = crop_utils.compute_crop_indices(img_len=X_data.shape[5], crop_size=crop_size[1], overlap_frac=overlap_frac) # crop images - X_data_cropped, padded_shape = crop_helper(X_data, row_starts=row_starts, row_ends=row_ends, + X_data_cropped, padded_shape = crop_utils.crop_helper(X_data, row_starts=row_starts, row_ends=row_ends, col_starts=col_starts, col_ends=col_ends, padding=(row_padding, col_padding)) - y_data_cropped, padded_shape = crop_helper(y_data, row_starts=row_starts, row_ends=row_ends, + y_data_cropped, padded_shape = crop_utils.crop_helper(y_data, row_starts=row_starts, row_ends=row_ends, col_starts=col_starts, col_ends=col_ends, padding=(row_padding, col_padding)) @@ -190,91 +104,6 @@ def crop_multichannel_data(X_data, y_data, crop_size, overlap_frac, test_paramet return X_data_cropped, y_data_cropped, log_data -def compute_slice_indices(stack_len, slice_len, slice_overlap): - """ Determine how to slice an image across the stack dimension. - - Args: - stack_len: total number of z or t stacks - slice_len: number of z/t frames to be included in each slice - slice_overlap: number of z/t frames that will overlap in each slice - - Returns: - numpy.array: coordinates for the start location of each slice - numpy.array: coordinates for the end location of each slice - """ - - if slice_overlap >= slice_len: - raise ValueError('slice overlap must be less than the length of the slice') - - spacing = slice_len - slice_overlap - - # slices_start indices begin at index 0, and are spaced 'spacing' apart from one another - slice_start_indices = np.arange(0, stack_len - slice_overlap, spacing) - - # slices_end indices are 'spacing' away from the start - slice_end_indices = slice_start_indices + slice_len - - if slice_end_indices[-1] != stack_len: - # if slices overshoot, reduce length of final slice - slice_end_indices[-1] = stack_len - - return slice_start_indices, slice_end_indices - - -def slice_helper(data_xr, slice_start_indices, slice_end_indices): - """Divide a stack into smaller slices according to supplied indices - - Args: - data_xr: xarray of to be split into slices - slice_start_indices: list of indices for where slices start - slice_end_indices: list of indices for where slices end - - Returns: - xarray.DataArray: 7D tensor of sliced images - """ - - # get input image dimensions - fov_len, stack_len, crop_num, input_slice_num, row_len, col_len, chan_len = data_xr.shape - - if input_slice_num > 1: - raise ValueError('Input array already contains slice data') - - slice_num = len(slice_start_indices) - sliced_stack_len = slice_end_indices[0] - slice_start_indices[0] - - # create xarray to hold slices - slice_data = np.zeros((fov_len, sliced_stack_len, crop_num, - slice_num, row_len, col_len, chan_len)) - - # labels for each index within a dimension - coordinate_labels = [data_xr.fovs, range(sliced_stack_len), range(crop_num), range(slice_num), - range(row_len), range(col_len), data_xr.channels] - - # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - - slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=dimension_labels) - - # loop through slice indices to generate sliced data - slice_counter = 0 - for i in range(len(slice_start_indices)): - - if i != len(slice_start_indices) - 1: - # not the last slice - slice_xr[:, :, :, slice_counter, ...] = \ - data_xr[:, slice_start_indices[i]:slice_end_indices[i], :, 0, :, :, :].values - slice_counter += 1 - - else: - # last slice, only index into stack the amount two indices are separated - slice_len = slice_end_indices[i] - slice_start_indices[i] - slice_xr[:, :slice_len, :, slice_counter, ...] = \ - data_xr[:, slice_start_indices[i]:slice_end_indices[i], :, 0, :, :, :].values - slice_counter += 1 - - return slice_xr - - def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0): """Takes an array of data and splits it up into smaller pieces along the stack dimension @@ -304,10 +133,10 @@ def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0): # compute indices for slices stack_len = X_data.shape[1] slice_start_indices, slice_end_indices = \ - compute_slice_indices(stack_len, slice_stack_len, slice_overlap) + slice_utils.compute_slice_indices(stack_len, slice_stack_len, slice_overlap) - X_data_slice = slice_helper(X_data, slice_start_indices, slice_end_indices) - y_data_slice = slice_helper(y_data, slice_start_indices, slice_end_indices) + X_data_slice = slice_utils.slice_helper(X_data, slice_start_indices, slice_end_indices) + y_data_slice = slice_utils.slice_helper(y_data, slice_start_indices, slice_end_indices) log_data = {} log_data['slice_start_indices'] = slice_start_indices.tolist() @@ -317,326 +146,6 @@ def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0): return X_data_slice, y_data_slice, log_data -def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, blank_labels='include', - save_format='npz', verbose=True): - """Take an array of processed image data and save as NPZ for caliban - - Args: - X_data: 7D tensor of cropped and sliced raw images - y_data: 7D tensor of cropped and sliced labeled images - original_data: the original unmodified images - log_data: data used to reconstruct images - save_dir: path to save the npz and JSON files - blank_labels: whether to include NPZs with blank labels (poor predictions) - or skip (no cells) - save_format: format to save the data (currently only NPZ) - verbose: flag to control print statements - """ - - if not os.path.isdir(save_dir): - os.makedirs(save_dir) - - # if these are present, it means data was cropped/sliced. Otherwise, default to 1 - num_crops = log_data.get('num_crops', 1) - num_slices = log_data.get('num_slices', 1) - - fov_names = original_data.fovs.values - fov_len = len(fov_names) - - if blank_labels not in ['skip', 'include', 'separate']: - raise ValueError('blank_labels must be one of ' - '[skip, include, separate], got {}'.format(blank_labels)) - - if blank_labels == 'separate': - os.makedirs(os.path.join(save_dir, 'separate')) - - # for each fov, loop through 2D crops and 3D slices - for fov, crop, slice in product(range(fov_len), range(num_crops), range(num_slices)): - # generate identifier for crop - npz_id = 'fov_{}_crop_{}_slice_{}'.format(fov_names[fov], crop, slice) - - # get working batch - labels = y_data[fov, :, crop, slice, ...].values - channels = X_data[fov, :, crop, slice, ...].values - - # determine if labels are blank, and if so what to do with npz - if np.sum(labels) == 0: - - # blank labels get saved to separate folder - if blank_labels == 'separate': - if verbose: - print('{} is blank, saving to separate folder'.format(npz_id)) - save_path = os.path.join(save_dir, blank_labels, npz_id) - - # save images as either npz or xarray - if save_format == 'npz': - np.savez(save_path + '.npz', X=channels, y=labels) - - elif save_format == 'xr': - raise NotImplementedError() - - # blank labels don't get saved, empty area of tissue - elif blank_labels == 'skip': - if verbose: - print('{} is blank, skipping saving'.format(npz_id)) - - # blank labels get saved along with other crops - elif blank_labels == 'include': - if verbose: - print('{} is blank, saving to folder'.format(npz_id)) - save_path = os.path.join(save_dir, npz_id) - - # save images as either npz or xarray - if save_format == 'npz': - np.savez(save_path + '.npz', X=channels, y=labels) - - elif save_format == 'xr': - raise NotImplementedError() - - else: - # crop is not blank, save based on file_format - save_path = os.path.join(save_dir, npz_id) - - # save images as either npz or xarray - if save_format == 'npz': - np.savez(save_path + '.npz', X=channels, y=labels) - - elif save_format == 'xr': - raise NotImplementedError() - - log_data['fov_names'] = fov_names.tolist() - log_data['channel_names'] = original_data.channels.values.tolist() - log_data['original_shape'] = original_data.shape - log_data['slice_stack_len'] = X_data.shape[1] - log_data['save_format'] = save_format - - log_path = os.path.join(save_dir, 'log_data.json') - with open(log_path, 'w') as write_file: - json.dump(log_data, write_file) - - -def get_saved_file_path(dir_list, fov_name, crop, slice, file_ext='.npz'): - """Helper function to identify correct file path for an npz file - - Args: - dir_list: list of files in directory - fov_name: string of the current fov_name - crop: int of current crop - slice: int of current slice - file_ext: extension file was saved with - - Returns: - string: formatted file name - - Raises: - ValueError: If multiple file path matches were found - """ - - base_string = 'fov_{}_crop_{}_slice_{}'.format(fov_name, crop, slice) - string_matches = [string for string in dir_list if base_string + '_save_version' in string] - - if len(string_matches) == 0: - full_string = base_string + file_ext - elif len(string_matches) == 1: - full_string = string_matches[0] - else: - raise ValueError('Multiple save versions found: ' - 'please select only a single save version. {}'.format(string_matches)) - return full_string - - -def load_npzs(crop_dir, log_data, verbose=True): - """Reads all of the cropped images from a directory, and aggregates them into a single stack - - Args: - crop_dir: path to directory with cropped npz or xarray files - log_data: dictionary of parameters generated during data saving - - verbose: flag to control print statements - - Returns: - numpy.array: 7D tensor of labeled crops - """ - - fov_names = log_data['fov_names'] - fov_len, stack_len, _, _, row_size, col_size, _ = log_data['original_shape'] - save_format = log_data['save_format'] - - # if cropped/sliced, get size of dimensions. Otherwise, use size in original data - row_crop_size = log_data.get('row_crop_size', row_size) - col_crop_size = log_data.get('col_crop_size', col_size) - slice_stack_len = log_data.get('slice_stack_len', stack_len) - - # if cropped/sliced, get number of crops/slices - num_crops, num_slices = log_data.get('num_crops', 1), log_data.get('num_slices', 1) - stack = np.zeros((fov_len, slice_stack_len, num_crops, - num_slices, row_crop_size, col_crop_size, 1)) - saved_files = os.listdir(crop_dir) - - # for each fov, loop over each 2D crop and 3D slice - for fov, crop, slice in product(range(fov_len), range(num_crops), range(num_slices)): - # load NPZs - if save_format == 'npz': - npz_path = os.path.join(crop_dir, get_saved_file_path(saved_files, - fov_names[fov], - crop, slice)) - if os.path.exists(npz_path): - temp_npz = np.load(npz_path) - - # last slice may be truncated, modify index - if slice == num_slices - 1: - current_stack_len = temp_npz['X'].shape[1] - else: - current_stack_len = slice_stack_len - - stack[fov, :current_stack_len, crop, slice, ...] = temp_npz['y'] - else: - # npz not generated, did not contain any labels, keep blank - if verbose: - print('could not find npz {}, skipping'.format(npz_path)) - - # load xarray - elif save_format == 'xr': - raise NotImplementedError() - # xr_path = os.path.join(crop_dir, get_saved_file_path(saved_files, fov_names[fov], - # crop, slice)) - # if os.path.exists(xr_path): - # temp_xr = xr.open_dataarray(xr_path) - # - # # last slice may be truncated, modify index - # if slice == num_slices - 1: - # current_stack_len = temp_xr.shape[1] - # else: - # current_stack_len = stack_len - # - # stack[fov, :current_stack_len, crop, slice, ...] = temp_xr[..., -1:] - # else: - # # npz not generated, did not contain any labels, keep blank - # print('could not find xr {}, skipping'.format(xr_path)) - - return stack - - -def stitch_crops(crop_stack, log_data): - """Takes a stack of annotated labels and stitches them together into a single image - - Args: - crop_stack: 7D tensor of labels to be stitched together - log_data: dictionary of parameters for reconstructing original image data - - Returns: - numpy.array: 7D tensor of reconstructed labels - """ - - # Initialize image with single dimension for channels - fov_len, stack_len, _, _, row_size, col_size, _ = log_data['original_shape'] - row_padding, col_padding = log_data.get('row_padding', 0), log_data.get('col_padding', 0) - stitched_labels = np.zeros((fov_len, stack_len, 1, 1, row_size + row_padding, - col_size + col_padding, 1)) - - row_starts, row_ends = log_data['row_starts'], log_data['row_ends'] - col_starts, col_ends = log_data['col_starts'], log_data['col_ends'] - - if crop_stack.shape[3] != 1: - raise ValueError('Stacks must be combined before stitching can occur') - - # for each fov and stack, loop through rows and columns of crop positions - for fov, stack, row, col in product(range(fov_len), range(stack_len), - range(len(row_starts)), range(len(col_starts))): - - # determine what crop # we're currently working on - crop_counter = row * len(row_starts) + col - - # get current crop - crop = crop_stack[fov, stack, crop_counter, 0, :, :, 0] - - # increment values to ensure unique labels across final image - lowest_allowed_val = np.amax(stitched_labels[fov, stack, ...]) - crop = np.where(crop == 0, crop, crop + lowest_allowed_val) - - # get ids of cells in current crop - potential_overlap_cells = np.unique(crop) - potential_overlap_cells = \ - potential_overlap_cells[np.nonzero(potential_overlap_cells)] - - # get values of stitched image at location where crop will be placed - stitched_crop = stitched_labels[fov, stack, 0, 0, - row_starts[row]:row_ends[row], - col_starts[col]:col_ends[col], 0] - - # loop through each cell in the crop to determine - # if it overlaps with another cell in full image - for cell in potential_overlap_cells: - - # get cell ids present in stitched image - # at location of current cell in crop - stitched_overlap_vals, stitched_overlap_counts = \ - np.unique(stitched_crop[crop == cell], return_counts=True) - - # remove IDs and counts corresponding to overlap with ID 0 (background) - keep_vals = np.nonzero(stitched_overlap_vals) - stitched_overlap_vals = stitched_overlap_vals[keep_vals] - stitched_overlap_counts = stitched_overlap_counts[keep_vals] - - # if there are overlaps, determine which is greatest in count, - # and replace with that ID - if len(stitched_overlap_vals) > 0: - max_overlap = stitched_overlap_vals[np.argmax(stitched_overlap_counts)] - crop[crop == cell] = max_overlap - - # combine the crop with the current values in the stitched image - combined_crop = np.where(stitched_crop > 0, stitched_crop, crop) - - # use this combined crop to update the values of stitched image - stitched_labels[fov, stack, 0, 0, row_starts[row]:row_ends[row], - col_starts[col]:col_ends[col], 0] = combined_crop - - # trim padding to put image back to original size - if row_padding > 0: - stitched_labels = stitched_labels[:, :, :, :, :-row_padding, :, :] - if col_padding > 0: - stitched_labels = stitched_labels[:, :, :, :, :, :-col_padding, :] - - return stitched_labels - - -def stitch_slices(slice_stack, log_data): - """Helper function to stitch slices together back into original sized array - - Args: - slice_stack: xarray of shape [fovs, stacks, crops, slices, rows, cols, segmentation_label] - log_data: log data produced from creation of slice stack - - Returns: - xarray.DataArray: 7D tensor of stitched labeled slices - """ - - # get parameters from dict - fov_len, stack_len, crop_num, _, row_len, col_len, chan_len = log_data['original_shape'] - crop_num = log_data.get('num_crops', crop_num) - row_len = log_data.get('row_crop_size', row_len) - col_len = log_data.get('col_crop_size', col_len) - - slice_start_indices = log_data['slice_start_indices'] - slice_end_indices = log_data['slice_end_indices'] - num_slices, fov_names = log_data['num_slices'], log_data['fov_names'] - - stitched_slices = np.zeros((fov_len, stack_len, crop_num, 1, row_len, col_len, 1)) - - # loop slice indices to generate sliced data - for i in range(num_slices - 1): - stitched_slices[:, slice_start_indices[i]:slice_end_indices[i], :, 0, ...] = \ - slice_stack[:, :, :, i, ...] - - # last slice, only index into stack the amount two indices are separated - last_idx = num_slices - 1 - slice_len = slice_end_indices[last_idx] - slice_start_indices[last_idx] - stitched_slices[:, slice_start_indices[last_idx]:slice_end_indices[last_idx], :, 0, ...] = \ - slice_stack[:, :slice_len, :, last_idx, ...] - - return stitched_slices - - def reconstruct_image_stack(crop_dir, verbose=True): """High level function to recombine data into a single stitched image @@ -654,15 +163,15 @@ def reconstruct_image_stack(crop_dir, verbose=True): log_data = json.load(json_file) # combine all npzs into a single stack - image_stack = load_npzs(crop_dir=crop_dir, log_data=log_data, verbose=verbose) + image_stack = io_utils.load_npzs(crop_dir=crop_dir, log_data=log_data, verbose=verbose) # stitch slices if data was sliced if 'num_slices' in log_data: - image_stack = stitch_slices(slice_stack=image_stack, log_data=log_data) + image_stack = slice_utils.stitch_slices(slice_stack=image_stack, log_data=log_data) # stitch crops if data was cropped if 'num_crops' in log_data: - image_stack = stitch_crops(crop_stack=image_stack, log_data=log_data) + image_stack = crop_utils.stitch_crops(crop_stack=image_stack, log_data=log_data) # labels for each index within a dimension _, stack_len, _, _, row_len, col_len, _ = log_data['original_shape'] diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 554e8d0..d173d56 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -24,149 +24,14 @@ # limitations under the License. # ============================================================================== import os -import shutil -import json -import pytest -import copy import tempfile -import skimage.measure import numpy as np -from caliban_toolbox import reshape_data import xarray as xr -import importlib - -importlib.reload(reshape_data) - - -def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len): - """Test function to generate a blank xarray with the supplied dimensions - - Inputs - fov_num: number of distinct FOVs - stack_num: number of distinct z stacks - crop_num: number of x/y crops - slice_num: number of z/t slices - row_num: number of rows - col_num: number of cols - chan_num: number of channels - - Outputs - test_xr: xarray of [fov_num, row_num, col_num, chan_num]""" - - test_img = np.zeros((fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len)) - - fovs = ["fov" + str(x) for x in range(1, fov_len + 1)] - channels = ["channel" + str(x) for x in range(1, chan_len + 1)] - - test_stack_xr = xr.DataArray(data=test_img, - coords=[fovs, range(stack_len), range(crop_num), range(slice_num), - range(row_len), range(col_len), channels], - dims=["fovs", "stacks", "crops", "slices", - "rows", "cols", "channels"]) - - return test_stack_xr - - -def test_compute_crop_indices(): - # test corner case of only one crop - img_len, crop_size, overlap_frac = 100, 100, 0.2 - starts, ends, padding = reshape_data.compute_crop_indices(img_len=img_len, crop_size=crop_size, - overlap_frac=overlap_frac) - assert (len(starts) == 1) - assert (len(ends) == 1) - - # test crop size that doesn't divide evenly into image size - img_len, crop_size, overlap_frac = 105, 20, 0.2 - starts, ends, padding = reshape_data.compute_crop_indices(img_len=img_len, crop_size=crop_size, - overlap_frac=overlap_frac) - crop_num = np.ceil(img_len / (crop_size - (crop_size * overlap_frac))) - assert (len(starts) == crop_num) - assert (len(ends) == crop_num) - - crop_end = crop_num * (crop_size - (crop_size * overlap_frac)) + crop_size * overlap_frac - assert (ends[-1] == crop_end) - - # test overlap of 0 between crops - img_len, crop_size, overlap_frac = 200, 20, 0 - starts, ends, padding = reshape_data.compute_crop_indices(img_len=img_len, crop_size=crop_size, - overlap_frac=overlap_frac) - assert (np.all(starts == range(0, 200, 20))) - assert (np.all(ends == range(20, 201, 20))) - assert (padding == 0) - - -def test_crop_helper(): - # img params - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 200, 200, 1 - crop_size, overlap_frac = 200, 0.2 - - # test only one crop - test_xr = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=chan_len) - - starts, ends, padding = reshape_data.compute_crop_indices(img_len=row_len, crop_size=crop_size, - overlap_frac=overlap_frac) - cropped, padded = reshape_data.crop_helper(input_data=test_xr, row_starts=starts, - row_ends=ends, col_starts=starts, col_ends=ends, - padding=(padding, padding)) - - assert (cropped.shape == (fov_len, stack_len, 1, slice_num, row_len, col_len, chan_len)) - - # test crops of different row/col dimensions - row_crop, col_crop = 50, 40 - row_starts, row_ends, row_padding = \ - reshape_data.compute_crop_indices(img_len=row_len, crop_size=row_crop, - overlap_frac=overlap_frac) - - col_starts, col_ends, col_padding = \ - reshape_data.compute_crop_indices(img_len=col_len, crop_size=col_crop, - overlap_frac=overlap_frac) - - cropped, padded = reshape_data.crop_helper(input_data=test_xr, row_starts=row_starts, - row_ends=row_ends, col_starts=col_starts, - col_ends=col_ends, - padding=(row_padding, col_padding)) - - assert (cropped.shape == (fov_len, stack_len, 30, slice_num, row_crop, col_crop, chan_len)) - - # test that correct region of image is being cropped - row_crop, col_crop = 40, 40 - - # assign each pixel in the image a unique value - linear_sequence = np.arange(0, fov_len * 1 * 1 * row_len * col_len * chan_len) - linear_sequence_reshaped = np.reshape(linear_sequence, (fov_len, 1, 1, 1, row_len, - col_len, chan_len)) - test_xr[:, :, :, :, :, :, :] = linear_sequence_reshaped - - # crop the image - row_starts, row_ends, row_padding = \ - reshape_data.compute_crop_indices(img_len=row_len, crop_size=row_crop, - overlap_frac=overlap_frac) - - col_starts, col_ends, col_padding = \ - reshape_data.compute_crop_indices(img_len=col_len, crop_size=col_crop, - overlap_frac=overlap_frac) - - cropped, padded = reshape_data.crop_helper(input_data=test_xr, row_starts=row_starts, - row_ends=row_ends, col_starts=col_starts, - col_ends=col_ends, - padding=(row_padding, col_padding)) - - # check that the values of each crop match the value in uncropped image - for img in range(test_xr.shape[0]): - crop_counter = 0 - for row in range(len(row_starts)): - for col in range(len(col_starts)): - crop = cropped[img, 0, crop_counter, 0, :, :, 0].values - - original_image_crop = test_xr[img, 0, 0, 0, row_starts[row]:row_ends[row], - col_starts[col]:col_ends[col], 0].values - assert (np.all(crop == original_image_crop)) - - crop_counter += 1 +from caliban_toolbox import reshape_data +from caliban_toolbox.utils import crop_utils, io_utils +from caliban_toolbox.utils.crop_utils_test import _blank_data_xr def test_crop_multichannel_data(): @@ -192,7 +57,7 @@ def test_crop_multichannel_data(): overlap_frac=overlap_frac, test_parameters=False) - expected_crop_num = len(reshape_data.compute_crop_indices(row_len, crop_size[0], + expected_crop_num = len(crop_utils.compute_crop_indices(row_len, crop_size[0], overlap_frac)[0]) ** 2 assert (X_data_cropped.shape == (fov_len, stack_len, expected_crop_num, slice_num, crop_size[0], crop_size[1], channel_len)) @@ -200,113 +65,6 @@ def test_crop_multichannel_data(): assert log_data["num_crops"] == expected_crop_num -def test_compute_slice_indices(): - # test when slice divides evenly into stack len - stack_len = 40 - slice_len = 4 - slice_overlap = 0 - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_len, - slice_overlap) - assert np.all(np.equal(slice_start_indices, np.arange(0, stack_len, slice_len))) - - # test when slice_num does not divide evenly into stack_len - stack_len = 42 - slice_len = 5 - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_len, - slice_overlap) - - expected_start_indices = np.arange(0, stack_len, slice_len) - assert np.all(np.equal(slice_start_indices, expected_start_indices)) - - # test overlapping slices - stack_len = 40 - slice_len = 4 - slice_overlap = 1 - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_len, - slice_overlap) - - assert len(slice_start_indices) == int(np.floor(stack_len / (slice_len - slice_overlap))) - assert slice_end_indices[-1] == stack_len - assert slice_end_indices[0] - slice_start_indices[0] == slice_len - - -def test_slice_helper(): - # test output shape with even division of slice - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 4 - - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_stack_len, 0) - - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=chan_len) - - slice_output = reshape_data.slice_helper(input_data, slice_start_indices, slice_end_indices) - - assert slice_output.shape == (fov_len, slice_stack_len, crop_num, - int(np.ceil(stack_len / slice_stack_len)), - row_len, col_len, chan_len) - - # test output shape with uneven division of slice - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 6 - - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_stack_len, 0) - - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=chan_len) - - slice_output = reshape_data.slice_helper(input_data, slice_start_indices, slice_end_indices) - - assert slice_output.shape == (fov_len, slice_stack_len, crop_num, - (np.ceil(stack_len / slice_stack_len)), - row_len, col_len, chan_len) - - # test output shape with slice overlaps - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 6 - slice_overlap = 1 - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_stack_len, - slice_overlap) - - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=chan_len) - - slice_output = reshape_data.slice_helper(input_data, slice_start_indices, slice_end_indices) - - assert slice_output.shape == (fov_len, slice_stack_len, crop_num, - (np.ceil(stack_len / (slice_stack_len - slice_overlap))), - row_len, col_len, chan_len) - - # test output values - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 4 - slice_start_indices, slice_end_indices = reshape_data.compute_slice_indices(stack_len, - slice_stack_len, 0) - - input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=chan_len) - - # tag upper left hand corner of each image - tags = np.arange(stack_len) - input_data[0, :, 0, 0, 0, 0, 0] = tags - - slice_output = reshape_data.slice_helper(input_data, slice_start_indices, slice_end_indices) - - # loop through each slice, make sure values increment as expected - for i in range(slice_output.shape[1]): - assert np.all(np.equal(slice_output[0, :, 0, i, 0, 0, 0], tags[i * 4:(i + 1) * 4])) - - def test_create_slice_data(): # test output shape with even division of slice fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 @@ -327,465 +85,6 @@ def test_create_slice_data(): row_len, col_len, chan_len) -def test_save_npzs_for_caliban(): - fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 4 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, - slice_num=num_slices, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, - slice_num=num_slices, - row_len=row_len, col_len=col_len, chan_len=1) - - sliced_X, sliced_y, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, - slice_stack_len=slice_stack_len) - - with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - blank_labels="include", - save_format="npz", verbose=False) - - # check that correct size was saved - test_npz_labels = np.load(os.path.join(temp_dir, "fov_fov1_crop_0_slice_0.npz")) - - assert test_npz_labels["y"].shape == (slice_stack_len, row_len, col_len, 1) - - assert test_npz_labels["y"].shape[:-1] == test_npz_labels["X"].shape[:-1] - - # check that json saved successfully - with open(os.path.join(temp_dir, "log_data.json")) as json_file: - saved_log_data = json.load(json_file) - - assert saved_log_data["original_shape"] == list(X_data.shape) - - with tempfile.TemporaryDirectory() as temp_dir: - # check that combined crop and slice saving works - crop_size = (10, 10) - overlap_frac = 0.2 - X_cropped, y_cropped, log_data_crop = \ - reshape_data.crop_multichannel_data(X_data=sliced_X, - y_data=sliced_y, - crop_size=crop_size, - overlap_frac=overlap_frac, - test_parameters=False) - - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, - original_data=X_data, - log_data={**log_data, **log_data_crop}, - save_dir=temp_dir, - blank_labels="include", save_format="npz", - verbose=False) - expected_crop_num = X_cropped.shape[2] * X_cropped.shape[3] - files = os.listdir(temp_dir) - files = [file for file in files if "npz" in file] - - assert len(files) == expected_crop_num - - # check that arguments specifying what to do with blank crops are working - - # set specified crops to not be blank - sliced_y[0, 0, 0, [1, 4, 7], 0, 0, 0] = 27 - expected_crop_num = sliced_X.shape[2] * sliced_X.shape[3] - - # test that function correctly includes blank crops when saving - with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, - original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - blank_labels="include", - save_format="npz", verbose=False) - - # check that there is the expected number of files saved to directory - files = os.listdir(temp_dir) - files = [file for file in files if "npz" in file] - - assert len(files) == expected_crop_num - - # test that function correctly skips blank crops when saving - with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, - original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - save_format="npz", - blank_labels="skip", verbose=False) - - # check that expected number of files in directory - files = os.listdir(temp_dir) - files = [file for file in files if "npz" in file] - assert len(files) == 3 - - # test that function correctly saves blank crops to separate folder - with tempfile.TemporaryDirectory() as temp_dir: - reshape_data.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, - original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - save_format="npz", - blank_labels="separate", verbose=False) - - # check that expected number of files in each directory - files = os.listdir(temp_dir) - files = [file for file in files if "npz" in file] - assert len(files) == 3 - - files = os.listdir(os.path.join(temp_dir, "separate")) - files = [file for file in files if "npz" in file] - assert len(files) == expected_crop_num - 3 - - -# postprocessing - - -def test_get_npz_file_path(): - # create list of npz_ids - dir_list = ["fov_fov1_crop_2_slice_4.npz", "fov_fov1_crop_2_slice_5_save_version_0.npz", - "fov_fov1_crop_2_slice_6_save_version_0.npz", - "fov_fov1_crop_2_slice_6_save_version_1.npz", - "fov_fov1_crop_2_slice_7_save_version_0.npz", - "fov_fov1_crop_2_slice_7_save_version_0_save_version_2.npz"] - - fov, crop = "fov1", 2 - - # test unmodified npz - slice = 4 - output_string = reshape_data.get_saved_file_path(dir_list, fov, crop, slice) - - assert output_string == dir_list[0] - - # test single modified npz - slice = 5 - output_string = reshape_data.get_saved_file_path(dir_list, fov, crop, slice) - assert output_string == dir_list[1] - - # test that error is raised when multiple save versions present - slice = 6 - with pytest.raises(ValueError): - output_string = reshape_data.get_saved_file_path(dir_list, fov, crop, slice) - - # test that error is raised when multiple save versions present due to resaves - slice = 7 - - with pytest.raises(ValueError): - output_string = reshape_data.get_saved_file_path(dir_list, fov, crop, slice) - - -def test_load_npzs(): - with tempfile.TemporaryDirectory() as temp_dir: - # first generate image stack that will be sliced up - fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1 - row_len, col_len, chan_len = 50, 50, 3 - slice_stack_len = 4 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) - - # slice the data - X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, - slice_stack_len) - - # crop the data - crop_size = (10, 10) - overlap_frac = 0.2 - X_cropped, y_cropped, log_data_crop = \ - reshape_data.crop_multichannel_data( - X_data=X_slice, - y_data=y_slice, - crop_size=crop_size, - overlap_frac=overlap_frac, - test_parameters=False) - - # tag the upper left hand corner of the label in each slice - slice_tags = np.arange(y_cropped.shape[3]) - crop_tags = np.arange(y_cropped.shape[2]) - y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags - y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags - - combined_log_data = {**log_data, **log_data_crop} - - # save the tagged data - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=combined_log_data, save_dir=temp_dir, - blank_labels="include", save_format="npz", - verbose=False) - - with open(os.path.join(temp_dir, "log_data.json")) as json_file: - saved_log_data = json.load(json_file) - - loaded_slices = reshape_data.load_npzs(temp_dir, saved_log_data, verbose=False) - - # dims other than channels are the same - assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1])) - - assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) - assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) - - # test slices with unequal last length - with tempfile.TemporaryDirectory() as temp_dir: - # first generate image stack that will be sliced up - fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1 - row_len, col_len, chan_len = 50, 50, 3 - slice_stack_len = 7 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) - - # slice the data - X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) - - # crop the data - crop_size = (10, 10) - overlap_frac = 0.2 - X_cropped, y_cropped, log_data_crop = \ - reshape_data.crop_multichannel_data( - X_data=X_slice, - y_data=y_slice, - crop_size=crop_size, - overlap_frac=overlap_frac, - test_parameters=False) - - # tag the upper left hand corner of the annotations in each slice - slice_tags = np.arange(y_cropped.shape[3]) - crop_tags = np.arange(X_cropped.shape[2]) - y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags - y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags - - combined_log_data = {**log_data, **log_data_crop} - - # save the tagged data - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=combined_log_data, save_dir=temp_dir, - blank_labels="include", save_format="npz", - verbose=False) - - loaded_slices = reshape_data.load_npzs(temp_dir, combined_log_data) - - # dims other than channels are the same - assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1])) - - assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) - assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) - - -def test_stitch_crops(): - # generate stack of crops from image with grid pattern - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) - - # create image with artificial objects to be segmented - - cell_idx = 1 - for i in range(12): - for j in range(11): - for fov in range(y_data.shape[0]): - y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx - cell_idx += 1 - - # ## Test when crop is same size as image - crop_size, overlap_frac = 400, 0.2 - X_cropped, y_cropped, log_data = \ - reshape_data.crop_multichannel_data(X_data=X_data, - y_data=y_data, - crop_size=(crop_size, crop_size), - overlap_frac=overlap_frac) - - log_data["original_shape"] = X_data.shape - - # stitch the crops back together - stitched_img = reshape_data.stitch_crops(crop_stack=y_cropped, log_data=log_data) - - - # dims are the same - assert np.all(stitched_img.shape == y_data.shape) - - # check that objects are at same location - assert (np.all(np.equal(stitched_img[..., 0] > 0, y_data.values[..., 0] > 0))) - - # check that same number of unique objects - assert len(np.unique(stitched_img)) == len(np.unique(y_data.values)) - - - # ## Test when rows has only one crop - crop_size, overlap_frac = (400, 40), 0.2 - - # crop data - X_cropped, y_cropped, log_data = \ - reshape_data.crop_multichannel_data(X_data=X_data, - y_data=y_data, - crop_size=crop_size, - overlap_frac=overlap_frac) - - # stitch back together - log_data["original_shape"] = X_data.shape - stitched_imgs = reshape_data.stitch_crops(crop_stack=y_cropped, log_data=log_data) - - # dims are the same - assert np.all(stitched_imgs.shape == y_data.shape) - - # all the same pixels are marked - assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0))) - - # there are the same number of cells - assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data))) - - # test stitching imperfect annotator labels that slightly overlap - # generate stack of crops from image with grid pattern - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 1, 1, 1, 800, 800, 1 - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - side_len = 40 - cell_num = y_data.shape[4] // side_len - - cell_id = np.arange(1, cell_num ** 2 + 1) - cell_id = np.random.choice(cell_id, cell_num ** 2, replace=False) - cell_idx = 0 - for row in range(cell_num): - for col in range(cell_num): - y_data[0, 0, 0, 0, row * side_len:(row + 1) * side_len, - col * side_len:(col + 1) * side_len, 0] = cell_id[cell_idx] - cell_idx += 1 - - crop_size, overlap_frac = 100, 0.2 - - starts, ends, padding = reshape_data.compute_crop_indices(img_len=row_len, crop_size=crop_size, - overlap_frac=overlap_frac) - - # generate a vector of random offsets to jitter the crop window, - # simulating mismatches between frames - offset_len = 5 - row_offset = np.append( - np.append(0, np.random.randint(-offset_len, offset_len, len(starts) - 2)), 0) - col_offset = np.append( - np.append(0, np.random.randint(-offset_len, offset_len, len(starts) - 2)), 0) - - # modify indices by random offset - row_starts, row_ends = starts + row_offset, ends + row_offset - col_starts, col_ends = starts + col_offset, ends + col_offset - - y_cropped, padded = reshape_data.crop_helper(input_data=y_data, row_starts=row_starts, - row_ends=row_ends, - col_starts=col_starts, col_ends=col_ends, - padding=(padding, padding)) - - # generate log data, since we had to go inside the upper level - # function to modify crop_helper inputs - log_data = {} - log_data["row_starts"] = row_starts.tolist() - log_data["row_ends"] = row_ends.tolist() - log_data["row_crop_size"] = crop_size - log_data["num_row_crops"] = len(row_starts) - log_data["col_starts"] = col_starts.tolist() - log_data["col_ends"] = col_ends.tolist() - log_data["col_crop_size"] = crop_size - log_data["num_col_crops"] = len(col_starts) - log_data["row_padding"] = int(padding) - log_data["col_padding"] = int(padding) - log_data["num_crops"] = y_cropped.shape[2] - log_data["original_shape"] = y_data.shape - log_data["fov_names"] = y_data.fovs.values.tolist() - log_data["channel_names"] = y_data.channels.values.tolist() - - stitched_img = reshape_data.stitch_crops(crop_stack=y_cropped, log_data=log_data) - - relabeled = skimage.measure.label(stitched_img[0, 0, 0, 0, :, :, 0]) - - props = skimage.measure.regionprops_table(relabeled, properties=["area", "label"]) - - # dims are the same - assert np.all(stitched_img.shape == y_data.shape) - - # same number of unique objects before and after - assert (len(np.unique(relabeled)) == len(np.unique(y_data[0, 0, 0, 0, :, :, 0]))) - - # no cell is smaller than offset subtracted from each side - min_size = (side_len - offset_len * 2) ** 2 - max_size = (side_len + offset_len * 2) ** 2 - - assert (np.all(props["area"] <= max_size)) - assert (np.all(props["area"] >= min_size)) - - -def test_stitch_slices(): - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 4 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) - - # generate ordered data - linear_seq = np.arange(stack_len * row_len * col_len) - test_vals = linear_seq.reshape((stack_len, row_len, col_len)) - y_data[0, :, 0, 0, :, :, 0] = test_vals - - X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, - slice_stack_len=slice_stack_len) - - log_data["original_shape"] = X_data.shape - log_data["fov_names"] = X_data.fovs.values - stitched_slices = reshape_data.stitch_slices(y_slice, {**log_data}) - - # dims are the same - assert np.all(stitched_slices.shape == y_data.shape) - - assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) - - # test case without even division of crops into imsize - - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 - slice_stack_len = 7 - - X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) - - y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) - - # generate ordered data - linear_seq = np.arange(stack_len * row_len * col_len) - test_vals = linear_seq.reshape((stack_len, row_len, col_len)) - y_data[0, :, 0, 0, :, :, 0] = test_vals - - X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, - slice_stack_len=slice_stack_len) - - # get parameters - log_data["original_shape"] = y_data.shape - log_data["fov_names"] = y_data.fovs.values - stitched_slices = reshape_data.stitch_slices(y_slice, log_data) - - assert np.all(stitched_slices.shape == y_data.shape) - - assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) - - def test_reconstruct_image_stack(): with tempfile.TemporaryDirectory() as temp_dir: # generate stack of crops from image with grid pattern @@ -817,7 +116,7 @@ def test_reconstruct_image_stack(): crop_size=(crop_size, crop_size), overlap_frac=overlap_frac) - reshape_data.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, + io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, log_data=log_data, save_dir=temp_dir) reshape_data.reconstruct_image_stack(crop_dir=temp_dir) @@ -856,7 +155,7 @@ def test_reconstruct_image_stack(): y_data=y_data, slice_stack_len=slice_stack_len) - reshape_data.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, + io_utils.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, log_data={**slice_log_data}, save_dir=temp_dir, blank_labels="include", save_format="npz", verbose=False) @@ -908,7 +207,7 @@ def test_reconstruct_image_stack(): y_data=y_cropped, slice_stack_len=slice_stack_len) - reshape_data.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, + io_utils.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, log_data={**slice_log_data, **log_data}, save_dir=temp_dir, blank_labels="include", @@ -934,8 +233,3 @@ def test_reconstruct_image_stack(): original[:stack, :stack] = 1 new = stitched_imgs[0, stack, 0, 0, :10, :10, 0] assert np.array_equal(original > 0, new > 0) - - - - - diff --git a/caliban_toolbox/utils/crop_utils.py b/caliban_toolbox/utils/crop_utils.py new file mode 100644 index 0000000..ce46027 --- /dev/null +++ b/caliban_toolbox/utils/crop_utils.py @@ -0,0 +1,203 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import math +import numpy as np + + +from itertools import product + +import xarray as xr + + +def compute_crop_indices(img_len, crop_size, overlap_frac): + """Determine how to crop the image across one dimension. + + Args: + img_len: length of the image for given dimension + crop_size: size in pixels of the crop in given dimension + overlap_frac: fraction that adjacent crops will overlap each other on each side + + Returns: + numpy.array: coordinates for where each crop will start in given dimension + numpy.array: coordinates for where each crop will end in given dimension + int: number of pixels of padding at start and end of image in given dimension + """ + + # compute overlap fraction in pixels + overlap_pix = math.floor(crop_size * overlap_frac) + + # the crops start at pixel 0, and are spaced crop_size - overlap_pix away from each other + start_indices = np.arange(0, img_len - overlap_pix, crop_size - overlap_pix) + + # the crops each end crop_size away the start + end_indices = start_indices + crop_size + + # the padding for the final image is the amount that the last crop goes beyond the image size + padding = end_indices[-1] - img_len + + return start_indices, end_indices, padding + + +def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding): + """Crops an image into pieces according to supplied coordinates + + Args: + input_data: xarray of [fovs, stacks, crops, slices, rows, cols, channels] to be cropped + row_starts: list of indices where row crops start + row_ends: list of indices where row crops end + col_starts: list of indices where col crops start + col_ends: list of indices where col crops end + padding: tuple which specifies the amount of padding on the final image + + Returns: + numpy.array: 7D tensor of cropped images + tuple: shape of the final padded image + """ + + # determine key parameters of crop + fov_len, stack_len, input_crop_num, slice_num, _, _, channel_len = input_data.shape + + if input_crop_num > 1: + raise ValueError("Array has already been cropped") + + crop_num = len(row_starts) * len(col_starts) + crop_size_row = row_ends[0] - row_starts[0] + crop_size_col = col_ends[0] - col_starts[0] + + # create xarray to hold crops + cropped_stack = np.zeros((fov_len, stack_len, crop_num, slice_num, + crop_size_row, crop_size_col, channel_len)) + + # labels for each index within a dimension + coordinate_labels = [input_data.fovs, input_data.stacks, range(crop_num), input_data.slices, + range(crop_size_row), range(crop_size_col), input_data.channels] + + # labels for each dimension + dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] + + cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=dimension_labels) + + # pad the input to account for imperfectly overlapping final crop in rows and cols + formatted_padding = ((0, 0), (0, 0), (0, 0), (0, 0), (0, padding[0]), (0, padding[1]), (0, 0)) + padded_input = np.pad(input_data, formatted_padding, mode='constant', constant_values=0) + + # loop through rows and cols to generate crops + crop_counter = 0 + for i in range(len(row_starts)): + for j in range(len(col_starts)): + cropped_xr[:, :, crop_counter, ...] = padded_input[:, :, 0, :, + row_starts[i]:row_ends[i], + col_starts[j]:col_ends[j], :] + crop_counter += 1 + + return cropped_xr, padded_input.shape + + +def stitch_crops(crop_stack, log_data): + """Takes a stack of annotated labels and stitches them together into a single image + + Args: + crop_stack: 7D tensor of labels to be stitched together + log_data: dictionary of parameters for reconstructing original image data + + Returns: + numpy.array: 7D tensor of reconstructed labels + """ + + # Initialize image with single dimension for channels + fov_len, stack_len, _, _, row_size, col_size, _ = log_data['original_shape'] + row_padding, col_padding = log_data.get('row_padding', 0), log_data.get('col_padding', 0) + stitched_labels = np.zeros((fov_len, stack_len, 1, 1, row_size + row_padding, + col_size + col_padding, 1)) + + row_starts, row_ends = log_data['row_starts'], log_data['row_ends'] + col_starts, col_ends = log_data['col_starts'], log_data['col_ends'] + + if crop_stack.shape[3] != 1: + raise ValueError('Stacks must be combined before stitching can occur') + + # for each fov and stack, loop through rows and columns of crop positions + for fov, stack, row, col in product(range(fov_len), range(stack_len), + range(len(row_starts)), range(len(col_starts))): + + # determine what crop # we're currently working on + crop_counter = row * len(row_starts) + col + + # get current crop + crop = crop_stack[fov, stack, crop_counter, 0, :, :, 0] + + # increment values to ensure unique labels across final image + lowest_allowed_val = np.amax(stitched_labels[fov, stack, ...]) + crop = np.where(crop == 0, crop, crop + lowest_allowed_val) + + # get ids of cells in current crop + potential_overlap_cells = np.unique(crop) + potential_overlap_cells = \ + potential_overlap_cells[np.nonzero(potential_overlap_cells)] + + # get values of stitched image at location where crop will be placed + stitched_crop = stitched_labels[fov, stack, 0, 0, + row_starts[row]:row_ends[row], + col_starts[col]:col_ends[col], 0] + + # loop through each cell in the crop to determine + # if it overlaps with another cell in full image + for cell in potential_overlap_cells: + + # get cell ids present in stitched image + # at location of current cell in crop + stitched_overlap_vals, stitched_overlap_counts = \ + np.unique(stitched_crop[crop == cell], return_counts=True) + + # remove IDs and counts corresponding to overlap with ID 0 (background) + keep_vals = np.nonzero(stitched_overlap_vals) + stitched_overlap_vals = stitched_overlap_vals[keep_vals] + stitched_overlap_counts = stitched_overlap_counts[keep_vals] + + # if there are overlaps, determine which is greatest in count, + # and replace with that ID + if len(stitched_overlap_vals) > 0: + max_overlap = stitched_overlap_vals[np.argmax(stitched_overlap_counts)] + crop[crop == cell] = max_overlap + + # combine the crop with the current values in the stitched image + combined_crop = np.where(stitched_crop > 0, stitched_crop, crop) + + # use this combined crop to update the values of stitched image + stitched_labels[fov, stack, 0, 0, row_starts[row]:row_ends[row], + col_starts[col]:col_ends[col], 0] = combined_crop + + # trim padding to put image back to original size + if row_padding > 0: + stitched_labels = stitched_labels[:, :, :, :, :-row_padding, :, :] + if col_padding > 0: + stitched_labels = stitched_labels[:, :, :, :, :, :-col_padding, :] + + return stitched_labels \ No newline at end of file diff --git a/caliban_toolbox/utils/crop_utils_test.py b/caliban_toolbox/utils/crop_utils_test.py new file mode 100644 index 0000000..fb8502b --- /dev/null +++ b/caliban_toolbox/utils/crop_utils_test.py @@ -0,0 +1,307 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import skimage + +import numpy as np +from caliban_toolbox import reshape_data + +from caliban_toolbox.utils import crop_utils +import xarray as xr + + +def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len): + """Test function to generate a blank xarray with the supplied dimensions + + Inputs + fov_num: number of distinct FOVs + stack_num: number of distinct z stacks + crop_num: number of x/y crops + slice_num: number of z/t slices + row_num: number of rows + col_num: number of cols + chan_num: number of channels + + Outputs + test_xr: xarray of [fov_num, row_num, col_num, chan_num]""" + + test_img = np.zeros((fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len)) + + fovs = ["fov" + str(x) for x in range(1, fov_len + 1)] + channels = ["channel" + str(x) for x in range(1, chan_len + 1)] + + test_stack_xr = xr.DataArray(data=test_img, + coords=[fovs, range(stack_len), range(crop_num), range(slice_num), + range(row_len), range(col_len), channels], + dims=["fovs", "stacks", "crops", "slices", + "rows", "cols", "channels"]) + + return test_stack_xr + + +def test_compute_crop_indices(): + # test corner case of only one crop + img_len, crop_size, overlap_frac = 100, 100, 0.2 + starts, ends, padding = crop_utils.compute_crop_indices(img_len=img_len, crop_size=crop_size, + overlap_frac=overlap_frac) + assert (len(starts) == 1) + assert (len(ends) == 1) + + # test crop size that doesn't divide evenly into image size + img_len, crop_size, overlap_frac = 105, 20, 0.2 + starts, ends, padding = crop_utils.compute_crop_indices(img_len=img_len, crop_size=crop_size, + overlap_frac=overlap_frac) + crop_num = np.ceil(img_len / (crop_size - (crop_size * overlap_frac))) + assert (len(starts) == crop_num) + assert (len(ends) == crop_num) + + crop_end = crop_num * (crop_size - (crop_size * overlap_frac)) + crop_size * overlap_frac + assert (ends[-1] == crop_end) + + # test overlap of 0 between crops + img_len, crop_size, overlap_frac = 200, 20, 0 + starts, ends, padding = crop_utils.compute_crop_indices(img_len=img_len, crop_size=crop_size, + overlap_frac=overlap_frac) + assert (np.all(starts == range(0, 200, 20))) + assert (np.all(ends == range(20, 201, 20))) + assert (padding == 0) + + +def test_crop_helper(): + # img params + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 200, 200, 1 + crop_size, overlap_frac = 200, 0.2 + + # test only one crop + test_xr = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=chan_len) + + starts, ends, padding = crop_utils.compute_crop_indices(img_len=row_len, crop_size=crop_size, + overlap_frac=overlap_frac) + cropped, padded = crop_utils.crop_helper(input_data=test_xr, row_starts=starts, + row_ends=ends, col_starts=starts, col_ends=ends, + padding=(padding, padding)) + + assert (cropped.shape == (fov_len, stack_len, 1, slice_num, row_len, col_len, chan_len)) + + # test crops of different row/col dimensions + row_crop, col_crop = 50, 40 + row_starts, row_ends, row_padding = \ + crop_utils.compute_crop_indices(img_len=row_len, crop_size=row_crop, + overlap_frac=overlap_frac) + + col_starts, col_ends, col_padding = \ + crop_utils.compute_crop_indices(img_len=col_len, crop_size=col_crop, + overlap_frac=overlap_frac) + + cropped, padded = crop_utils.crop_helper(input_data=test_xr, row_starts=row_starts, + row_ends=row_ends, col_starts=col_starts, + col_ends=col_ends, + padding=(row_padding, col_padding)) + + assert (cropped.shape == (fov_len, stack_len, 30, slice_num, row_crop, col_crop, chan_len)) + + # test that correct region of image is being cropped + row_crop, col_crop = 40, 40 + + # assign each pixel in the image a unique value + linear_sequence = np.arange(0, fov_len * 1 * 1 * row_len * col_len * chan_len) + linear_sequence_reshaped = np.reshape(linear_sequence, (fov_len, 1, 1, 1, row_len, + col_len, chan_len)) + test_xr[:, :, :, :, :, :, :] = linear_sequence_reshaped + + # crop the image + row_starts, row_ends, row_padding = \ + crop_utils.compute_crop_indices(img_len=row_len, crop_size=row_crop, + overlap_frac=overlap_frac) + + col_starts, col_ends, col_padding = \ + crop_utils.compute_crop_indices(img_len=col_len, crop_size=col_crop, + overlap_frac=overlap_frac) + + cropped, padded = crop_utils.crop_helper(input_data=test_xr, row_starts=row_starts, + row_ends=row_ends, col_starts=col_starts, + col_ends=col_ends, + padding=(row_padding, col_padding)) + + # check that the values of each crop match the value in uncropped image + for img in range(test_xr.shape[0]): + crop_counter = 0 + for row in range(len(row_starts)): + for col in range(len(col_starts)): + crop = cropped[img, 0, crop_counter, 0, :, :, 0].values + + original_image_crop = test_xr[img, 0, 0, 0, row_starts[row]:row_ends[row], + col_starts[col]:col_ends[col], 0].values + assert (np.all(crop == original_image_crop)) + + crop_counter += 1 + + +def test_stitch_crops(): + # generate stack of crops from image with grid pattern + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # create image with artificial objects to be segmented + + cell_idx = 1 + for i in range(12): + for j in range(11): + for fov in range(y_data.shape[0]): + y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), + (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx + cell_idx += 1 + + # ## Test when crop is same size as image + crop_size, overlap_frac = 400, 0.2 + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=(crop_size, crop_size), + overlap_frac=overlap_frac) + + log_data["original_shape"] = X_data.shape + + # stitch the crops back together + stitched_img = crop_utils.stitch_crops(crop_stack=y_cropped, log_data=log_data) + + # dims are the same + assert np.all(stitched_img.shape == y_data.shape) + + # check that objects are at same location + assert (np.all(np.equal(stitched_img[..., 0] > 0, y_data.values[..., 0] > 0))) + + # check that same number of unique objects + assert len(np.unique(stitched_img)) == len(np.unique(y_data.values)) + + # ## Test when rows has only one crop + crop_size, overlap_frac = (400, 40), 0.2 + + # crop data + X_cropped, y_cropped, log_data = \ + reshape_data.crop_multichannel_data(X_data=X_data, + y_data=y_data, + crop_size=crop_size, + overlap_frac=overlap_frac) + + # stitch back together + log_data["original_shape"] = X_data.shape + stitched_imgs = crop_utils.stitch_crops(crop_stack=y_cropped, log_data=log_data) + + # dims are the same + assert np.all(stitched_imgs.shape == y_data.shape) + + # all the same pixels are marked + assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0))) + + # there are the same number of cells + assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data))) + + # test stitching imperfect annotator labels that slightly overlap + # generate stack of crops from image with grid pattern + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 1, 1, 1, 800, 800, 1 + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + side_len = 40 + cell_num = y_data.shape[4] // side_len + + cell_id = np.arange(1, cell_num ** 2 + 1) + cell_id = np.random.choice(cell_id, cell_num ** 2, replace=False) + cell_idx = 0 + for row in range(cell_num): + for col in range(cell_num): + y_data[0, 0, 0, 0, row * side_len:(row + 1) * side_len, + col * side_len:(col + 1) * side_len, 0] = cell_id[cell_idx] + cell_idx += 1 + + crop_size, overlap_frac = 100, 0.2 + + starts, ends, padding = crop_utils.compute_crop_indices(img_len=row_len, crop_size=crop_size, + overlap_frac=overlap_frac) + + # generate a vector of random offsets to jitter the crop window, + # simulating mismatches between frames + offset_len = 5 + row_offset = np.append( + np.append(0, np.random.randint(-offset_len, offset_len, len(starts) - 2)), 0) + col_offset = np.append( + np.append(0, np.random.randint(-offset_len, offset_len, len(starts) - 2)), 0) + + # modify indices by random offset + row_starts, row_ends = starts + row_offset, ends + row_offset + col_starts, col_ends = starts + col_offset, ends + col_offset + + y_cropped, padded = crop_utils.crop_helper(input_data=y_data, row_starts=row_starts, + row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(padding, padding)) + + # generate log data, since we had to go inside the upper level + # function to modify crop_helper inputs + log_data = {} + log_data["row_starts"] = row_starts.tolist() + log_data["row_ends"] = row_ends.tolist() + log_data["row_crop_size"] = crop_size + log_data["num_row_crops"] = len(row_starts) + log_data["col_starts"] = col_starts.tolist() + log_data["col_ends"] = col_ends.tolist() + log_data["col_crop_size"] = crop_size + log_data["num_col_crops"] = len(col_starts) + log_data["row_padding"] = int(padding) + log_data["col_padding"] = int(padding) + log_data["num_crops"] = y_cropped.shape[2] + log_data["original_shape"] = y_data.shape + log_data["fov_names"] = y_data.fovs.values.tolist() + log_data["channel_names"] = y_data.channels.values.tolist() + + stitched_img = crop_utils.stitch_crops(crop_stack=y_cropped, log_data=log_data) + + relabeled = skimage.measure.label(stitched_img[0, 0, 0, 0, :, :, 0]) + + props = skimage.measure.regionprops_table(relabeled, properties=["area", "label"]) + + # dims are the same + assert np.all(stitched_img.shape == y_data.shape) + + # same number of unique objects before and after + assert (len(np.unique(relabeled)) == len(np.unique(y_data[0, 0, 0, 0, :, :, 0]))) + + # no cell is smaller than offset subtracted from each side + min_size = (side_len - offset_len * 2) ** 2 + max_size = (side_len + offset_len * 2) ** 2 + + assert (np.all(props["area"] <= max_size)) + assert (np.all(props["area"] >= min_size)) \ No newline at end of file diff --git a/caliban_toolbox/utils/io_utils.py b/caliban_toolbox/utils/io_utils.py new file mode 100644 index 0000000..a30022d --- /dev/null +++ b/caliban_toolbox/utils/io_utils.py @@ -0,0 +1,234 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import numpy as np +import os +import json + +from itertools import product + + +def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, blank_labels='include', + save_format='npz', verbose=True): + """Take an array of processed image data and save as NPZ for caliban + + Args: + X_data: 7D tensor of cropped and sliced raw images + y_data: 7D tensor of cropped and sliced labeled images + original_data: the original unmodified images + log_data: data used to reconstruct images + save_dir: path to save the npz and JSON files + blank_labels: whether to include NPZs with blank labels (poor predictions) + or skip (no cells) + save_format: format to save the data (currently only NPZ) + verbose: flag to control print statements + """ + + if not os.path.isdir(save_dir): + os.makedirs(save_dir) + + # if these are present, it means data was cropped/sliced. Otherwise, default to 1 + num_crops = log_data.get('num_crops', 1) + num_slices = log_data.get('num_slices', 1) + + fov_names = original_data.fovs.values + fov_len = len(fov_names) + + if blank_labels not in ['skip', 'include', 'separate']: + raise ValueError('blank_labels must be one of ' + '[skip, include, separate], got {}'.format(blank_labels)) + + if blank_labels == 'separate': + os.makedirs(os.path.join(save_dir, 'separate')) + + # for each fov, loop through 2D crops and 3D slices + for fov, crop, slice in product(range(fov_len), range(num_crops), range(num_slices)): + # generate identifier for crop + npz_id = 'fov_{}_crop_{}_slice_{}'.format(fov_names[fov], crop, slice) + + # get working batch + labels = y_data[fov, :, crop, slice, ...].values + channels = X_data[fov, :, crop, slice, ...].values + + # determine if labels are blank, and if so what to do with npz + if np.sum(labels) == 0: + + # blank labels get saved to separate folder + if blank_labels == 'separate': + if verbose: + print('{} is blank, saving to separate folder'.format(npz_id)) + save_path = os.path.join(save_dir, blank_labels, npz_id) + + # save images as either npz or xarray + if save_format == 'npz': + np.savez(save_path + '.npz', X=channels, y=labels) + + elif save_format == 'xr': + raise NotImplementedError() + + # blank labels don't get saved, empty area of tissue + elif blank_labels == 'skip': + if verbose: + print('{} is blank, skipping saving'.format(npz_id)) + + # blank labels get saved along with other crops + elif blank_labels == 'include': + if verbose: + print('{} is blank, saving to folder'.format(npz_id)) + save_path = os.path.join(save_dir, npz_id) + + # save images as either npz or xarray + if save_format == 'npz': + np.savez(save_path + '.npz', X=channels, y=labels) + + elif save_format == 'xr': + raise NotImplementedError() + + else: + # crop is not blank, save based on file_format + save_path = os.path.join(save_dir, npz_id) + + # save images as either npz or xarray + if save_format == 'npz': + np.savez(save_path + '.npz', X=channels, y=labels) + + elif save_format == 'xr': + raise NotImplementedError() + + log_data['fov_names'] = fov_names.tolist() + log_data['channel_names'] = original_data.channels.values.tolist() + log_data['original_shape'] = original_data.shape + log_data['slice_stack_len'] = X_data.shape[1] + log_data['save_format'] = save_format + + log_path = os.path.join(save_dir, 'log_data.json') + with open(log_path, 'w') as write_file: + json.dump(log_data, write_file) + + +def get_saved_file_path(dir_list, fov_name, crop, slice, file_ext='.npz'): + """Helper function to identify correct file path for an npz file + + Args: + dir_list: list of files in directory + fov_name: string of the current fov_name + crop: int of current crop + slice: int of current slice + file_ext: extension file was saved with + + Returns: + string: formatted file name + + Raises: + ValueError: If multiple file path matches were found + """ + + base_string = 'fov_{}_crop_{}_slice_{}'.format(fov_name, crop, slice) + string_matches = [string for string in dir_list if base_string + '_save_version' in string] + + if len(string_matches) == 0: + full_string = base_string + file_ext + elif len(string_matches) == 1: + full_string = string_matches[0] + else: + raise ValueError('Multiple save versions found: ' + 'please select only a single save version. {}'.format(string_matches)) + return full_string + + +def load_npzs(crop_dir, log_data, verbose=True): + """Reads all of the cropped images from a directory, and aggregates them into a single stack + + Args: + crop_dir: path to directory with cropped npz or xarray files + log_data: dictionary of parameters generated during data saving + + verbose: flag to control print statements + + Returns: + numpy.array: 7D tensor of labeled crops + """ + + fov_names = log_data['fov_names'] + fov_len, stack_len, _, _, row_size, col_size, _ = log_data['original_shape'] + save_format = log_data['save_format'] + + # if cropped/sliced, get size of dimensions. Otherwise, use size in original data + row_crop_size = log_data.get('row_crop_size', row_size) + col_crop_size = log_data.get('col_crop_size', col_size) + slice_stack_len = log_data.get('slice_stack_len', stack_len) + + # if cropped/sliced, get number of crops/slices + num_crops, num_slices = log_data.get('num_crops', 1), log_data.get('num_slices', 1) + stack = np.zeros((fov_len, slice_stack_len, num_crops, + num_slices, row_crop_size, col_crop_size, 1)) + saved_files = os.listdir(crop_dir) + + # for each fov, loop over each 2D crop and 3D slice + for fov, crop, slice in product(range(fov_len), range(num_crops), range(num_slices)): + # load NPZs + if save_format == 'npz': + npz_path = os.path.join(crop_dir, get_saved_file_path(saved_files, + fov_names[fov], + crop, slice)) + if os.path.exists(npz_path): + temp_npz = np.load(npz_path) + + # last slice may be truncated, modify index + if slice == num_slices - 1: + current_stack_len = temp_npz['X'].shape[1] + else: + current_stack_len = slice_stack_len + + stack[fov, :current_stack_len, crop, slice, ...] = temp_npz['y'] + else: + # npz not generated, did not contain any labels, keep blank + if verbose: + print('could not find npz {}, skipping'.format(npz_path)) + + # load xarray + elif save_format == 'xr': + raise NotImplementedError() + # xr_path = os.path.join(crop_dir, get_saved_file_path(saved_files, fov_names[fov], + # crop, slice)) + # if os.path.exists(xr_path): + # temp_xr = xr.open_dataarray(xr_path) + # + # # last slice may be truncated, modify index + # if slice == num_slices - 1: + # current_stack_len = temp_xr.shape[1] + # else: + # current_stack_len = stack_len + # + # stack[fov, :current_stack_len, crop, slice, ...] = temp_xr[..., -1:] + # else: + # # npz not generated, did not contain any labels, keep blank + # print('could not find xr {}, skipping'.format(xr_path)) + + return stack diff --git a/caliban_toolbox/utils/io_utils_test.py b/caliban_toolbox/utils/io_utils_test.py new file mode 100644 index 0000000..959ddc6 --- /dev/null +++ b/caliban_toolbox/utils/io_utils_test.py @@ -0,0 +1,290 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +import json +import pytest +import copy +import tempfile + +import numpy as np + + +from caliban_toolbox import reshape_data +from caliban_toolbox.utils import io_utils + +from caliban_toolbox.utils.crop_utils_test import _blank_data_xr + + +def test_save_npzs_for_caliban(): + fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 4 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, + slice_num=num_slices, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, + slice_num=num_slices, + row_len=row_len, col_len=col_len, chan_len=1) + + sliced_X, sliced_y, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) + + with tempfile.TemporaryDirectory() as temp_dir: + io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) + + # check that correct size was saved + test_npz_labels = np.load(os.path.join(temp_dir, "fov_fov1_crop_0_slice_0.npz")) + + assert test_npz_labels["y"].shape == (slice_stack_len, row_len, col_len, 1) + + assert test_npz_labels["y"].shape[:-1] == test_npz_labels["X"].shape[:-1] + + # check that json saved successfully + with open(os.path.join(temp_dir, "log_data.json")) as json_file: + saved_log_data = json.load(json_file) + + assert saved_log_data["original_shape"] == list(X_data.shape) + + with tempfile.TemporaryDirectory() as temp_dir: + # check that combined crop and slice saving works + crop_size = (10, 10) + overlap_frac = 0.2 + X_cropped, y_cropped, log_data_crop = \ + reshape_data.crop_multichannel_data(X_data=sliced_X, + y_data=sliced_y, + crop_size=crop_size, + overlap_frac=overlap_frac, + test_parameters=False) + + io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, + original_data=X_data, + log_data={**log_data, **log_data_crop}, + save_dir=temp_dir, + blank_labels="include", save_format="npz", + verbose=False) + expected_crop_num = X_cropped.shape[2] * X_cropped.shape[3] + files = os.listdir(temp_dir) + files = [file for file in files if "npz" in file] + + assert len(files) == expected_crop_num + + # check that arguments specifying what to do with blank crops are working + + # set specified crops to not be blank + sliced_y[0, 0, 0, [1, 4, 7], 0, 0, 0] = 27 + expected_crop_num = sliced_X.shape[2] * sliced_X.shape[3] + + # test that function correctly includes blank crops when saving + with tempfile.TemporaryDirectory() as temp_dir: + io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, + original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) + + # check that there is the expected number of files saved to directory + files = os.listdir(temp_dir) + files = [file for file in files if "npz" in file] + + assert len(files) == expected_crop_num + + # test that function correctly skips blank crops when saving + with tempfile.TemporaryDirectory() as temp_dir: + io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, + original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + save_format="npz", + blank_labels="skip", verbose=False) + + # check that expected number of files in directory + files = os.listdir(temp_dir) + files = [file for file in files if "npz" in file] + assert len(files) == 3 + + # test that function correctly saves blank crops to separate folder + with tempfile.TemporaryDirectory() as temp_dir: + io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, + original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + save_format="npz", + blank_labels="separate", verbose=False) + + # check that expected number of files in each directory + files = os.listdir(temp_dir) + files = [file for file in files if "npz" in file] + assert len(files) == 3 + + files = os.listdir(os.path.join(temp_dir, "separate")) + files = [file for file in files if "npz" in file] + assert len(files) == expected_crop_num - 3 + + +# postprocessing + + +def test_get_npz_file_path(): + # create list of npz_ids + dir_list = ["fov_fov1_crop_2_slice_4.npz", "fov_fov1_crop_2_slice_5_save_version_0.npz", + "fov_fov1_crop_2_slice_6_save_version_0.npz", + "fov_fov1_crop_2_slice_6_save_version_1.npz", + "fov_fov1_crop_2_slice_7_save_version_0.npz", + "fov_fov1_crop_2_slice_7_save_version_0_save_version_2.npz"] + + fov, crop = "fov1", 2 + + # test unmodified npz + slice = 4 + output_string = io_utils.get_saved_file_path(dir_list, fov, crop, slice) + + assert output_string == dir_list[0] + + # test single modified npz + slice = 5 + output_string = io_utils.get_saved_file_path(dir_list, fov, crop, slice) + assert output_string == dir_list[1] + + # test that error is raised when multiple save versions present + slice = 6 + with pytest.raises(ValueError): + output_string = io_utils.get_saved_file_path(dir_list, fov, crop, slice) + + # test that error is raised when multiple save versions present due to resaves + slice = 7 + + with pytest.raises(ValueError): + output_string = io_utils.get_saved_file_path(dir_list, fov, crop, slice) + + +def test_load_npzs(): + with tempfile.TemporaryDirectory() as temp_dir: + # first generate image stack that will be sliced up + fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1 + row_len, col_len, chan_len = 50, 50, 3 + slice_stack_len = 4 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # slice the data + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, + slice_stack_len) + + # crop the data + crop_size = (10, 10) + overlap_frac = 0.2 + X_cropped, y_cropped, log_data_crop = \ + reshape_data.crop_multichannel_data( + X_data=X_slice, + y_data=y_slice, + crop_size=crop_size, + overlap_frac=overlap_frac, + test_parameters=False) + + # tag the upper left hand corner of the label in each slice + slice_tags = np.arange(y_cropped.shape[3]) + crop_tags = np.arange(y_cropped.shape[2]) + y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags + y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags + + combined_log_data = {**log_data, **log_data_crop} + + # save the tagged data + io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, + log_data=combined_log_data, save_dir=temp_dir, + blank_labels="include", save_format="npz", + verbose=False) + + with open(os.path.join(temp_dir, "log_data.json")) as json_file: + saved_log_data = json.load(json_file) + + loaded_slices = io_utils.load_npzs(temp_dir, saved_log_data, verbose=False) + + # dims other than channels are the same + assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1])) + + assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) + assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) + + # test slices with unequal last length + with tempfile.TemporaryDirectory() as temp_dir: + # first generate image stack that will be sliced up + fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1 + row_len, col_len, chan_len = 50, 50, 3 + slice_stack_len = 7 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # slice the data + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) + + # crop the data + crop_size = (10, 10) + overlap_frac = 0.2 + X_cropped, y_cropped, log_data_crop = \ + reshape_data.crop_multichannel_data( + X_data=X_slice, + y_data=y_slice, + crop_size=crop_size, + overlap_frac=overlap_frac, + test_parameters=False) + + # tag the upper left hand corner of the annotations in each slice + slice_tags = np.arange(y_cropped.shape[3]) + crop_tags = np.arange(X_cropped.shape[2]) + y_cropped[0, 0, :, 0, 0, 0, 0] = crop_tags + y_cropped[0, 0, 0, :, 0, 0, 0] = slice_tags + + combined_log_data = {**log_data, **log_data_crop} + + # save the tagged data + io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, + log_data=combined_log_data, save_dir=temp_dir, + blank_labels="include", save_format="npz", + verbose=False) + + loaded_slices = io_utils.load_npzs(temp_dir, combined_log_data) + + # dims other than channels are the same + assert (np.all(loaded_slices.shape[:-1] == X_cropped.shape[:-1])) + + assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) + assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) + diff --git a/caliban_toolbox/utils/slice_utils.py b/caliban_toolbox/utils/slice_utils.py new file mode 100644 index 0000000..1c3ccbc --- /dev/null +++ b/caliban_toolbox/utils/slice_utils.py @@ -0,0 +1,159 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import math +import numpy as np +import os +import json + +from itertools import product + +import xarray as xr + + +def compute_slice_indices(stack_len, slice_len, slice_overlap): + """ Determine how to slice an image across the stack dimension. + + Args: + stack_len: total number of z or t stacks + slice_len: number of z/t frames to be included in each slice + slice_overlap: number of z/t frames that will overlap in each slice + + Returns: + numpy.array: coordinates for the start location of each slice + numpy.array: coordinates for the end location of each slice + """ + + if slice_overlap >= slice_len: + raise ValueError('slice overlap must be less than the length of the slice') + + spacing = slice_len - slice_overlap + + # slices_start indices begin at index 0, and are spaced 'spacing' apart from one another + slice_start_indices = np.arange(0, stack_len - slice_overlap, spacing) + + # slices_end indices are 'spacing' away from the start + slice_end_indices = slice_start_indices + slice_len + + if slice_end_indices[-1] != stack_len: + # if slices overshoot, reduce length of final slice + slice_end_indices[-1] = stack_len + + return slice_start_indices, slice_end_indices + + +def slice_helper(data_xr, slice_start_indices, slice_end_indices): + """Divide a stack into smaller slices according to supplied indices + + Args: + data_xr: xarray of to be split into slices + slice_start_indices: list of indices for where slices start + slice_end_indices: list of indices for where slices end + + Returns: + xarray.DataArray: 7D tensor of sliced images + """ + + # get input image dimensions + fov_len, stack_len, crop_num, input_slice_num, row_len, col_len, chan_len = data_xr.shape + + if input_slice_num > 1: + raise ValueError('Input array already contains slice data') + + slice_num = len(slice_start_indices) + sliced_stack_len = slice_end_indices[0] - slice_start_indices[0] + + # create xarray to hold slices + slice_data = np.zeros((fov_len, sliced_stack_len, crop_num, + slice_num, row_len, col_len, chan_len)) + + # labels for each index within a dimension + coordinate_labels = [data_xr.fovs, range(sliced_stack_len), range(crop_num), range(slice_num), + range(row_len), range(col_len), data_xr.channels] + + # labels for each dimension + dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] + + slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=dimension_labels) + + # loop through slice indices to generate sliced data + slice_counter = 0 + for i in range(len(slice_start_indices)): + + if i != len(slice_start_indices) - 1: + # not the last slice + slice_xr[:, :, :, slice_counter, ...] = \ + data_xr[:, slice_start_indices[i]:slice_end_indices[i], :, 0, :, :, :].values + slice_counter += 1 + + else: + # last slice, only index into stack the amount two indices are separated + slice_len = slice_end_indices[i] - slice_start_indices[i] + slice_xr[:, :slice_len, :, slice_counter, ...] = \ + data_xr[:, slice_start_indices[i]:slice_end_indices[i], :, 0, :, :, :].values + slice_counter += 1 + + return slice_xr + + +def stitch_slices(slice_stack, log_data): + """Helper function to stitch slices together back into original sized array + + Args: + slice_stack: xarray of shape [fovs, stacks, crops, slices, rows, cols, segmentation_label] + log_data: log data produced from creation of slice stack + + Returns: + xarray.DataArray: 7D tensor of stitched labeled slices + """ + + # get parameters from dict + fov_len, stack_len, crop_num, _, row_len, col_len, chan_len = log_data['original_shape'] + crop_num = log_data.get('num_crops', crop_num) + row_len = log_data.get('row_crop_size', row_len) + col_len = log_data.get('col_crop_size', col_len) + + slice_start_indices = log_data['slice_start_indices'] + slice_end_indices = log_data['slice_end_indices'] + num_slices, fov_names = log_data['num_slices'], log_data['fov_names'] + + stitched_slices = np.zeros((fov_len, stack_len, crop_num, 1, row_len, col_len, 1)) + + # loop slice indices to generate sliced data + for i in range(num_slices - 1): + stitched_slices[:, slice_start_indices[i]:slice_end_indices[i], :, 0, ...] = \ + slice_stack[:, :, :, i, ...] + + # last slice, only index into stack the amount two indices are separated + last_idx = num_slices - 1 + slice_len = slice_end_indices[last_idx] - slice_start_indices[last_idx] + stitched_slices[:, slice_start_indices[last_idx]:slice_end_indices[last_idx], :, 0, ...] = \ + slice_stack[:, :slice_len, :, last_idx, ...] + + return stitched_slices diff --git a/caliban_toolbox/utils/slice_utils_test.py b/caliban_toolbox/utils/slice_utils_test.py new file mode 100644 index 0000000..9b873eb --- /dev/null +++ b/caliban_toolbox/utils/slice_utils_test.py @@ -0,0 +1,200 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import numpy as np + +from caliban_toolbox import reshape_data +from caliban_toolbox.utils import slice_utils + +from caliban_toolbox.utils.crop_utils_test import _blank_data_xr + + +def test_compute_slice_indices(): + # test when slice divides evenly into stack len + stack_len = 40 + slice_len = 4 + slice_overlap = 0 + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_len, + slice_overlap) + assert np.all(np.equal(slice_start_indices, np.arange(0, stack_len, slice_len))) + + # test when slice_num does not divide evenly into stack_len + stack_len = 42 + slice_len = 5 + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_len, + slice_overlap) + + expected_start_indices = np.arange(0, stack_len, slice_len) + assert np.all(np.equal(slice_start_indices, expected_start_indices)) + + # test overlapping slices + stack_len = 40 + slice_len = 4 + slice_overlap = 1 + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_len, + slice_overlap) + + assert len(slice_start_indices) == int(np.floor(stack_len / (slice_len - slice_overlap))) + assert slice_end_indices[-1] == stack_len + assert slice_end_indices[0] - slice_start_indices[0] == slice_len + + +def test_slice_helper(): + # test output shape with even division of slice + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 4 + + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_stack_len, 0) + + input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=chan_len) + + slice_output = slice_utils.slice_helper(input_data, slice_start_indices, slice_end_indices) + + assert slice_output.shape == (fov_len, slice_stack_len, crop_num, + int(np.ceil(stack_len / slice_stack_len)), + row_len, col_len, chan_len) + + # test output shape with uneven division of slice + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 6 + + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_stack_len, 0) + + input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=chan_len) + + slice_output = slice_utils.slice_helper(input_data, slice_start_indices, slice_end_indices) + + assert slice_output.shape == (fov_len, slice_stack_len, crop_num, + (np.ceil(stack_len / slice_stack_len)), + row_len, col_len, chan_len) + + # test output shape with slice overlaps + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 6 + slice_overlap = 1 + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_stack_len, + slice_overlap) + + input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=chan_len) + + slice_output = slice_utils.slice_helper(input_data, slice_start_indices, slice_end_indices) + + assert slice_output.shape == (fov_len, slice_stack_len, crop_num, + (np.ceil(stack_len / (slice_stack_len - slice_overlap))), + row_len, col_len, chan_len) + + # test output values + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 4 + slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, + slice_stack_len, 0) + + input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, row_len=row_len, col_len=col_len, + chan_len=chan_len) + + # tag upper left hand corner of each image + tags = np.arange(stack_len) + input_data[0, :, 0, 0, 0, 0, 0] = tags + + slice_output = slice_utils.slice_helper(input_data, slice_start_indices, slice_end_indices) + + # loop through each slice, make sure values increment as expected + for i in range(slice_output.shape[1]): + assert np.all(np.equal(slice_output[0, :, 0, i, 0, 0, 0], tags[i * 4:(i + 1) * 4])) + + +def test_stitch_slices(): + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 4 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # generate ordered data + linear_seq = np.arange(stack_len * row_len * col_len) + test_vals = linear_seq.reshape((stack_len, row_len, col_len)) + y_data[0, :, 0, 0, :, :, 0] = test_vals + + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) + + log_data["original_shape"] = X_data.shape + log_data["fov_names"] = X_data.fovs.values + stitched_slices = slice_utils.stitch_slices(y_slice, {**log_data}) + + # dims are the same + assert np.all(stitched_slices.shape == y_data.shape) + + assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) + + # test case without even division of crops into imsize + + fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 + slice_stack_len = 7 + + X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) + + y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=1) + + # generate ordered data + linear_seq = np.arange(stack_len * row_len * col_len) + test_vals = linear_seq.reshape((stack_len, row_len, col_len)) + y_data[0, :, 0, 0, :, :, 0] = test_vals + + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, + slice_stack_len=slice_stack_len) + + # get parameters + log_data["original_shape"] = y_data.shape + log_data["fov_names"] = y_data.fovs.values + stitched_slices = slice_utils.stitch_slices(y_slice, log_data) + + assert np.all(stitched_slices.shape == y_data.shape) + + assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) + + From 893703999edfe3fe1c6b471e3d7daf4ac4d10c6a Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sat, 25 Apr 2020 21:56:55 -0700 Subject: [PATCH 07/12] pep8 --- caliban_toolbox/reshape_data.py | 24 +++++---- caliban_toolbox/reshape_data_test.py | 34 ++++++------ caliban_toolbox/utils/crop_utils.py | 3 +- caliban_toolbox/utils/crop_utils_test.py | 48 ++++++++--------- caliban_toolbox/utils/io_utils.py | 4 +- caliban_toolbox/utils/io_utils_test.py | 66 +++++++++++------------ caliban_toolbox/utils/slice_utils_test.py | 28 +++++----- 7 files changed, 104 insertions(+), 103 deletions(-) diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 8bb89d3..4d7529c 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -73,21 +73,23 @@ def crop_multichannel_data(X_data, y_data, crop_size, overlap_frac, test_paramet # compute the start and end coordinates for the row and column crops row_starts, row_ends, row_padding = crop_utils.compute_crop_indices(img_len=X_data.shape[4], - crop_size=crop_size[0], - overlap_frac=overlap_frac) + crop_size=crop_size[0], + overlap_frac=overlap_frac) col_starts, col_ends, col_padding = crop_utils.compute_crop_indices(img_len=X_data.shape[5], - crop_size=crop_size[1], - overlap_frac=overlap_frac) + crop_size=crop_size[1], + overlap_frac=overlap_frac) # crop images - X_data_cropped, padded_shape = crop_utils.crop_helper(X_data, row_starts=row_starts, row_ends=row_ends, - col_starts=col_starts, col_ends=col_ends, - padding=(row_padding, col_padding)) - - y_data_cropped, padded_shape = crop_utils.crop_helper(y_data, row_starts=row_starts, row_ends=row_ends, - col_starts=col_starts, col_ends=col_ends, - padding=(row_padding, col_padding)) + X_data_cropped, padded_shape = crop_utils.crop_helper(X_data, row_starts=row_starts, + row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(row_padding, col_padding)) + + y_data_cropped, padded_shape = crop_utils.crop_helper(y_data, row_starts=row_starts, + row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(row_padding, col_padding)) # save relevant parameters for reconstructing image log_data = {} diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index d173d56..78af6ae 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -58,9 +58,9 @@ def test_crop_multichannel_data(): test_parameters=False) expected_crop_num = len(crop_utils.compute_crop_indices(row_len, crop_size[0], - overlap_frac)[0]) ** 2 + overlap_frac)[0]) ** 2 assert (X_data_cropped.shape == (fov_len, stack_len, expected_crop_num, slice_num, - crop_size[0], crop_size[1], channel_len)) + crop_size[0], crop_size[1], channel_len)) assert log_data["num_crops"] == expected_crop_num @@ -78,17 +78,19 @@ def test_create_slice_data(): slice_num=num_slices, row_len=row_len, col_len=col_len, chan_len=chan_len) - X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) + X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data, + slice_stack_len) assert X_slice.shape == (fov_len, slice_stack_len, num_crops, - int(np.ceil(stack_len / slice_stack_len)), - row_len, col_len, chan_len) + int(np.ceil(stack_len / slice_stack_len)), + row_len, col_len, chan_len) def test_reconstruct_image_stack(): with tempfile.TemporaryDirectory() as temp_dir: # generate stack of crops from image with grid pattern - fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 + (fov_len, stack_len, crop_num, + slice_num, row_len, col_len, chan_len) = 2, 1, 1, 1, 400, 400, 4 X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, @@ -105,7 +107,7 @@ def test_reconstruct_image_stack(): for j in range(11): for fov in range(y_data.shape[0]): y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10), - (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx + (j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx cell_idx += 1 # Crop the data @@ -117,7 +119,7 @@ def test_reconstruct_image_stack(): overlap_frac=overlap_frac) io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=log_data, save_dir=temp_dir) + log_data=log_data, save_dir=temp_dir) reshape_data.reconstruct_image_stack(crop_dir=temp_dir) @@ -156,9 +158,9 @@ def test_reconstruct_image_stack(): slice_stack_len=slice_stack_len) io_utils.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, - log_data={**slice_log_data}, save_dir=temp_dir, - blank_labels="include", - save_format="npz", verbose=False) + log_data={**slice_log_data}, save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) reshape_data.reconstruct_image_stack(temp_dir) stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) @@ -187,7 +189,7 @@ def test_reconstruct_image_stack(): for j in range(1, 11): for stack in range(stack_len): y_data[:, stack, :, :, (i * 35):(i * 35 + 10 + stack * 2), - (j * 37):(j * 37 + 8 + stack * 2), 0] = cell_idx + (j * 37):(j * 37 + 8 + stack * 2), 0] = cell_idx cell_idx += 1 # tag upper left hand corner of each image with squares of increasing size @@ -208,10 +210,10 @@ def test_reconstruct_image_stack(): slice_stack_len=slice_stack_len) io_utils.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data, - log_data={**slice_log_data, **log_data}, - save_dir=temp_dir, - blank_labels="include", - save_format="npz", verbose=False) + log_data={**slice_log_data, **log_data}, + save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) reshape_data.reconstruct_image_stack(temp_dir) stitched_imgs = xr.open_dataarray(os.path.join(temp_dir, 'stitched_images.xr')) diff --git a/caliban_toolbox/utils/crop_utils.py b/caliban_toolbox/utils/crop_utils.py index ce46027..20685aa 100644 --- a/caliban_toolbox/utils/crop_utils.py +++ b/caliban_toolbox/utils/crop_utils.py @@ -30,7 +30,6 @@ import math import numpy as np - from itertools import product import xarray as xr @@ -200,4 +199,4 @@ def stitch_crops(crop_stack, log_data): if col_padding > 0: stitched_labels = stitched_labels[:, :, :, :, :, :-col_padding, :] - return stitched_labels \ No newline at end of file + return stitched_labels diff --git a/caliban_toolbox/utils/crop_utils_test.py b/caliban_toolbox/utils/crop_utils_test.py index fb8502b..a33cc23 100644 --- a/caliban_toolbox/utils/crop_utils_test.py +++ b/caliban_toolbox/utils/crop_utils_test.py @@ -65,14 +65,14 @@ def test_compute_crop_indices(): # test corner case of only one crop img_len, crop_size, overlap_frac = 100, 100, 0.2 starts, ends, padding = crop_utils.compute_crop_indices(img_len=img_len, crop_size=crop_size, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) assert (len(starts) == 1) assert (len(ends) == 1) # test crop size that doesn't divide evenly into image size img_len, crop_size, overlap_frac = 105, 20, 0.2 starts, ends, padding = crop_utils.compute_crop_indices(img_len=img_len, crop_size=crop_size, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) crop_num = np.ceil(img_len / (crop_size - (crop_size * overlap_frac))) assert (len(starts) == crop_num) assert (len(ends) == crop_num) @@ -83,7 +83,7 @@ def test_compute_crop_indices(): # test overlap of 0 between crops img_len, crop_size, overlap_frac = 200, 20, 0 starts, ends, padding = crop_utils.compute_crop_indices(img_len=img_len, crop_size=crop_size, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) assert (np.all(starts == range(0, 200, 20))) assert (np.all(ends == range(20, 201, 20))) assert (padding == 0) @@ -100,10 +100,10 @@ def test_crop_helper(): chan_len=chan_len) starts, ends, padding = crop_utils.compute_crop_indices(img_len=row_len, crop_size=crop_size, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) cropped, padded = crop_utils.crop_helper(input_data=test_xr, row_starts=starts, - row_ends=ends, col_starts=starts, col_ends=ends, - padding=(padding, padding)) + row_ends=ends, col_starts=starts, col_ends=ends, + padding=(padding, padding)) assert (cropped.shape == (fov_len, stack_len, 1, slice_num, row_len, col_len, chan_len)) @@ -111,16 +111,16 @@ def test_crop_helper(): row_crop, col_crop = 50, 40 row_starts, row_ends, row_padding = \ crop_utils.compute_crop_indices(img_len=row_len, crop_size=row_crop, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) col_starts, col_ends, col_padding = \ crop_utils.compute_crop_indices(img_len=col_len, crop_size=col_crop, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) cropped, padded = crop_utils.crop_helper(input_data=test_xr, row_starts=row_starts, - row_ends=row_ends, col_starts=col_starts, - col_ends=col_ends, - padding=(row_padding, col_padding)) + row_ends=row_ends, col_starts=col_starts, + col_ends=col_ends, + padding=(row_padding, col_padding)) assert (cropped.shape == (fov_len, stack_len, 30, slice_num, row_crop, col_crop, chan_len)) @@ -136,16 +136,16 @@ def test_crop_helper(): # crop the image row_starts, row_ends, row_padding = \ crop_utils.compute_crop_indices(img_len=row_len, crop_size=row_crop, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) col_starts, col_ends, col_padding = \ crop_utils.compute_crop_indices(img_len=col_len, crop_size=col_crop, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) cropped, padded = crop_utils.crop_helper(input_data=test_xr, row_starts=row_starts, - row_ends=row_ends, col_starts=col_starts, - col_ends=col_ends, - padding=(row_padding, col_padding)) + row_ends=row_ends, col_starts=col_starts, + col_ends=col_ends, + padding=(row_padding, col_padding)) # check that the values of each crop match the value in uncropped image for img in range(test_xr.shape[0]): @@ -166,8 +166,8 @@ def test_stitch_crops(): fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 2, 1, 1, 1, 400, 400, 4 X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, @@ -244,13 +244,13 @@ def test_stitch_crops(): for row in range(cell_num): for col in range(cell_num): y_data[0, 0, 0, 0, row * side_len:(row + 1) * side_len, - col * side_len:(col + 1) * side_len, 0] = cell_id[cell_idx] + col * side_len:(col + 1) * side_len, 0] = cell_id[cell_idx] cell_idx += 1 crop_size, overlap_frac = 100, 0.2 starts, ends, padding = crop_utils.compute_crop_indices(img_len=row_len, crop_size=crop_size, - overlap_frac=overlap_frac) + overlap_frac=overlap_frac) # generate a vector of random offsets to jitter the crop window, # simulating mismatches between frames @@ -265,9 +265,9 @@ def test_stitch_crops(): col_starts, col_ends = starts + col_offset, ends + col_offset y_cropped, padded = crop_utils.crop_helper(input_data=y_data, row_starts=row_starts, - row_ends=row_ends, - col_starts=col_starts, col_ends=col_ends, - padding=(padding, padding)) + row_ends=row_ends, + col_starts=col_starts, col_ends=col_ends, + padding=(padding, padding)) # generate log data, since we had to go inside the upper level # function to modify crop_helper inputs @@ -304,4 +304,4 @@ def test_stitch_crops(): max_size = (side_len + offset_len * 2) ** 2 assert (np.all(props["area"] <= max_size)) - assert (np.all(props["area"] >= min_size)) \ No newline at end of file + assert (np.all(props["area"] >= min_size)) diff --git a/caliban_toolbox/utils/io_utils.py b/caliban_toolbox/utils/io_utils.py index a30022d..a3cc913 100644 --- a/caliban_toolbox/utils/io_utils.py +++ b/caliban_toolbox/utils/io_utils.py @@ -34,8 +34,8 @@ from itertools import product -def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, blank_labels='include', - save_format='npz', verbose=True): +def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, + blank_labels='include', save_format='npz', verbose=True): """Take an array of processed image data and save as NPZ for caliban Args: diff --git a/caliban_toolbox/utils/io_utils_test.py b/caliban_toolbox/utils/io_utils_test.py index 959ddc6..f015a7f 100644 --- a/caliban_toolbox/utils/io_utils_test.py +++ b/caliban_toolbox/utils/io_utils_test.py @@ -31,7 +31,6 @@ import numpy as np - from caliban_toolbox import reshape_data from caliban_toolbox.utils import io_utils @@ -43,8 +42,8 @@ def test_save_npzs_for_caliban(): slice_stack_len = 4 X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, - slice_num=num_slices, - row_len=row_len, col_len=col_len, chan_len=chan_len) + slice_num=num_slices, + row_len=row_len, col_len=col_len, chan_len=chan_len) y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, slice_num=num_slices, @@ -55,9 +54,9 @@ def test_save_npzs_for_caliban(): with tempfile.TemporaryDirectory() as temp_dir: io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - blank_labels="include", - save_format="npz", verbose=False) + log_data=copy.copy(log_data), save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) # check that correct size was saved test_npz_labels = np.load(os.path.join(temp_dir, "fov_fov1_crop_0_slice_0.npz")) @@ -84,11 +83,11 @@ def test_save_npzs_for_caliban(): test_parameters=False) io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, - original_data=X_data, - log_data={**log_data, **log_data_crop}, - save_dir=temp_dir, - blank_labels="include", save_format="npz", - verbose=False) + original_data=X_data, + log_data={**log_data, **log_data_crop}, + save_dir=temp_dir, + blank_labels="include", save_format="npz", + verbose=False) expected_crop_num = X_cropped.shape[2] * X_cropped.shape[3] files = os.listdir(temp_dir) files = [file for file in files if "npz" in file] @@ -104,10 +103,10 @@ def test_save_npzs_for_caliban(): # test that function correctly includes blank crops when saving with tempfile.TemporaryDirectory() as temp_dir: io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, - original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - blank_labels="include", - save_format="npz", verbose=False) + original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + blank_labels="include", + save_format="npz", verbose=False) # check that there is the expected number of files saved to directory files = os.listdir(temp_dir) @@ -118,10 +117,10 @@ def test_save_npzs_for_caliban(): # test that function correctly skips blank crops when saving with tempfile.TemporaryDirectory() as temp_dir: io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, - original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - save_format="npz", - blank_labels="skip", verbose=False) + original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + save_format="npz", + blank_labels="skip", verbose=False) # check that expected number of files in directory files = os.listdir(temp_dir) @@ -131,10 +130,10 @@ def test_save_npzs_for_caliban(): # test that function correctly saves blank crops to separate folder with tempfile.TemporaryDirectory() as temp_dir: io_utils.save_npzs_for_caliban(X_data=sliced_X, y_data=sliced_y, - original_data=X_data, - log_data=copy.copy(log_data), save_dir=temp_dir, - save_format="npz", - blank_labels="separate", verbose=False) + original_data=X_data, + log_data=copy.copy(log_data), save_dir=temp_dir, + save_format="npz", + blank_labels="separate", verbose=False) # check that expected number of files in each directory files = os.listdir(temp_dir) @@ -222,9 +221,9 @@ def test_load_npzs(): # save the tagged data io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=combined_log_data, save_dir=temp_dir, - blank_labels="include", save_format="npz", - verbose=False) + log_data=combined_log_data, save_dir=temp_dir, + blank_labels="include", save_format="npz", + verbose=False) with open(os.path.join(temp_dir, "log_data.json")) as json_file: saved_log_data = json.load(json_file) @@ -245,15 +244,17 @@ def test_load_npzs(): slice_stack_len = 7 X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, chan_len=1) # slice the data - X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) + X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, + y_data, + slice_stack_len) # crop the data crop_size = (10, 10) @@ -276,9 +277,9 @@ def test_load_npzs(): # save the tagged data io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data, - log_data=combined_log_data, save_dir=temp_dir, - blank_labels="include", save_format="npz", - verbose=False) + log_data=combined_log_data, save_dir=temp_dir, + blank_labels="include", save_format="npz", + verbose=False) loaded_slices = io_utils.load_npzs(temp_dir, combined_log_data) @@ -287,4 +288,3 @@ def test_load_npzs(): assert np.all(np.equal(loaded_slices[0, 0, :, 0, 0, 0, 0], crop_tags)) assert np.all(np.equal(loaded_slices[0, 0, 0, :, 0, 0, 0], slice_tags)) - diff --git a/caliban_toolbox/utils/slice_utils_test.py b/caliban_toolbox/utils/slice_utils_test.py index 9b873eb..0aa9f7b 100644 --- a/caliban_toolbox/utils/slice_utils_test.py +++ b/caliban_toolbox/utils/slice_utils_test.py @@ -37,16 +37,16 @@ def test_compute_slice_indices(): slice_len = 4 slice_overlap = 0 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_len, - slice_overlap) + slice_len, + slice_overlap) assert np.all(np.equal(slice_start_indices, np.arange(0, stack_len, slice_len))) # test when slice_num does not divide evenly into stack_len stack_len = 42 slice_len = 5 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_len, - slice_overlap) + slice_len, + slice_overlap) expected_start_indices = np.arange(0, stack_len, slice_len) assert np.all(np.equal(slice_start_indices, expected_start_indices)) @@ -56,8 +56,8 @@ def test_compute_slice_indices(): slice_len = 4 slice_overlap = 1 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_len, - slice_overlap) + slice_len, + slice_overlap) assert len(slice_start_indices) == int(np.floor(stack_len / (slice_len - slice_overlap))) assert slice_end_indices[-1] == stack_len @@ -70,7 +70,7 @@ def test_slice_helper(): slice_stack_len = 4 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_stack_len, 0) + slice_stack_len, 0) input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, @@ -87,7 +87,7 @@ def test_slice_helper(): slice_stack_len = 6 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_stack_len, 0) + slice_stack_len, 0) input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, @@ -104,8 +104,8 @@ def test_slice_helper(): slice_stack_len = 6 slice_overlap = 1 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_stack_len, - slice_overlap) + slice_stack_len, + slice_overlap) input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, @@ -121,7 +121,7 @@ def test_slice_helper(): fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3 slice_stack_len = 4 slice_start_indices, slice_end_indices = slice_utils.compute_slice_indices(stack_len, - slice_stack_len, 0) + slice_stack_len, 0) input_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, @@ -173,8 +173,8 @@ def test_stitch_slices(): slice_stack_len = 7 X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, - slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + slice_num=slice_num, + row_len=row_len, col_len=col_len, chan_len=chan_len) y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, @@ -196,5 +196,3 @@ def test_stitch_slices(): assert np.all(stitched_slices.shape == y_data.shape) assert np.all(np.equal(stitched_slices[0, :, 0, 0, :, :, 0], test_vals)) - - From 60a101b9d1a3f08251d537ba9d721de8b772ed54 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sun, 26 Apr 2020 09:56:36 -0700 Subject: [PATCH 08/12] migrated xarray functions from angelolab repo --- caliban_toolbox/utils/data_utils.py | 155 +++++++++++++++++++++++ caliban_toolbox/utils/data_utils_test.py | 138 ++++++++++++++++++++ 2 files changed, 293 insertions(+) create mode 100644 caliban_toolbox/utils/data_utils.py create mode 100644 caliban_toolbox/utils/data_utils_test.py diff --git a/caliban_toolbox/utils/data_utils.py b/caliban_toolbox/utils/data_utils.py new file mode 100644 index 0000000..344acbc --- /dev/null +++ b/caliban_toolbox/utils/data_utils.py @@ -0,0 +1,155 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np +import xarray as xr + + +def pad_xr_dims(input_xr, padded_dims): + """Takes an xarray and pads it with dimensions of size 1 according to the supplied dims list + + Inputs + input_xr: xarray to pad + padded_dims: ordered list of final dims; new dims will be added with size 1 + + Outputs + padded_xr: xarray that has had additional dims added of size 1 + + Raises: + ValueError: If padded dims includes existing dims in a different order + ValueError: If padded dims includes duplicate names + """ + + # make sure that dimensions which are present in both lists are in same order + old_dims = [dim for dim in padded_dims if dim in input_xr.dims] + + if not old_dims == list(input_xr.dims): + raise ValueError("existing dimensions in the xarray must be in same order") + + if len(np.unique(padded_dims)) != len(padded_dims): + raise ValueError('Dimensions must have unique names') + + # create new output data + output_vals = input_xr.values + output_coords = [] + + for idx, dim in enumerate(padded_dims): + + if dim in input_xr.dims: + # dimension already exists, using existing values and coords + output_coords.append(input_xr[dim]) + else: + output_vals = np.expand_dims(output_vals, axis=idx) + output_coords.append(range(1)) + + padded_xr = xr.DataArray(output_vals, coords=output_coords, dims=padded_dims) + + return padded_xr + + +def create_blank_channel(img_size, dtype='int16', full_blank=False): + """Creates a mostly blank channel of specified size + + Args: + img_size: tuple specifying the size of the image to create + dtype: dtype for image + full_blank: boolean to set whether image has few sparse pixels, or is completely blank + + Returns: + numpy.array: a (mostly) blank array with positive pixels in random values + """ + + blank = np.zeros(img_size, dtype=dtype) + + if full_blank: + return blank + else: + # noise will be created within 100 pixel boxes + row_steps = math.floor(blank.shape[0] / 100) + col_steps = math.floor(blank.shape[1] / 100) + + for row_step in range(row_steps): + for col_step in range(col_steps): + row_index = np.random.randint(0, 100 - 1) + col_index = np.random.randint(0, 100 - 1) + blank[row_step * 100 + row_index, col_step * 100 + col_index] = \ + np.random.randint(1, 15) + + return blank + + +def reorder_channels(new_channel_order, input_data, full_blank=True): + """Reorders the channels in an xarray to match new_channel_order. New channels will be blank + + Args: + new_channel_order: ordered list of channels for output data + input_data: xarray to be reordered + full_blank: whether new channels should be completely blank (for visualization), + or mostly blank with noise (for model training to avoid divide by zero errors). + + Returns: + xarray.DataArray: Reordered version of input_data + + Raises: + ValueError: If new_channel_order contains duplicated entries + """ + + # error checking + vals, counts = np.unique(new_channel_order, return_counts=True) + duplicated = np.where(counts > 1) + if len(duplicated[0] > 0): + raise ValueError("The following channels are duplicated " + "in new_channel_order: {}".format(vals[duplicated[0]])) + + # create array for output data + full_array = np.zeros((input_data.shape[:-1] + (len(new_channel_order),)), + dtype=input_data.dtype) + + existing_channels = input_data.channels + + for i in range(len(new_channel_order)): + if new_channel_order[i] in existing_channels: + current_channel = input_data.loc[:, :, :, new_channel_order[i]].values + full_array[:, :, :, i] = current_channel + else: + print('Creating blank channel with {}'.format(new_channel_order[i])) + blank = create_blank_channel(input_data.shape[1:3], dtype=input_data.dtype, + full_blank=full_blank) + full_array[:, :, :, i] = blank + + coords = [input_data.fovs, range(input_data.shape[1]), + range(input_data.shape[2]), new_channel_order] + + dims = ["fovs", "rows", "cols", "channels"] + + channel_xr_blanked = xr.DataArray(full_array, coords=coords, dims=dims) + + return channel_xr_blanked diff --git a/caliban_toolbox/utils/data_utils_test.py b/caliban_toolbox/utils/data_utils_test.py new file mode 100644 index 0000000..f8c744b --- /dev/null +++ b/caliban_toolbox/utils/data_utils_test.py @@ -0,0 +1,138 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import pytest + +import numpy as np +import xarray as xr +from caliban_toolbox.utils import data_utils + + +def test_pad_xr_dims(): + test_input = np.zeros((2, 10, 10, 3)) + test_coords = [['Point1', 'Point2'], range(test_input.shape[1]), range(test_input.shape[2]), + ['chan0', 'chan1', 'chan2']] + test_dims = ['fovs', 'rows', 'cols', 'channels'] + + test_xr = xr.DataArray(test_input, coords=test_coords, dims=test_dims) + + padded_dims = ['fovs', 'rows', 'new_dim1', 'cols', 'new_dim2', 'channels'] + + padded_xr = data_utils.pad_xr_dims(test_xr, padded_dims) + + assert list(padded_xr.dims) == padded_dims + assert len(padded_xr['new_dim1']) == len(padded_xr['new_dim2']) == 1 + + # check that error raised when wrong dimensions + padded_wrong_order_dims = ['rows', 'fovs', 'new_dim1', 'cols', 'new_dim2', 'channels'] + + with pytest.raises(ValueError): + data_utils.pad_xr_dims(test_xr, padded_wrong_order_dims) + + # check that error raised when duplicated dimensions + padded_duplicated_dims = ['fovs', 'rows', 'new_dim1', 'cols', 'new_dim1', 'channels'] + + with pytest.raises(ValueError): + data_utils.pad_xr_dims(test_xr, padded_duplicated_dims) + + +def test_create_blank_channel(): + + semi_blank = data_utils.create_blank_channel(img_size=(1024, 1024), dtype="int16", + full_blank=False) + + assert semi_blank.shape == (1024, 1024) + assert np.sum(semi_blank > 0) == 10 * 10 + + full_blank = data_utils.create_blank_channel(img_size=(1024, 1024), dtype="int16", + full_blank=True) + assert np.sum(full_blank) == 0 + + +def test_reorder_channels(): + + # test switching without blank channels + test_input = np.random.randint(5, size=(2, 128, 128, 3)) + + # channel 0 is 3x bigger, channel 2 is 3x smaller + test_input[:, :, :, 0] *= 3 + test_input[:, :, :, 2] //= 3 + + coords = [['fov1', 'fov2'], range(test_input.shape[1]), + range(test_input.shape[2]), ['chan0', 'chan1', 'chan2']] + dims = ['fovs', 'rows', 'cols', 'channels'] + input_data = xr.DataArray(test_input, coords=coords, dims=dims) + + new_channel_order = ['chan2', 'chan1', 'chan0'] + + reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, + input_data=input_data) + + # confirm that labels are in correct order, and that values were switched as well + assert np.array_equal(new_channel_order, reordered_data.channels) + + for chan in input_data.channels.values: + assert np.array_equal(reordered_data.loc[:, :, :, chan], input_data.loc[:, :, :, chan]) + + # test switching with blank channels + new_channel_order = ['chan0', 'chan1', 'chan666', 'chan2'] + reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, + input_data=input_data) + + # make sure order was switched, and that blank channel is empty + assert np.array_equal(new_channel_order, reordered_data.channels) + + # make sure new channel is empty and existing channels have same value + for chan in reordered_data.channels.values: + if chan in input_data.channels: + assert np.array_equal(reordered_data.loc[:, :, :, chan], input_data.loc[:, :, :, chan]) + else: + assert np.sum(reordered_data.loc[:, :, :, chan].values > 0) == 0 + + # test switching with blank channels and existing channels in new order + new_channel_order = ['chan2', 'chan11', 'chan1', 'chan12', 'chan0'] + reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, + input_data=input_data) + + assert np.array_equal(new_channel_order, reordered_data.channels) + + # make sure new channel is empty and existing channels have same value + for chan in reordered_data.channels.values: + if chan in input_data.channels: + assert np.array_equal(reordered_data.loc[:, :, :, chan], input_data.loc[:, :, :, chan]) + else: + assert np.sum(reordered_data.loc[:, :, :, chan].values > 0) == 0 + + # New channels have duplicates + with pytest.raises(ValueError): + new_channel_order = ['chan0', 'chan1', 'chan2', 'chan2'] + reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, + input_data=input_data) + + From 546f361096b3ea8a27f9e2f08a22e989393c1590 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Sun, 26 Apr 2020 10:26:53 -0700 Subject: [PATCH 09/12] simplified set_colors function --- caliban_toolbox/utils/plot_utils.py | 43 +++++++------- caliban_toolbox/utils/plot_utils_test.py | 74 ++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 23 deletions(-) create mode 100644 caliban_toolbox/utils/plot_utils_test.py diff --git a/caliban_toolbox/utils/plot_utils.py b/caliban_toolbox/utils/plot_utils.py index f0da5d3..3ed40a2 100644 --- a/caliban_toolbox/utils/plot_utils.py +++ b/caliban_toolbox/utils/plot_utils.py @@ -29,6 +29,8 @@ import numpy as np +from caliban_toolbox.utils import data_utils + def overlay_grid_lines(overlay_img, row_starts, row_ends, col_starts, col_ends): """Visualize the location of image crops on the original uncropped image to assess crop size @@ -132,51 +134,46 @@ def overlay_crop_overlap(img_crop, row_starts, row_ends, col_starts, col_ends): return img_crop -def set_channel_colors(combined_xr, plot_colors): +def set_channel_colors(channel_data, plot_colors): """Modifies the order of image channels so they're displayed with appropriate color in caliban Args: - combined_xr: xarray containing channels and labels - plot_colors: array containing the color of each channel, in order of the current channels + channel_data: xarray containing channels + plot_colors: array containing the desired color for each channel Returns: xarray.DataArray: reordered image data to enable visualization in caliban """ - # first define the order that channels are visualize - channel_order = np.array(['red', 'green', 'blue', 'cyan', - 'magenta', 'yellow', 'segmentation_label']) + # first define which channel index is visualized as what color by caliban + color_order = np.array(['red', 'green', 'blue', 'cyan', + 'magenta', 'yellow']) - # create the array holding the final ordering of channel names - final_channel_names = np.array(['red', 'green', 'blue', 'cyan', 'magenta', - 'yellow', 'segmentation_label'], dtype=' Date: Sun, 26 Apr 2020 10:38:48 -0700 Subject: [PATCH 10/12] pep8 --- caliban_toolbox/aws_functions.py | 4 +++- caliban_toolbox/utils/data_utils_test.py | 10 ++++------ caliban_toolbox/utils/plot_utils.py | 4 ++-- caliban_toolbox/utils/plot_utils_test.py | 1 - 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/caliban_toolbox/aws_functions.py b/caliban_toolbox/aws_functions.py index 8375d43..52c450b 100755 --- a/caliban_toolbox/aws_functions.py +++ b/caliban_toolbox/aws_functions.py @@ -91,7 +91,9 @@ def aws_upload_files(aws_folder, stage, upload_folder, pixel_only, label_only, r subfolders = re.split('/', aws_folder) subfolders = '__'.join(subfolders) - url_dict = {'pixel_only': pixel_only, 'label_only': label_only, 'rgb': rgb_mode} + # TODO: add back label_only once logic is working correctly in caliban + # url_dict = {'pixel_only': pixel_only, 'label_only': label_only, 'rgb': rgb_mode} + url_dict = {'pixel_only': pixel_only, 'rgb': rgb_mode} url_encoded_dict = urlencode(url_dict) # upload images diff --git a/caliban_toolbox/utils/data_utils_test.py b/caliban_toolbox/utils/data_utils_test.py index f8c744b..8bb8ff8 100644 --- a/caliban_toolbox/utils/data_utils_test.py +++ b/caliban_toolbox/utils/data_utils_test.py @@ -60,7 +60,7 @@ def test_pad_xr_dims(): with pytest.raises(ValueError): data_utils.pad_xr_dims(test_xr, padded_duplicated_dims) - + def test_create_blank_channel(): @@ -90,8 +90,8 @@ def test_reorder_channels(): input_data = xr.DataArray(test_input, coords=coords, dims=dims) new_channel_order = ['chan2', 'chan1', 'chan0'] - - reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, + + reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, input_data=input_data) # confirm that labels are in correct order, and that values were switched as well @@ -103,7 +103,7 @@ def test_reorder_channels(): # test switching with blank channels new_channel_order = ['chan0', 'chan1', 'chan666', 'chan2'] reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, - input_data=input_data) + input_data=input_data) # make sure order was switched, and that blank channel is empty assert np.array_equal(new_channel_order, reordered_data.channels) @@ -134,5 +134,3 @@ def test_reorder_channels(): new_channel_order = ['chan0', 'chan1', 'chan2', 'chan2'] reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, input_data=input_data) - - diff --git a/caliban_toolbox/utils/plot_utils.py b/caliban_toolbox/utils/plot_utils.py index 3ed40a2..1c880e9 100644 --- a/caliban_toolbox/utils/plot_utils.py +++ b/caliban_toolbox/utils/plot_utils.py @@ -147,7 +147,7 @@ def set_channel_colors(channel_data, plot_colors): # first define which channel index is visualized as what color by caliban color_order = np.array(['red', 'green', 'blue', 'cyan', - 'magenta', 'yellow']) + 'magenta', 'yellow']) # create the array which will hold the final ordering of channel names final_channel_order = np.array(['red', 'green', 'blue', 'cyan', 'magenta', @@ -174,6 +174,6 @@ def set_channel_colors(channel_data, plot_colors): # reorder the xarray reordered_xr = data_utils.reorder_channels(new_channel_order=final_channel_order, - input_data=channel_data) + input_data=channel_data) return reordered_xr diff --git a/caliban_toolbox/utils/plot_utils_test.py b/caliban_toolbox/utils/plot_utils_test.py index 9298cff..3a95f2f 100644 --- a/caliban_toolbox/utils/plot_utils_test.py +++ b/caliban_toolbox/utils/plot_utils_test.py @@ -71,4 +71,3 @@ def test_set_channel_colors(): with pytest.raises(ValueError): colors = ['magenta', 'blue', 'red', 'yellow'] output_data = plot_utils.set_channel_colors(channel_data=input_data, plot_colors=colors) - From c81c4c60a342dce774dcc6bd12c721b89f049833 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Mon, 27 Apr 2020 09:47:55 -0700 Subject: [PATCH 11/12] updated example notebook --- caliban_toolbox/utils/data_utils.py | 25 +- caliban_toolbox/utils/data_utils_test.py | 4 + .../Caliban_Figure8_Upload_Combined.ipynb | 238 +++++++----------- 3 files changed, 122 insertions(+), 145 deletions(-) diff --git a/caliban_toolbox/utils/data_utils.py b/caliban_toolbox/utils/data_utils.py index 344acbc..1d5f7cd 100644 --- a/caliban_toolbox/utils/data_utils.py +++ b/caliban_toolbox/utils/data_utils.py @@ -33,12 +33,13 @@ import xarray as xr -def pad_xr_dims(input_xr, padded_dims): +def pad_xr_dims(input_xr, padded_dims=None): """Takes an xarray and pads it with dimensions of size 1 according to the supplied dims list Inputs input_xr: xarray to pad - padded_dims: ordered list of final dims; new dims will be added with size 1 + padded_dims: ordered list of final dims; new dims will be added with size 1. If None, + defaults to standard naming scheme for pipeline Outputs padded_xr: xarray that has had additional dims added of size 1 @@ -48,6 +49,8 @@ def pad_xr_dims(input_xr, padded_dims): ValueError: If padded dims includes duplicate names """ + if padded_dims is None: + padded_dims = ["fovs", "stacks", "crops", "slices", "rows", "cols", "channels"] # make sure that dimensions which are present in both lists are in same order old_dims = [dim for dim in padded_dims if dim in input_xr.dims] @@ -153,3 +156,21 @@ def reorder_channels(new_channel_order, input_data, full_blank=True): channel_xr_blanked = xr.DataArray(full_array, coords=coords, dims=dims) return channel_xr_blanked + + +def make_blank_labels(image_data): + """Creates an xarray of blank y_labels which matches the image_data passed in + + Args: + image_data: xarray of image channels used to get label names + + Returns: + xarray.DataArray: blank xarray of labeled data + """ + + blank_data = np.zeros(image_data.shape[:-1] + (1,), dtype='int16') + + coords = [image_data.fovs, image_data.rows, image_data.cols, ['segmentation_label']] + blank_xr = xr.DataArray(blank_data, coords=coords, dims=image_data.dims) + + return blank_xr diff --git a/caliban_toolbox/utils/data_utils_test.py b/caliban_toolbox/utils/data_utils_test.py index 8bb8ff8..8a7f96c 100644 --- a/caliban_toolbox/utils/data_utils_test.py +++ b/caliban_toolbox/utils/data_utils_test.py @@ -134,3 +134,7 @@ def test_reorder_channels(): new_channel_order = ['chan0', 'chan1', 'chan2', 'chan2'] reordered_data = data_utils.reorder_channels(new_channel_order=new_channel_order, input_data=input_data) + + +def test_make_blank_labels(): + assert True diff --git a/notebooks/Caliban_Figure8_Upload_Combined.ipynb b/notebooks/Caliban_Figure8_Upload_Combined.ipynb index 5aec7ba..1f3dffa 100644 --- a/notebooks/Caliban_Figure8_Upload_Combined.ipynb +++ b/notebooks/Caliban_Figure8_Upload_Combined.ipynb @@ -29,9 +29,9 @@ "\n", "from caliban_toolbox import reshape_data\n", "from caliban_toolbox.figure_eight_functions import create_figure_eight_job, download_figure_eight_output\n", - "from caliban_toolbox.utils import widget_utils, plot_utils\n", + "from caliban_toolbox.utils import widget_utils, plot_utils, data_utils, io_utils\n", "\n", - "from segmentation.utils.data_utils import load_imgs_from_dir, pad_xr_dims, reorder_xarray_channels\n", + "from segmentation.utils.data_utils import load_imgs_from_dir\n", "import xarray as xr\n", "\n", "import matplotlib as mpl\n", @@ -66,7 +66,7 @@ "outputs": [], "source": [ "# as a placeholder, run the following code to load example data:\n", - "data_stack_xr = load_imgs_from_dir(\"/example_data/timelapse/HeLa_by_image\", load_axis=\"stacks\",\n", + "channel_data = load_imgs_from_dir(\"/example_data/timelapse/HeLa_by_image\", load_axis=\"fovs\",\n", " imgs=[\"FITC_001.png\", \"Phase_000.png\", \"Phase_001.png\", \"Phase_002.png\"])" ] }, @@ -88,7 +88,7 @@ ], "source": [ "# this loads the imaging data into an 4D array of [time, rows, cols, channels]\n", - "data_stack_xr.shape" + "channel_data.shape" ] }, { @@ -109,7 +109,7 @@ ], "source": [ "# each channel is labeled according to the image name:\n", - "data_stack_xr.channels.values" + "channel_data.channels.values" ] }, { @@ -134,11 +134,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "# call deepcell.applications for appropriate model" + "# call deepcell.applications for appropriate model\n", + "y_data = data_utils.make_blank_labels(channel_data)" ] }, { @@ -158,13 +159,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9dec0240a9884d84971daa2b83093afb", + "model_id": "a4a5548184df4d9eb111289055fd5200", "version_major": 2, "version_minor": 0 }, @@ -177,9 +178,9 @@ } ], "source": [ - "choose_img_output = interactive(widget_utils.choose_img_from_stack, stack = fixed(data_stack_xr), \n", - " slice_idx = (0, data_stack_xr.shape[0]-1, 1),\n", - " chan_name = (data_stack_xr.channels.values));\n", + "choose_img_output = interactive(widget_utils.choose_img_from_stack, stack = fixed(channel_data), \n", + " slice_idx = (0, channel_data.shape[0]-1, 1),\n", + " chan_name = (channel_data.channels.values));\n", "choose_img_output" ] }, @@ -193,13 +194,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "21d7567da3024fbe8def006fe50c6b8c", + "model_id": "0541e382e11c42f79c56abae9852eade", "version_major": 2, "version_minor": 0 }, @@ -214,7 +215,7 @@ "source": [ "# get most recent parameters for selected image\n", "selected_slice_idx, selected_channel_idx = choose_img_output.result\n", - "img = data_stack_xr[selected_slice_idx, :, :, selected_channel_idx]\n", + "img = channel_data[selected_slice_idx, :, :, selected_channel_idx]\n", "\n", "# interative edit mode\n", "adjust_image_output = interactive(widget_utils.adjust_image_interactive, image=fixed(img), blur=(0.0,4,0.1), \n", @@ -232,19 +233,19 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# create placeholder channel to hold the output of channel adjustment\n", - "adjusted_channel_xr = xr.DataArray(np.zeros(data_stack_xr.shape[:-1] + (1,), np.uint8),\n", - " coords=[data_stack_xr.stacks, data_stack_xr.rows, data_stack_xr.cols,\n", + "adjusted_channel_xr = xr.DataArray(np.zeros(channel_data.shape[:-1] + (1,), np.uint8),\n", + " coords=[channel_data.fovs, channel_data.rows, channel_data.cols,\n", " [\"adjusted_channel\"]],\n", - " dims=data_stack_xr.dims)\n", + " dims=channel_data.dims)\n", "\n", "# adjust all slices for given channel\n", - "for i in range(data_stack_xr.shape[0]):\n", - " image = data_stack_xr[i, :, :, selected_channel_idx]\n", + "for i in range(channel_data.shape[0]):\n", + " image = channel_data[i, :, :, selected_channel_idx]\n", " adjusted_channel_xr[i, :, :, 0] = widget_utils.adjust_image(image, adjust_image_output.kwargs)" ] }, @@ -257,13 +258,13 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "502d062f8617474687f1e5234e226532", + "model_id": "b50117d5201d40a7a16784d6c9196338", "version_major": 2, "version_minor": 0 }, @@ -277,7 +278,7 @@ ], "source": [ "check_adjustment_output = interactive(widget_utils.choose_img_from_stack, stack = fixed(adjusted_channel_xr), \n", - " slice_idx = (0, data_stack_xr.shape[0]-1, 1),\n", + " slice_idx = (0, channel_data.shape[0]-1, 1),\n", " chan_name = fixed(\"adjusted_channel\"));\n", "check_adjustment_output" ] @@ -291,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -318,13 +319,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "90040ff569744533ad8697696242e698", + "model_id": "1e114b6db4b6415eb07d7cb9a6e5e884", "version_major": 2, "version_minor": 0 }, @@ -337,9 +338,9 @@ } ], "source": [ - "choose_img_output = interactive(widget_utils.choose_img_from_stack, stack = fixed(data_stack_xr), \n", - " slice_idx = (0, data_stack_xr.shape[0]-1, 1),\n", - " chan_name = (data_stack_xr.channels.values));\n", + "choose_img_output = interactive(widget_utils.choose_img_from_stack, stack = fixed(channel_data), \n", + " slice_idx = (0, channel_data.shape[0]-1, 1),\n", + " chan_name = (channel_data.channels.values));\n", "choose_img_output" ] }, @@ -352,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -361,25 +362,25 @@ "array(['FITC_001', 'Phase_000', 'Phase_001', 'Phase_002'], dtype='" + "" ] }, - "execution_count": 40, + "execution_count": 32, "metadata": {}, "output_type": "execute_result" }, @@ -825,7 +796,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 33, "metadata": {}, "outputs": [], "source": [ @@ -836,16 +807,16 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "" + "" ] }, - "execution_count": 45, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" }, @@ -869,43 +840,24 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "# once the parameters above look good, we'll crop the images\n", - "cropped_stack, log_data = reshape_data.crop_multichannel_data(data_xr=expanded_xr, crop_size=crop_size, \n", - " overlap_frac=overlap_frac)" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(5, 1, 1, 1, 1080, 1280, 4)" - ] - }, - "execution_count": 53, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "expanded_xr.shape" + "X_cropped, y_cropped, log_data = reshape_data.crop_multichannel_data(X_data=X_expanded, y_data=y_expanded, \n", + " crop_size=crop_size, \n", + " overlap_frac=overlap_frac)" ] }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# save cropped data into npz files for upload to Figure8\n", - "reshape_data.save_npzs_for_caliban(resized_xr=cropped_stack, original_xr=combined_xr, save_dir=save_dir,\n", + "io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_data, original_data=channel_data, save_dir=save_dir,\n", " log_data=log_data)" ] }, From 27391b4a2219efc5e241d912435bb513a3c63d28 Mon Sep 17 00:00:00 2001 From: ngreenwald Date: Wed, 29 Apr 2020 15:28:25 -0700 Subject: [PATCH 12/12] restored label_only flag --- caliban_toolbox/aws_functions.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/caliban_toolbox/aws_functions.py b/caliban_toolbox/aws_functions.py index 52c450b..8375d43 100755 --- a/caliban_toolbox/aws_functions.py +++ b/caliban_toolbox/aws_functions.py @@ -91,9 +91,7 @@ def aws_upload_files(aws_folder, stage, upload_folder, pixel_only, label_only, r subfolders = re.split('/', aws_folder) subfolders = '__'.join(subfolders) - # TODO: add back label_only once logic is working correctly in caliban - # url_dict = {'pixel_only': pixel_only, 'label_only': label_only, 'rgb': rgb_mode} - url_dict = {'pixel_only': pixel_only, 'rgb': rgb_mode} + url_dict = {'pixel_only': pixel_only, 'label_only': label_only, 'rgb': rgb_mode} url_encoded_dict = urlencode(url_dict) # upload images