-
Notifications
You must be signed in to change notification settings - Fork 1
/
reshape_data_test.py
290 lines (236 loc) · 13.2 KB
/
reshape_data_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Copyright 2016-2020 The Van Valen Lab at the California Institute of
# Technology (Caltech), with support from the Paul Allen Family Foundation,
# Google, & National Institutes of Health (NIH) under Grant U24CA224309-01.
# All rights reserved.
#
# Licensed under a modified Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.github.com/vanvalenlab/caliban-toolbox/LICENSE
#
# The Work provided may be used for non-commercial academic purposes only.
# For any other use of the Work, including commercial use, please contact:
# vanvalenlab@gmail.com
#
# Neither the name of Caltech nor the names of its contributors may be used
# to endorse or promote products derived from this software without specific
# prior written permission.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import tempfile
import pytest
import numpy as np
import xarray as xr
from caliban_toolbox import reshape_data
from caliban_toolbox.utils import crop_utils, io_utils
from caliban_toolbox.utils.crop_utils_test import _blank_data_xr
def test_crop_multichannel_data():
# img params
fov_len, stack_len, crop_num, slice_num, row_len = 2, 1, 1, 1, 200
col_len, channel_len = 200, 1
crop_size = (50, 50)
overlap_frac = 0.2
# test only one crop
test_X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num, row_len=row_len, col_len=col_len,
chan_len=channel_len)
test_y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num, row_len=row_len, col_len=col_len,
chan_len=channel_len)
X_data_cropped, y_data_cropped, log_data = \
reshape_data.crop_multichannel_data(X_data=test_X_data,
y_data=test_y_data,
crop_size=crop_size,
overlap_frac=overlap_frac,
test_parameters=False)
expected_crop_num = len(crop_utils.compute_crop_indices(img_len=row_len,
crop_size=crop_size[0],
overlap_frac=overlap_frac)[0]) ** 2
assert (X_data_cropped.shape == (fov_len, stack_len, expected_crop_num, slice_num,
crop_size[0], crop_size[1], channel_len))
assert log_data["num_crops"] == expected_crop_num
# invalid arguments
# no crop_size or crop_num
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data)
# both crop_size and crop_num
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(20, 20), crop_num=(20, 20))
# bad crop_size dtype
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=5)
# bad crop_size shape
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(10, 5, 2))
# bad crop_size values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(0, 5))
# bad crop_size values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_size=(1.5, 5))
# bad crop_num dtype
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=5)
# bad crop_num shape
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=(10, 5, 2))
# bad crop_num values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=(0, 5))
# bad crop_num values
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
crop_num=(1.5, 5))
# bad overlap_frac value
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data,
overlap_frac=1.2)
# bad X_data dims
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data[0], y_data=test_y_data,
crop_size=(5, 5))
# bad y_data dims
with pytest.raises(ValueError):
_ = reshape_data.crop_multichannel_data(X_data=test_X_data, y_data=test_y_data[0],
crop_num=(5, 5))
def test_create_slice_data():
# test output shape with even division of slice
fov_len, stack_len, num_crops, num_slices, row_len, col_len, chan_len = 1, 40, 1, 1, 50, 50, 3
slice_stack_len = 4
X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops,
slice_num=num_slices, row_len=row_len, col_len=col_len,
chan_len=chan_len)
y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=num_crops,
slice_num=num_slices, row_len=row_len, col_len=col_len,
chan_len=chan_len)
X_slice, y_slice, slice_indices = reshape_data.create_slice_data(X_data, y_data,
slice_stack_len)
assert X_slice.shape == (fov_len, slice_stack_len, num_crops,
int(np.ceil(stack_len / slice_stack_len)),
row_len, col_len, chan_len)
def test_reconstruct_image_stack():
with tempfile.TemporaryDirectory() as temp_dir:
# generate stack of crops from image with grid pattern
(fov_len, stack_len, crop_num,
slice_num, row_len, col_len, chan_len) = 2, 1, 1, 1, 400, 400, 4
X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=chan_len)
y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
# create image with artificial objects to be segmented
cell_idx = 1
for i in range(12):
for j in range(11):
for fov in range(y_data.shape[0]):
y_data[fov, :, :, :, (i * 35):(i * 35 + 10 + fov * 10),
(j * 37):(j * 37 + 8 + fov * 10), 0] = cell_idx
cell_idx += 1
# Crop the data
crop_size, overlap_frac = 100, 0.2
X_cropped, y_cropped, log_data = \
reshape_data.crop_multichannel_data(X_data=X_data,
y_data=y_data,
crop_size=(crop_size, crop_size),
overlap_frac=overlap_frac)
io_utils.save_npzs_for_caliban(X_data=X_cropped, y_data=y_cropped, original_data=X_data,
log_data=log_data, save_dir=temp_dir)
stitched_imgs = reshape_data.reconstruct_image_stack(crop_dir=temp_dir)
# dims are the same
assert np.all(stitched_imgs.shape == y_data.shape)
# all the same pixels are marked
assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0)))
# there are the same number of cells
assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data)))
with tempfile.TemporaryDirectory() as temp_dir:
# generate data with the corner tagged
fov_len, stack_len, crop_num, slice_num = 1, 40, 1, 1
row_len, col_len, chan_len = 50, 50, 3
slice_stack_len = 4
X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=chan_len)
y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
# tag upper left hand corner of the label in each image
tags = np.arange(stack_len)
y_data[0, :, 0, 0, 0, 0, 0] = tags
X_slice, y_slice, slice_log_data = \
reshape_data.create_slice_data(X_data=X_data,
y_data=y_data,
slice_stack_len=slice_stack_len)
io_utils.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data,
log_data={**slice_log_data}, save_dir=temp_dir,
blank_labels="include",
save_format="npz", verbose=False)
stitched_imgs = reshape_data.reconstruct_image_stack(temp_dir)
assert np.all(stitched_imgs.shape == y_data.shape)
assert np.all(np.equal(stitched_imgs[0, :, 0, 0, 0, 0, 0], tags))
with tempfile.TemporaryDirectory() as temp_dir:
# generate data with both corners tagged and images labeled
(fov_len, stack_len, crop_num,
slice_num, row_len, col_len, chan_len) = 1, 8, 1, 1, 400, 400, 4
X_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=chan_len)
y_data = _blank_data_xr(fov_len=fov_len, stack_len=stack_len, crop_num=crop_num,
slice_num=slice_num,
row_len=row_len, col_len=col_len, chan_len=1)
# create image with artificial objects to be segmented
cell_idx = 1
for i in range(1, 12):
for j in range(1, 11):
for stack in range(stack_len):
y_data[:, stack, :, :, (i * 35):(i * 35 + 10 + stack * 2),
(j * 37):(j * 37 + 8 + stack * 2), 0] = cell_idx
cell_idx += 1
# tag upper left hand corner of each image with squares of increasing size
for stack in range(stack_len):
y_data[0, stack, 0, 0, :stack, :stack, 0] = 1
# Crop the data
crop_size, overlap_frac = 100, 0.2
X_cropped, y_cropped, log_data = \
reshape_data.crop_multichannel_data(X_data=X_data,
y_data=y_data,
crop_size=(crop_size, crop_size),
overlap_frac=overlap_frac)
X_slice, y_slice, slice_log_data = \
reshape_data.create_slice_data(X_data=X_cropped,
y_data=y_cropped,
slice_stack_len=slice_stack_len)
io_utils.save_npzs_for_caliban(X_data=X_slice, y_data=y_slice, original_data=X_data,
log_data={**slice_log_data, **log_data},
save_dir=temp_dir,
blank_labels="include",
save_format="npz", verbose=False)
stitched_imgs = reshape_data.reconstruct_image_stack(temp_dir)
assert np.all(stitched_imgs.shape == y_data.shape)
# dims are the same
assert np.all(stitched_imgs.shape == y_data.shape)
# all the same pixels are marked
assert (np.all(np.equal(stitched_imgs[:, :, 0] > 0, y_data[:, :, 0] > 0)))
# there are the same number of cells
assert (len(np.unique(stitched_imgs)) == len(np.unique(y_data)))
# check mark in upper left hand corner of image
for stack in range(stack_len):
original = np.zeros((10, 10))
original[:stack, :stack] = 1
new = stitched_imgs[0, stack, 0, 0, :10, :10, 0]
assert np.array_equal(original > 0, new > 0)