-
Notifications
You must be signed in to change notification settings - Fork 3
/
to_multiscales.py
340 lines (305 loc) · 13.2 KB
/
to_multiscales.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
from typing import Union, Optional, Sequence, Mapping, Dict, Tuple, Any, List
from typing_extensions import Literal
from collections.abc import MutableMapping
import time
import shutil
from pathlib import Path
import atexit
import signal
from zarr.core import Array as ZarrArray
from numpy.typing import ArrayLike
from dask.array.core import Array as DaskArray
import numpy as np
import zarr
import dask
from .methods._dask_image import _downsample_dask_image
from .methods._itk import _downsample_itk_bin_shrink, _downsample_itk_gaussian, _downsample_itk_label
from .to_ngff_image import to_ngff_image
from .ngff_image import NgffImage
from .multiscales import Multiscales
from .zarr_metadata import Metadata, Axis, Translation, Scale, Dataset
from .methods import Methods
from .config import config
from .rich_dask_progress import NgffProgress, NgffProgressCallback
from .memory_usage import memory_usage
from .methods._support import _spatial_dims
def _ngff_image_scale_factors(ngff_image, min_length, out_chunks):
sizes = { d: s for d, s in zip(ngff_image.dims, ngff_image.data.shape) if d in _spatial_dims }
scale_factors = []
dims = ngff_image.dims
previous = { d: 1 for d in _spatial_dims.intersection(dims) }
sizes_array = np.array(list(sizes.values()))
sizes = { d: s for d, s in zip(ngff_image.dims, ngff_image.data.shape) if d in _spatial_dims }
double_chunks = np.array([2*out_chunks[d] for d in _spatial_dims.intersection(out_chunks)])
while (sizes_array > double_chunks).any():
max_size = np.array(list(sizes.values())).max()
to_skip = { d: sizes[d] <= max_size / 2 for d in previous.keys() }
scale_factor = {}
for dim in previous.keys():
if to_skip[dim] or sizes[dim] / 2 < out_chunks[dim]:
scale_factor[dim] = previous[dim]
continue
scale_factor[dim] = 2 * previous[dim]
sizes[dim] = int(sizes[dim] / 2)
sizes_array = np.array(list(sizes.values()))
previous = scale_factor
# There should be sufficient data in the result for statistics, etc.
if (np.prod(sizes_array) / min_length) < 2:
break
scale_factors.append(scale_factor)
return scale_factors
def _large_image_serialization(image: NgffImage, progress: Optional[Union[NgffProgress, NgffProgressCallback]]):
if "z" in image.dims:
optimized_chunks = 512
else:
optimized_chunks = 1024
base_path = f"{image.name}-cache-{time.time()}"
cache_store = config.cache_store
base_path_removed = False
def remove_from_cache_store(sig_id, frame):
nonlocal base_path_removed
if not base_path_removed:
if isinstance(cache_store, zarr.storage.DirectoryStore):
full_path = Path(cache_store.dir_path()) / base_path
if full_path.exists():
shutil.rmtree(full_path, ignore_errors=True)
else:
zarr.storage.rmdir(cache_store, base_path)
base_path_removed = True
atexit.register(remove_from_cache_store, None, None)
signal.signal(signal.SIGTERM, remove_from_cache_store)
signal.signal(signal.SIGINT, remove_from_cache_store)
data = image.data
dims = list(image.dims)
x_index = dims.index('x')
y_index = dims.index('y')
rechunks = {}
for index, dim in enumerate(dims):
if dim == 't':
rechunks[index] = 1
elif dim == 'c':
rechunks[index] = 1
else:
rechunks[index] = min(optimized_chunks, data.shape[index])
if 'z' in dims:
z_index = dims.index('z')
slice_bytes = data.dtype.itemsize * data.shape[x_index] * data.shape[y_index]
slab_slices = min(int(np.ceil(config.memory_target / slice_bytes)), data.shape[z_index])
if optimized_chunks < data.shape[z_index]:
slab_slices = min(slab_slices, optimized_chunks)
rechunks[z_index] = slab_slices
path = f"{base_path}/slabs"
slabs = data.rechunk(rechunks)
chunks = tuple([c[0] for c in slabs.chunks])
optimized = dask.array.Array(dask.array.optimize(slabs.__dask_graph__(),
slabs.__dask_keys__()), slabs.name,
slabs.chunks, meta=slabs)
zarr_array = zarr.creation.open_array(
shape=data.shape,
chunks=chunks,
dtype=data.dtype,
store=cache_store,
path=path,
mode='a',
)
n_slabs = int(np.ceil(data.shape[z_index] / slab_slices))
if progress:
progress.add_cache_task(f"[blue]Caching z-slabs", n_slabs)
for slab_index in range(n_slabs):
if progress:
if isinstance(progress, NgffProgressCallback):
progress.add_callback_task(f"[blue]Caching z-slabs {slab_index+1} of {n_slabs}")
progress.update_cache_task_completed((slab_index+1))
region = [slice(data.shape[i]) for i in range(data.ndim)]
region[z_index] = slice(slab_index*slab_slices, min((slab_index+1)*slab_slices, data.shape[z_index]))
region = tuple(region)
arr_region = optimized[region]
dask.array.to_zarr(
arr_region,
zarr_array,
region=region,
component=path,
overwrite=False,
compute=True,
return_stored=False,
)
data = dask.array.from_zarr(cache_store, component=path)
if optimized_chunks < data.shape[z_index] and slab_slices < optimized_chunks:
rechunks[z_index] = optimized_chunks
data = data.rechunk(rechunks)
path = f"{base_path}/optimized_chunks"
chunks = tuple([c[0] for c in optimized.chunks])
data = data.rechunk(chunks)
zarr_array = zarr.creation.open_array(
shape=data.shape,
chunks=chunks,
dtype=data.dtype,
store=cache_store,
path=path,
mode='a',
)
n_slabs = int(np.ceil(data.shape[z_index] / optimized_chunks))
for slab_index in range(n_slabs):
if progress:
if isinstance(progress, NgffProgressCallback):
progress.add_callback_task(f"[blue]Caching z-rechunk {slab_index+1} of {n_slabs}")
progress.update_cache_task_completed((slab_index+1))
region = [slice(data.shape[i]) for i in range(data.ndim)]
region[z_index] = slice(slab_index*optimized_chunks, min((slab_index+1)*optimized_chunks, data.shape[z_index]))
region = tuple(region)
arr_region = data[region]
dask.array.to_zarr(
arr_region,
zarr_array,
region=region,
component=path,
overwrite=False,
compute=True,
return_stored=False,
)
data = dask.array.from_zarr(cache_store, component=path)
else:
data = data.rechunk(rechunks)
else:
data = data.rechunk(rechunks)
# TODO: Slab, chunk optimized very large 2D images
path = base_path + f"/optimized_chunks"
if progress:
progress.add_callback_task(f"[blue]Caching optimized chunks")
dask.array.to_zarr(
data,
cache_store,
component=path,
overwrite=False,
compute=True,
return_stored=False,
)
data = dask.array.from_zarr(cache_store, component=path)
image.data = data
return image
def to_multiscales(
data: Union[NgffImage, ArrayLike, MutableMapping, str, ZarrArray],
scale_factors: Union[int, Sequence[Union[Dict[str, int], int]]] = 128,
method: Optional[Methods] = None,
chunks: Optional[
Union[
int,
Tuple[int, ...],
Tuple[Tuple[int, ...], ...],
Mapping[Any, Union[None, int, Tuple[int, ...]]],
]
] = None,
progress: Optional[Union[NgffProgress, NgffProgressCallback]] = None,
cache: Optional[bool] = None,
) -> Multiscales:
"""
Generate multiple resolution scales for the OME-NGFF standard data model.
data: NgffImage, ArrayLike, ZarrArray, MutableMapping, str
Multi-dimensional array that provides the image pixel values, or image pixel values + image metadata when an NgffImage.
scale_factors : int of minimum length, int per scale or dict of spatial dimension int's per scale
If a single integer, scale factors in spatial dimensions will be increased by a factor of two until this minimum length is reached.
If a list, integer scale factors to apply uniformly across all spatial dimensions or
along individual spatial dimensions.
Examples: 64 or [2, 4] or [{'x': 2, 'y': 4 }, {'x': 5, 'y': 10}]
chunks : Dask array chunking specification, optional
Specify the chunking used in each output scale.
cache : bool, optional
Cache intermediate results to disk to limit memory consumption. If None, the default, determine based on ngff_zarr.config.memory_target.
progress:
Optional progress logger
Returns
-------
multiscales: Multiscales
NgffImage for each resolution and NGFF multiscales metadata
"""
image = data
if isinstance(data, NgffImage):
ngff_image = data
else:
ngff_image = to_ngff_image(data)
# IPFS and visualization friendly default chunks
if "z" in ngff_image.dims:
default_chunks = 128
else:
default_chunks = 256
default_chunks = {d: default_chunks for d in ngff_image.dims}
if "t" in ngff_image.dims:
default_chunks["t"] = 1
out_chunks = chunks
if out_chunks is None:
out_chunks = default_chunks
elif isinstance(out_chunks, int):
out_chunks = {d: chunks for d in ngff_image.dims}
elif isinstance(out_chunks, tuple):
out_chunks = {d: chunks[i] for i, d in enumerate(ngff_image.dims)}
da_out_chunks = tuple(out_chunks[d] for d in ngff_image.dims)
if not isinstance(ngff_image.data, DaskArray):
if isinstance(ngff_image.data, (ZarrArray, str, MutableMapping)):
ngff_image.data = dask.array.from_zarr(ngff_image.data)
else:
ngff_image.data = dask.array.from_array(ngff_image.data)
if isinstance(scale_factors, int):
scale_factors = _ngff_image_scale_factors(ngff_image, scale_factors, out_chunks)
# if cache is None and memory_usage(ngff_image) > config.memory_target or task_count(ngff_image) > config.task_target or cache:
if cache is None and memory_usage(ngff_image) > config.memory_target or cache:
ngff_image = _large_image_serialization(ngff_image, progress)
ngff_image.data = ngff_image.data.rechunk(da_out_chunks)
if method is None:
method = Methods.DASK_IMAGE_GAUSSIAN
if method is Methods.ITK_BIN_SHRINK:
images = _downsample_itk_bin_shrink(
ngff_image, default_chunks, out_chunks, scale_factors
)
elif method is Methods.ITK_GAUSSIAN:
images = _downsample_itk_gaussian(
ngff_image, default_chunks, out_chunks, scale_factors
)
elif method is Methods.DASK_IMAGE_GAUSSIAN:
images = _downsample_dask_image(
ngff_image, default_chunks, out_chunks, scale_factors, label=False
)
elif method is Methods.DASK_IMAGE_NEAREST:
images = _downsample_dask_image(
ngff_image, default_chunks, out_chunks, scale_factors, label="nearest"
)
elif method is Methods.DASK_IMAGE_MODE:
images = _downsample_dask_image(
ngff_image, default_chunks, out_chunks, scale_factors, label="mode"
)
axes = []
for dim in ngff_image.dims:
unit = None
if ngff_image.axes_units and dim in ngff_image.axes_units:
unit = ngff_image.axes_units[dim]
if dim in {"x", "y", "z"}:
axis = Axis(name=dim, type="space", unit=unit)
elif dim == "c":
axis = Axis(name=dim, type="channel", unit=unit)
elif dim == "t":
axis = Axis(name=dim, type="time", unit=unit)
else:
raise KeyError(f'Dimension identifier is not valid: {dim}')
axes.append(axis)
datasets = []
for index, image in enumerate(images):
path = f"scale{index}/{ngff_image.name}"
scale = []
for dim in image.dims:
if dim in image.scale:
scale.append(image.scale[dim])
else:
scale.append(1.0)
translation = []
for dim in image.dims:
if dim in image.translation:
translation.append(image.translation[dim])
else:
translation.append(1.0)
coordinateTransformations = [Scale(scale), Translation(translation)]
dataset = Dataset(
path=path, coordinateTransformations=coordinateTransformations
)
datasets.append(dataset)
metadata = Metadata(axes=axes, datasets=datasets, name=ngff_image.name)
multiscales = Multiscales(images, metadata, scale_factors, method, out_chunks)
return multiscales