Skip to content

Commit

Permalink
add support for multiple labels (#111)
Browse files Browse the repository at this point in the history
* add support for multiple labels
* moved global variables to settings.py
* switched to global variables
  • Loading branch information
ngreenwald committed Aug 11, 2020
1 parent be970e5 commit df42d68
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 36 deletions.
25 changes: 14 additions & 11 deletions caliban_toolbox/reshape_data.py
Expand Up @@ -32,6 +32,7 @@

import xarray as xr

from caliban_toolbox import settings
from caliban_toolbox.utils import crop_utils, slice_utils, io_utils
from caliban_toolbox.utils.crop_utils import compute_crop_indices, crop_helper

Expand Down Expand Up @@ -91,12 +92,15 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla
if overlap_frac < 0 or overlap_frac > 1:
raise ValueError('overlap_frac must be between 0 and 1')

if list(X_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']:
if list(X_data.dims) != settings.X_DIMENSION_LABELS:
raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims))

if list(y_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']:
if list(y_data.dims) != settings.Y_DIMENSION_LABELS:
raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims))

if y_data.shape[-1] != 1:
raise ValueError('Only one type of segmentation label can be processed at a time')

# check if testing or running all samples
if test_parameters:
X_data, y_data = X_data[:1, ...], y_data[:1, ...]
Expand Down Expand Up @@ -141,6 +145,7 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla
log_data['row_padding'] = int(row_padding)
log_data['col_padding'] = int(col_padding)
log_data['num_crops'] = X_data_cropped.shape[2]
log_data['label_name'] = y_data.dims[-1]

return X_data_cropped, y_data_cropped, log_data

Expand All @@ -164,12 +169,11 @@ def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0):
raise ValueError('invalid input data shape, '
'expected array of len(7), got {}'.format(X_data.shape))

if len(y_data.shape) != 7:
raise ValueError('invalid input data shape, '
'expected array of len(7), got {}'.format(y_data.shape))
if list(X_data.dims) != settings.X_DIMENSION_LABELS:
raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims))

if slice_stack_len > X_data.shape[1]:
raise ValueError('slice size is greater than stack length')
if list(y_data.dims) != settings.Y_DIMENSION_LABELS:
raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims))

# compute indices for slices
stack_len = X_data.shape[1]
Expand Down Expand Up @@ -219,13 +223,12 @@ def reconstruct_image_stack(crop_dir, verbose=True):

# labels for each index within a dimension
_, stack_len, _, _, row_len, col_len, _ = log_data['original_shape']
label_name = log_data['label_name']
coordinate_labels = [log_data['fov_names'], range(stack_len), range(1),
range(1), range(row_len), range(col_len), ['segmentation_label']]
range(1), range(row_len), range(col_len), [label_name]]

# labels for each dimension
dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

stitched_xr = xr.DataArray(data=image_stack, coords=coordinate_labels,
dims=dimension_labels)
dims=settings.Y_DIMENSION_LABELS)

return stitched_xr
13 changes: 8 additions & 5 deletions caliban_toolbox/reshape_data_test.py
Expand Up @@ -49,7 +49,7 @@ def test_crop_multichannel_data():

test_y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num, row_len=row_len, col_len=col_len,
chan_len=channel_len)
chan_len=channel_len, last_dim_name='compartments')

X_data_cropped, y_data_cropped, log_data = \
reshape_data.crop_multichannel_data(X_data=test_X_data,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_create_slice_data():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops,
slice_num=num_slices, row_len=row_len, col_len=col_len,
chan_len=chan_len)
chan_len=chan_len, last_dim_name='compartments')

X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data,
slice_stack_len)
Expand All @@ -155,7 +155,8 @@ def test_reconstruct_image_stack():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# create image with artificial objects to be segmented

Expand Down Expand Up @@ -201,7 +202,8 @@ def test_reconstruct_image_stack():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# tag upper left hand corner of the label in each image
tags = np.arange(stack_len)
Expand Down Expand Up @@ -234,7 +236,8 @@ def test_reconstruct_image_stack():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# create image with artificial objects to be segmented

Expand Down
33 changes: 33 additions & 0 deletions caliban_toolbox/settings.py
@@ -0,0 +1,33 @@
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division


X_DIMENSION_LABELS = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

Y_DIMENSION_LABELS = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'compartments']
11 changes: 6 additions & 5 deletions caliban_toolbox/utils/crop_utils.py
Expand Up @@ -86,7 +86,7 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding)
"""Crops an image into pieces according to supplied coordinates
Args:
input_data: xarray of [fovs, stacks, crops, slices, rows, cols, channels] to be cropped
input_data: xarray of either X or y data to be cropped
row_starts: list of indices where row crops start
row_ends: list of indices where row crops end
col_starts: list of indices where col crops start
Expand All @@ -104,6 +104,9 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding)
if input_crop_num > 1:
raise ValueError("Array has already been cropped")

# get name of last dimension from input data to determine if X or y
last_dim_name = input_data.dims[-1]

crop_num = len(row_starts) * len(col_starts)
crop_size_row = row_ends[0] - row_starts[0]
crop_size_col = col_ends[0] - col_starts[0]
Expand All @@ -114,12 +117,10 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding)

# labels for each index within a dimension
coordinate_labels = [input_data.fovs, input_data.stacks, range(crop_num), input_data.slices,
range(crop_size_row), range(crop_size_col), input_data.channels]
range(crop_size_row), range(crop_size_col), input_data[last_dim_name]]

# labels for each dimension
dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=dimension_labels)
cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=input_data.dims)

# pad the input to account for imperfectly overlapping final crop in rows and cols
formatted_padding = ((0, 0), (0, 0), (0, 0), (0, 0), (0, padding[0]), (0, padding[1]), (0, 0))
Expand Down
14 changes: 9 additions & 5 deletions caliban_toolbox/utils/crop_utils_test.py
Expand Up @@ -32,7 +32,8 @@
import xarray as xr


def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len):
def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len,
last_dim_name='channels'):
"""Test function to generate a blank xarray with the supplied dimensions
Inputs
Expand All @@ -43,6 +44,7 @@ def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, ch
row_num: number of rows
col_num: number of cols
chan_num: number of channels
last_dim_name: name of last dimension. Either channels or compartments for X or y data
Outputs
test_xr: xarray of [fov_num, row_num, col_num, chan_num]"""
Expand All @@ -56,7 +58,7 @@ def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, ch
coords=[fovs, range(stack_len), range(crop_num), range(slice_num),
range(row_len), range(col_len), channels],
dims=["fovs", "stacks", "crops", "slices",
"rows", "cols", "channels"])
"rows", "cols", last_dim_name])

return test_stack_xr

Expand Down Expand Up @@ -213,7 +215,8 @@ def test_stitch_crops():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# create image with artificial objects to be segmented

Expand Down Expand Up @@ -276,7 +279,8 @@ def test_stitch_crops():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=chan_len)
row_len=row_len, col_len=col_len, chan_len=chan_len,
last_dim_name='compartments')
side_len = 40
cell_num = y_data.shape[4] // side_len

Expand Down Expand Up @@ -327,7 +331,7 @@ def test_stitch_crops():
log_data["num_crops"] = y_cropped.shape[2]
log_data["original_shape"] = y_data.shape
log_data["fov_names"] = y_data.fovs.values.tolist()
log_data["channel_names"] = y_data.channels.values.tolist()
log_data["label_name"] = str(y_data.coords['compartments'][0].values)

stitched_img = crop_utils.stitch_crops(crop_stack=y_cropped, log_data=log_data)

Expand Down
2 changes: 1 addition & 1 deletion caliban_toolbox/utils/io_utils.py
Expand Up @@ -122,7 +122,7 @@ def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir,
raise NotImplementedError()

log_data['fov_names'] = fov_names.tolist()
log_data['channel_names'] = original_data.channels.values.tolist()
log_data['label_name'] = str(y_data.coords[y_data.dims[-1]][0].values)
log_data['original_shape'] = original_data.shape
log_data['slice_stack_len'] = X_data.shape[1]
log_data['save_format'] = save_format
Expand Down
9 changes: 6 additions & 3 deletions caliban_toolbox/utils/io_utils_test.py
Expand Up @@ -47,7 +47,8 @@ def test_save_npzs_for_caliban():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops,
slice_num=num_slices,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

sliced_X, sliced_y, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data,
slice_stack_len=slice_stack_len)
Expand Down Expand Up @@ -194,7 +195,8 @@ def test_load_npzs():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# slice the data
X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data,
Expand Down Expand Up @@ -249,7 +251,8 @@ def test_load_npzs():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# slice the data
X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data,
Expand Down
9 changes: 5 additions & 4 deletions caliban_toolbox/utils/slice_utils.py
Expand Up @@ -86,6 +86,9 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices):
if input_slice_num > 1:
raise ValueError('Input array already contains slice data')

# get name of last dimension from input data to determine if X or y
last_dim_name = data_xr.dims[-1]

slice_num = len(slice_start_indices)
sliced_stack_len = slice_end_indices[0] - slice_start_indices[0]

Expand All @@ -95,12 +98,10 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices):

# labels for each index within a dimension
coordinate_labels = [data_xr.fovs, range(sliced_stack_len), range(crop_num), range(slice_num),
range(row_len), range(col_len), data_xr.channels]
range(row_len), range(col_len), data_xr[last_dim_name]]

# labels for each dimension
dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=dimension_labels)
slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=data_xr.dims)

# loop through slice indices to generate sliced data
slice_counter = 0
Expand Down
6 changes: 4 additions & 2 deletions caliban_toolbox/utils/slice_utils_test.py
Expand Up @@ -148,7 +148,8 @@ def test_stitch_slices():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# generate ordered data
linear_seq = np.arange(stack_len * row_len * col_len)
Expand Down Expand Up @@ -178,7 +179,8 @@ def test_stitch_slices():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# generate ordered data
linear_seq = np.arange(stack_len * row_len * col_len)
Expand Down

0 comments on commit df42d68

Please sign in to comment.