diff --git a/caliban_toolbox/reshape_data.py b/caliban_toolbox/reshape_data.py index 23fe5b1..b877a80 100644 --- a/caliban_toolbox/reshape_data.py +++ b/caliban_toolbox/reshape_data.py @@ -32,6 +32,7 @@ import xarray as xr +from caliban_toolbox import settings from caliban_toolbox.utils import crop_utils, slice_utils, io_utils from caliban_toolbox.utils.crop_utils import compute_crop_indices, crop_helper @@ -91,12 +92,15 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla if overlap_frac < 0 or overlap_frac > 1: raise ValueError('overlap_frac must be between 0 and 1') - if list(X_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']: + if list(X_data.dims) != settings.X_DIMENSION_LABELS: raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims)) - if list(y_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']: + if list(y_data.dims) != settings.Y_DIMENSION_LABELS: raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims)) + if y_data.shape[-1] != 1: + raise ValueError('Only one type of segmentation label can be processed at a time') + # check if testing or running all samples if test_parameters: X_data, y_data = X_data[:1, ...], y_data[:1, ...] @@ -141,6 +145,7 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla log_data['row_padding'] = int(row_padding) log_data['col_padding'] = int(col_padding) log_data['num_crops'] = X_data_cropped.shape[2] + log_data['label_name'] = y_data.dims[-1] return X_data_cropped, y_data_cropped, log_data @@ -164,12 +169,11 @@ def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0): raise ValueError('invalid input data shape, ' 'expected array of len(7), got {}'.format(X_data.shape)) - if len(y_data.shape) != 7: - raise ValueError('invalid input data shape, ' - 'expected array of len(7), got {}'.format(y_data.shape)) + if list(X_data.dims) != settings.X_DIMENSION_LABELS: + raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims)) - if slice_stack_len > X_data.shape[1]: - raise ValueError('slice size is greater than stack length') + if list(y_data.dims) != settings.Y_DIMENSION_LABELS: + raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims)) # compute indices for slices stack_len = X_data.shape[1] @@ -219,13 +223,12 @@ def reconstruct_image_stack(crop_dir, verbose=True): # labels for each index within a dimension _, stack_len, _, _, row_len, col_len, _ = log_data['original_shape'] + label_name = log_data['label_name'] coordinate_labels = [log_data['fov_names'], range(stack_len), range(1), - range(1), range(row_len), range(col_len), ['segmentation_label']] + range(1), range(row_len), range(col_len), [label_name]] # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - stitched_xr = xr.DataArray(data=image_stack, coords=coordinate_labels, - dims=dimension_labels) + dims=settings.Y_DIMENSION_LABELS) return stitched_xr diff --git a/caliban_toolbox/reshape_data_test.py b/caliban_toolbox/reshape_data_test.py index 669057a..e89fc80 100644 --- a/caliban_toolbox/reshape_data_test.py +++ b/caliban_toolbox/reshape_data_test.py @@ -49,7 +49,7 @@ def test_crop_multichannel_data(): test_y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, row_len=row_len, col_len=col_len, - chan_len=channel_len) + chan_len=channel_len, last_dim_name='compartments') X_data_cropped, y_data_cropped, log_data = \ reshape_data.crop_multichannel_data(X_data=test_X_data, @@ -133,7 +133,7 @@ def test_create_slice_data(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, slice_num=num_slices, row_len=row_len, col_len=col_len, - chan_len=chan_len) + chan_len=chan_len, last_dim_name='compartments') X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data, slice_stack_len) @@ -155,7 +155,8 @@ def test_reconstruct_image_stack(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # create image with artificial objects to be segmented @@ -201,7 +202,8 @@ def test_reconstruct_image_stack(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # tag upper left hand corner of the label in each image tags = np.arange(stack_len) @@ -234,7 +236,8 @@ def test_reconstruct_image_stack(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # create image with artificial objects to be segmented diff --git a/caliban_toolbox/settings.py b/caliban_toolbox/settings.py new file mode 100644 index 0000000..6ea3272 --- /dev/null +++ b/caliban_toolbox/settings.py @@ -0,0 +1,33 @@ +# Copyright 2016-2020 The Van Valen Lab at the California Institute of +# Technology (Caltech), with support from the Paul Allen Family Foundation, +# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01. +# All rights reserved. +# +# Licensed under a modified Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE +# +# The Work provided may be used for non-commercial academic purposes only. +# For any other use of the Work, including commercial use, please contact: +# vanvalenlab@gmail.com +# +# Neither the name of Caltech nor the names of its contributors may be used +# to endorse or promote products derived from this software without specific +# prior written permission. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + + +X_DIMENSION_LABELS = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] + +Y_DIMENSION_LABELS = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'compartments'] diff --git a/caliban_toolbox/utils/crop_utils.py b/caliban_toolbox/utils/crop_utils.py index 1c257b5..297a1fd 100644 --- a/caliban_toolbox/utils/crop_utils.py +++ b/caliban_toolbox/utils/crop_utils.py @@ -86,7 +86,7 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding) """Crops an image into pieces according to supplied coordinates Args: - input_data: xarray of [fovs, stacks, crops, slices, rows, cols, channels] to be cropped + input_data: xarray of either X or y data to be cropped row_starts: list of indices where row crops start row_ends: list of indices where row crops end col_starts: list of indices where col crops start @@ -104,6 +104,9 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding) if input_crop_num > 1: raise ValueError("Array has already been cropped") + # get name of last dimension from input data to determine if X or y + last_dim_name = input_data.dims[-1] + crop_num = len(row_starts) * len(col_starts) crop_size_row = row_ends[0] - row_starts[0] crop_size_col = col_ends[0] - col_starts[0] @@ -114,12 +117,10 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding) # labels for each index within a dimension coordinate_labels = [input_data.fovs, input_data.stacks, range(crop_num), input_data.slices, - range(crop_size_row), range(crop_size_col), input_data.channels] + range(crop_size_row), range(crop_size_col), input_data[last_dim_name]] # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - - cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=dimension_labels) + cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=input_data.dims) # pad the input to account for imperfectly overlapping final crop in rows and cols formatted_padding = ((0, 0), (0, 0), (0, 0), (0, 0), (0, padding[0]), (0, padding[1]), (0, 0)) diff --git a/caliban_toolbox/utils/crop_utils_test.py b/caliban_toolbox/utils/crop_utils_test.py index f4830aa..24b776a 100644 --- a/caliban_toolbox/utils/crop_utils_test.py +++ b/caliban_toolbox/utils/crop_utils_test.py @@ -32,7 +32,8 @@ import xarray as xr -def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len): +def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len, + last_dim_name='channels'): """Test function to generate a blank xarray with the supplied dimensions Inputs @@ -43,6 +44,7 @@ def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, ch row_num: number of rows col_num: number of cols chan_num: number of channels + last_dim_name: name of last dimension. Either channels or compartments for X or y data Outputs test_xr: xarray of [fov_num, row_num, col_num, chan_num]""" @@ -56,7 +58,7 @@ def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, ch coords=[fovs, range(stack_len), range(crop_num), range(slice_num), range(row_len), range(col_len), channels], dims=["fovs", "stacks", "crops", "slices", - "rows", "cols", "channels"]) + "rows", "cols", last_dim_name]) return test_stack_xr @@ -213,7 +215,8 @@ def test_stitch_crops(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # create image with artificial objects to be segmented @@ -276,7 +279,8 @@ def test_stitch_crops(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=chan_len) + row_len=row_len, col_len=col_len, chan_len=chan_len, + last_dim_name='compartments') side_len = 40 cell_num = y_data.shape[4] // side_len @@ -327,7 +331,7 @@ def test_stitch_crops(): log_data["num_crops"] = y_cropped.shape[2] log_data["original_shape"] = y_data.shape log_data["fov_names"] = y_data.fovs.values.tolist() - log_data["channel_names"] = y_data.channels.values.tolist() + log_data["label_name"] = str(y_data.coords['compartments'][0].values) stitched_img = crop_utils.stitch_crops(crop_stack=y_cropped, log_data=log_data) diff --git a/caliban_toolbox/utils/io_utils.py b/caliban_toolbox/utils/io_utils.py index 773c2dc..fe12238 100644 --- a/caliban_toolbox/utils/io_utils.py +++ b/caliban_toolbox/utils/io_utils.py @@ -122,7 +122,7 @@ def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir, raise NotImplementedError() log_data['fov_names'] = fov_names.tolist() - log_data['channel_names'] = original_data.channels.values.tolist() + log_data['label_name'] = str(y_data.coords[y_data.dims[-1]][0].values) log_data['original_shape'] = original_data.shape log_data['slice_stack_len'] = X_data.shape[1] log_data['save_format'] = save_format diff --git a/caliban_toolbox/utils/io_utils_test.py b/caliban_toolbox/utils/io_utils_test.py index f015a7f..5caca56 100644 --- a/caliban_toolbox/utils/io_utils_test.py +++ b/caliban_toolbox/utils/io_utils_test.py @@ -47,7 +47,8 @@ def test_save_npzs_for_caliban(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops, slice_num=num_slices, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') sliced_X, sliced_y, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data, slice_stack_len=slice_stack_len) @@ -194,7 +195,8 @@ def test_load_npzs(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # slice the data X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data, @@ -249,7 +251,8 @@ def test_load_npzs(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # slice the data X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, diff --git a/caliban_toolbox/utils/slice_utils.py b/caliban_toolbox/utils/slice_utils.py index e80a3d3..a7f10b5 100644 --- a/caliban_toolbox/utils/slice_utils.py +++ b/caliban_toolbox/utils/slice_utils.py @@ -86,6 +86,9 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices): if input_slice_num > 1: raise ValueError('Input array already contains slice data') + # get name of last dimension from input data to determine if X or y + last_dim_name = data_xr.dims[-1] + slice_num = len(slice_start_indices) sliced_stack_len = slice_end_indices[0] - slice_start_indices[0] @@ -95,12 +98,10 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices): # labels for each index within a dimension coordinate_labels = [data_xr.fovs, range(sliced_stack_len), range(crop_num), range(slice_num), - range(row_len), range(col_len), data_xr.channels] + range(row_len), range(col_len), data_xr[last_dim_name]] # labels for each dimension - dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels'] - - slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=dimension_labels) + slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=data_xr.dims) # loop through slice indices to generate sliced data slice_counter = 0 diff --git a/caliban_toolbox/utils/slice_utils_test.py b/caliban_toolbox/utils/slice_utils_test.py index 0aa9f7b..cece1b4 100644 --- a/caliban_toolbox/utils/slice_utils_test.py +++ b/caliban_toolbox/utils/slice_utils_test.py @@ -148,7 +148,8 @@ def test_stitch_slices(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # generate ordered data linear_seq = np.arange(stack_len * row_len * col_len) @@ -178,7 +179,8 @@ def test_stitch_slices(): y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num, slice_num=slice_num, - row_len=row_len, col_len=col_len, chan_len=1) + row_len=row_len, col_len=col_len, chan_len=1, + last_dim_name='compartments') # generate ordered data linear_seq = np.arange(stack_len * row_len * col_len)