Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for multiple labels #111

Merged
merged 3 commits into from
Aug 11, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 14 additions & 11 deletions caliban_toolbox/reshape_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

import xarray as xr

from caliban_toolbox import settings
from caliban_toolbox.utils import crop_utils, slice_utils, io_utils
from caliban_toolbox.utils.crop_utils import compute_crop_indices, crop_helper

Expand Down Expand Up @@ -91,12 +92,15 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla
if overlap_frac < 0 or overlap_frac > 1:
raise ValueError('overlap_frac must be between 0 and 1')

if list(X_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']:
if list(X_data.dims) != settings.X_DIMENSION_LABELS:
raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims))

if list(y_data.dims) != ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']:
if list(y_data.dims) != settings.Y_DIMENSION_LABELS:
raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims))

if y_data.shape[-1] != 1:
raise ValueError('Only one type of segmentation label can be processed at a time')

# check if testing or running all samples
if test_parameters:
X_data, y_data = X_data[:1, ...], y_data[:1, ...]
Expand Down Expand Up @@ -141,6 +145,7 @@ def crop_multichannel_data(X_data, y_data, crop_size=None, crop_num=None, overla
log_data['row_padding'] = int(row_padding)
log_data['col_padding'] = int(col_padding)
log_data['num_crops'] = X_data_cropped.shape[2]
log_data['label_name'] = y_data.dims[-1]

return X_data_cropped, y_data_cropped, log_data

Expand All @@ -164,12 +169,11 @@ def create_slice_data(X_data, y_data, slice_stack_len, slice_overlap=0):
raise ValueError('invalid input data shape, '
'expected array of len(7), got {}'.format(X_data.shape))

if len(y_data.shape) != 7:
raise ValueError('invalid input data shape, '
'expected array of len(7), got {}'.format(y_data.shape))
if list(X_data.dims) != settings.X_DIMENSION_LABELS:
raise ValueError('X_data does not have expected dims, found {}'.format(X_data.dims))

if slice_stack_len > X_data.shape[1]:
raise ValueError('slice size is greater than stack length')
if list(y_data.dims) != settings.Y_DIMENSION_LABELS:
raise ValueError('y_data does not have expected dims, found {}'.format(y_data.dims))

# compute indices for slices
stack_len = X_data.shape[1]
Expand Down Expand Up @@ -216,13 +220,12 @@ def reconstruct_image_stack(crop_dir, verbose=True):

# labels for each index within a dimension
_, stack_len, _, _, row_len, col_len, _ = log_data['original_shape']
label_name = log_data['label_name']
coordinate_labels = [log_data['fov_names'], range(stack_len), range(1),
range(1), range(row_len), range(col_len), ['segmentation_label']]
range(1), range(row_len), range(col_len), [label_name]]

# labels for each dimension
dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

stitched_xr = xr.DataArray(data=image_stack, coords=coordinate_labels,
dims=dimension_labels)
dims=settings.Y_DIMENSION_LABELS)

stitched_xr.to_netcdf(os.path.join(crop_dir, 'stitched_images.xr'))
13 changes: 8 additions & 5 deletions caliban_toolbox/reshape_data_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_crop_multichannel_data():

test_y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num, row_len=row_len, col_len=col_len,
chan_len=channel_len)
chan_len=channel_len, last_dim_name='compartments')

X_data_cropped, y_data_cropped, log_data = \
reshape_data.crop_multichannel_data(X_data=test_X_data,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_create_slice_data():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops,
slice_num=num_slices, row_len=row_len, col_len=col_len,
chan_len=chan_len)
chan_len=chan_len, last_dim_name='compartments')

X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data,
slice_stack_len)
Expand All @@ -155,7 +155,8 @@ def test_reconstruct_image_stack():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# create image with artificial objects to be segmented

Expand Down Expand Up @@ -203,7 +204,8 @@ def test_reconstruct_image_stack():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# tag upper left hand corner of the label in each image
tags = np.arange(stack_len)
Expand Down Expand Up @@ -237,7 +239,8 @@ def test_reconstruct_image_stack():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# create image with artificial objects to be segmented

Expand Down
38 changes: 38 additions & 0 deletions caliban_toolbox/settings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from decouple import config

X_DIMENSION_LABELS = config('DIMENSION_LABELS', cast=list,
default=['fovs', 'stacks', 'crops', 'slices',
'rows', 'cols', 'channels'])

Y_DIMENSION_LABELS = config('DIMENSION_LABELS', cast=list,
default=['fovs', 'stacks', 'crops', 'slices',
'rows', 'cols', 'compartments'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does casting to a list work? Does it expect a comma-separated entry or something?

11 changes: 6 additions & 5 deletions caliban_toolbox/utils/crop_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding)
"""Crops an image into pieces according to supplied coordinates

Args:
input_data: xarray of [fovs, stacks, crops, slices, rows, cols, channels] to be cropped
input_data: xarray of either X or y data to be cropped
row_starts: list of indices where row crops start
row_ends: list of indices where row crops end
col_starts: list of indices where col crops start
Expand All @@ -104,6 +104,9 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding)
if input_crop_num > 1:
raise ValueError("Array has already been cropped")

# get name of last dimension from input data to determine if X or y
last_dim_name = input_data.dims[-1]

crop_num = len(row_starts) * len(col_starts)
crop_size_row = row_ends[0] - row_starts[0]
crop_size_col = col_ends[0] - col_starts[0]
Expand All @@ -114,12 +117,10 @@ def crop_helper(input_data, row_starts, row_ends, col_starts, col_ends, padding)

# labels for each index within a dimension
coordinate_labels = [input_data.fovs, input_data.stacks, range(crop_num), input_data.slices,
range(crop_size_row), range(crop_size_col), input_data.channels]
range(crop_size_row), range(crop_size_col), input_data[last_dim_name]]

# labels for each dimension
dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=dimension_labels)
cropped_xr = xr.DataArray(data=cropped_stack, coords=coordinate_labels, dims=input_data.dims)

# pad the input to account for imperfectly overlapping final crop in rows and cols
formatted_padding = ((0, 0), (0, 0), (0, 0), (0, 0), (0, padding[0]), (0, padding[1]), (0, 0))
Expand Down
14 changes: 9 additions & 5 deletions caliban_toolbox/utils/crop_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
import xarray as xr


def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len):
def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, chan_len,
last_dim_name='channels'):
"""Test function to generate a blank xarray with the supplied dimensions

Inputs
Expand All @@ -43,6 +44,7 @@ def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, ch
row_num: number of rows
col_num: number of cols
chan_num: number of channels
last_dim_name: name of last dimension. Either channels or compartments for X or y data

Outputs
test_xr: xarray of [fov_num, row_num, col_num, chan_num]"""
Expand All @@ -56,7 +58,7 @@ def _blank_data_xr(fov_len, stack_len, crop_num, slice_num, row_len, col_len, ch
coords=[fovs, range(stack_len), range(crop_num), range(slice_num),
range(row_len), range(col_len), channels],
dims=["fovs", "stacks", "crops", "slices",
"rows", "cols", "channels"])
"rows", "cols", last_dim_name])

return test_stack_xr

Expand Down Expand Up @@ -195,7 +197,8 @@ def test_stitch_crops():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# create image with artificial objects to be segmented

Expand Down Expand Up @@ -258,7 +261,8 @@ def test_stitch_crops():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=chan_len)
row_len=row_len, col_len=col_len, chan_len=chan_len,
last_dim_name='compartments')
side_len = 40
cell_num = y_data.shape[4] // side_len

Expand Down Expand Up @@ -309,7 +313,7 @@ def test_stitch_crops():
log_data["num_crops"] = y_cropped.shape[2]
log_data["original_shape"] = y_data.shape
log_data["fov_names"] = y_data.fovs.values.tolist()
log_data["channel_names"] = y_data.channels.values.tolist()
log_data["label_name"] = str(y_data.coords['compartments'][0].values)

stitched_img = crop_utils.stitch_crops(crop_stack=y_cropped, log_data=log_data)

Expand Down
2 changes: 1 addition & 1 deletion caliban_toolbox/utils/io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def save_npzs_for_caliban(X_data, y_data, original_data, log_data, save_dir,
raise NotImplementedError()

log_data['fov_names'] = fov_names.tolist()
log_data['channel_names'] = original_data.channels.values.tolist()
log_data['label_name'] = str(y_data.coords[y_data.dims[-1]][0].values)
log_data['original_shape'] = original_data.shape
log_data['slice_stack_len'] = X_data.shape[1]
log_data['save_format'] = save_format
Expand Down
9 changes: 6 additions & 3 deletions caliban_toolbox/utils/io_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_save_npzs_for_caliban():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops,
slice_num=num_slices,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

sliced_X, sliced_y, log_data = reshape_data.create_slice_data(X_data=X_data, y_data=y_data,
slice_stack_len=slice_stack_len)
Expand Down Expand Up @@ -194,7 +195,8 @@ def test_load_npzs():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# slice the data
X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data, y_data,
Expand Down Expand Up @@ -249,7 +251,8 @@ def test_load_npzs():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# slice the data
X_slice, y_slice, log_data = reshape_data.create_slice_data(X_data,
Expand Down
9 changes: 5 additions & 4 deletions caliban_toolbox/utils/slice_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,9 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices):
if input_slice_num > 1:
raise ValueError('Input array already contains slice data')

# get name of last dimension from input data to determine if X or y
last_dim_name = data_xr.dims[-1]

slice_num = len(slice_start_indices)
sliced_stack_len = slice_end_indices[0] - slice_start_indices[0]

Expand All @@ -95,12 +98,10 @@ def slice_helper(data_xr, slice_start_indices, slice_end_indices):

# labels for each index within a dimension
coordinate_labels = [data_xr.fovs, range(sliced_stack_len), range(crop_num), range(slice_num),
range(row_len), range(col_len), data_xr.channels]
range(row_len), range(col_len), data_xr[last_dim_name]]

# labels for each dimension
dimension_labels = ['fovs', 'stacks', 'crops', 'slices', 'rows', 'cols', 'channels']

slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=dimension_labels)
slice_xr = xr.DataArray(data=slice_data, coords=coordinate_labels, dims=data_xr.dims)

# loop through slice indices to generate sliced data
slice_counter = 0
Expand Down
6 changes: 4 additions & 2 deletions caliban_toolbox/utils/slice_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,8 @@ def test_stitch_slices():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# generate ordered data
linear_seq = np.arange(stack_len * row_len * col_len)
Expand Down Expand Up @@ -178,7 +179,8 @@ def test_stitch_slices():

y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
row_len=row_len, col_len=col_len, chan_len=1,
last_dim_name='compartments')

# generate ordered data
linear_seq = np.arange(stack_len * row_len * col_len)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ xarray==0.13.0
netCDF4==1.5.3
pathlib==1.0.1
deepcell-toolbox>=0.6.1
python-decouple==3.1