Skip to content

Commit

Permalink
Merge d09a2f3 into c024409
Browse files Browse the repository at this point in the history
  • Loading branch information
pc494 committed Mar 25, 2024
2 parents c024409 + d09a2f3 commit 2b24e20
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 134 deletions.
12 changes: 6 additions & 6 deletions pyxem/signals/diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,17 @@
_slice_radial_integrate,
_slice_radial_integrate1d,
)
from pyxem.utils.dask_tools import (
from pyxem.utils._dask import (
_get_dask_array,
get_signal_dimension_host_chunk_slice,
align_single_frame,
_get_signal_dimension_host_chunk_slice,
_align_single_frame,
)
from pyxem.utils.signal import (
select_method_from_method_dict,
to_hyperspy_index,
)
import pyxem.utils._pixelated_stem_tools as pst
import pyxem.utils.dask_tools as dt
import pyxem.utils._dask as dt
import pyxem.utils.ransac_ellipse_tools as ret
from pyxem.utils._deprecated import deprecated, deprecated_argument

Expand Down Expand Up @@ -860,7 +860,7 @@ def center_direct_beam(
else:
align_kwargs["order"] = 0
aligned = self.map(
align_single_frame,
_align_single_frame,
shifts=shifts,
inplace=inplace,
lazy_output=lazy_output,
Expand Down Expand Up @@ -1332,7 +1332,7 @@ def make_probe_navigation(self, method="fast"):
x = round(self.axes_manager.signal_shape[0] / 2)
y = round(self.axes_manager.signal_shape[1] / 2)
if self._lazy:
isig_slice = get_signal_dimension_host_chunk_slice(
isig_slice = _get_signal_dimension_host_chunk_slice(
x, y, self.data.chunks
)
else:
Expand Down
2 changes: 1 addition & 1 deletion pyxem/signals/insitu_diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import dask.array as da
from dask.graph_manipulation import clone

from pyxem.utils.dask_tools import _get_dask_array, _get_chunking
from pyxem.utils._dask import _get_dask_array, _get_chunking
from pyxem.utils._insitu import (
_register_drift_5d,
_register_drift_2d,
Expand Down
24 changes: 12 additions & 12 deletions pyxem/tests/utils/test_dask_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import skimage.morphology as sm

from pyxem.signals import Diffraction2D, LazyDiffraction2D
import pyxem.utils.dask_tools as dt
import pyxem.utils._dask as dt
import pyxem.utils._background_subtraction as bt
import pyxem.utils.expt_utils as et
import pyxem.utils._pixelated_stem_tools as pst
Expand All @@ -36,7 +36,7 @@ class TestSignalDimensionGetChunkSliceList:
def test_chunksizes(self, sig_chunks):
xchunk, ychunk = sig_chunks
data = da.zeros((20, 20, 20, 20), chunks=(10, 10, ychunk, xchunk))
chunk_slice_list = dt.get_signal_dimension_chunk_slice_list(data.chunks)
chunk_slice_list = dt._get_signal_dimension_chunk_slice_list(data.chunks)
assert len(data.chunks[-1]) * len(data.chunks[-2]) == len(chunk_slice_list)
for chunk_slice in chunk_slice_list:
xsize = chunk_slice[0].stop - chunk_slice[0].start
Expand All @@ -46,28 +46,28 @@ def test_chunksizes(self, sig_chunks):

def test_non_square_chunks(self):
data = da.zeros((2, 2, 20, 20), chunks=(2, 2, 15, 15))
chunk_slice_list = dt.get_signal_dimension_chunk_slice_list(data.chunks)
chunk_slice_list = dt._get_signal_dimension_chunk_slice_list(data.chunks)
assert chunk_slice_list[0] == (slice(0, 15, None), slice(0, 15, None))
assert chunk_slice_list[1] == (slice(15, 20, None), slice(0, 15, None))
assert chunk_slice_list[2] == (slice(0, 15, None), slice(15, 20, None))
assert chunk_slice_list[3] == (slice(15, 20, None), slice(15, 20, None))

def test_one_signal_chunk(self):
data = da.zeros((2, 2, 20, 20), chunks=(1, 1, 20, 20))
chunk_slice_list = dt.get_signal_dimension_chunk_slice_list(data.chunks)
chunk_slice_list = dt._get_signal_dimension_chunk_slice_list(data.chunks)
assert len(chunk_slice_list) == 1
assert chunk_slice_list[0] == np.s_[0:20, 0:20]

def test_rechunk(self):
data = da.zeros((2, 2, 20, 20), chunks=(1, 1, 20, 20))
data1 = data.rechunk((2, 2, 10, 10))
chunk_slice_list = dt.get_signal_dimension_chunk_slice_list(data1.chunks)
chunk_slice_list = dt._get_signal_dimension_chunk_slice_list(data1.chunks)
assert len(chunk_slice_list) == 4

def test_slice_navigation(self):
data = da.zeros((2, 2, 20, 20), chunks=(1, 1, 20, 20))
data1 = data[0, 1]
chunk_slice_list = dt.get_signal_dimension_chunk_slice_list(data1.chunks)
chunk_slice_list = dt._get_signal_dimension_chunk_slice_list(data1.chunks)
assert len(chunk_slice_list) == 1
assert chunk_slice_list[0] == np.s_[0:20, 0:20]

Expand Down Expand Up @@ -95,7 +95,7 @@ class TestGetSignalDimensionHostChunkSlice:
def test_simple(self, xy, sig_slice, xchunk, ychunk):
x, y = xy
data = da.zeros((2, 2, 20, 20), chunks=(1, 1, ychunk, xchunk))
chunk_slice = dt.get_signal_dimension_host_chunk_slice(x, y, data.chunks)
chunk_slice = dt._get_signal_dimension_host_chunk_slice(x, y, data.chunks)
assert chunk_slice == sig_slice

@pytest.mark.parametrize(
Expand All @@ -111,7 +111,7 @@ def test_simple(self, xy, sig_slice, xchunk, ychunk):
def test_non_square(self, xy, sig_slice):
x, y = xy
data = da.zeros((2, 2, 30, 20), chunks=(1, 1, 10, 10))
chunk_slice = dt.get_signal_dimension_host_chunk_slice(x, y, data.chunks)
chunk_slice = dt._get_signal_dimension_host_chunk_slice(x, y, data.chunks)
assert chunk_slice == sig_slice


Expand All @@ -124,7 +124,7 @@ def test_simple(self, shifts):
image = np.zeros((y_size, x_size), dtype=np.uint16)
x, y = 3, 7
image[y, x] = 7
image_shifted = dt.align_single_frame(image, shifts)
image_shifted = dt._align_single_frame(image, shifts)
pos = np.s_[y + shifts[1], x + shifts[0]]
assert image_shifted[pos] == 7
image_shifted[pos] = 0
Expand All @@ -144,7 +144,7 @@ def test_subpixel_integer_image(self, shifts, pos):
image = np.zeros((y_size, x_size), dtype=np.uint16)
x, y = 3, 7
image[y, x] = 8
image_shifted = dt.align_single_frame(image, shifts, order=1)
image_shifted = dt._align_single_frame(image, shifts, order=1)
assert (image_shifted[pos] >= 2).all()
image_shifted[pos] = 0
assert not image_shifted.any()
Expand All @@ -163,7 +163,7 @@ def test_subpixel_float_image(self, shifts, pos):
image = np.zeros((y_size, x_size), dtype=np.float32)
x, y = 3, 7
image[y, x] = 9
image_shifted = dt.align_single_frame(image, shifts, order=1)
image_shifted = dt._align_single_frame(image, shifts, order=1)
assert image_shifted[pos].sum() == 9
image_shifted[pos] = 0
assert not image_shifted.any()
Expand All @@ -174,7 +174,7 @@ def test_not_subpixel_float_image(self, shifts):
image = np.zeros((y_size, x_size), dtype=np.float32)
x, y = 3, 7
image[y, x] = 9
image_shifted = dt.align_single_frame(image, shifts, order=0)
image_shifted = dt._align_single_frame(image, shifts, order=0)
pos = np.s_[y + round(shifts[1]), x + round(shifts[0])]
assert image_shifted[pos] == 9.0
image_shifted[pos] = 0
Expand Down
163 changes: 163 additions & 0 deletions pyxem/utils/_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# -*- coding: utf-8 -*-
# Copyright 2016-2024 The pyXem developers
#
# This file is part of pyXem.
#
# pyXem is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# pyXem is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with pyXem. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import dask.array as da
import scipy.ndimage as ndi
from skimage import morphology
from hyperspy.misc.utils import isiterable


def _align_single_frame(image, shifts, **kwargs):
temp_image = ndi.shift(image, shifts[::-1], **kwargs)
return temp_image


def _get_signal_dimension_chunk_slice_list(chunks):
"""Convenience function for getting the signal chunks as slices
The slices are assumed to be used on a HyperSpy signal object.
Thus the input will be in the Dask chunk order (y, x), while the
output will be in the HyperSpy order (x, y).
"""
chunk_slice_raw_list = da.core.slices_from_chunks(chunks[-2:])
chunk_slice_list = []
for chunk_slice_raw in chunk_slice_raw_list:
chunk_slice_list.append((chunk_slice_raw[1], chunk_slice_raw[0]))
return chunk_slice_list


def _get_signal_dimension_host_chunk_slice(x, y, chunks):
chunk_slice_list = _get_signal_dimension_chunk_slice_list(chunks)
for chunk_slice in chunk_slice_list:
x_slice, y_slice = chunk_slice
if y_slice.start <= y < y_slice.stop:
if x_slice.start <= x < x_slice.stop:
return chunk_slice
return False


def _intensity_peaks_image_single_frame(frame, peaks, disk_r):
"""Intensity of the peaks is calculated by taking the mean value
of the pixel values inside radius disk_r where the centers are the
peak positions. If the peak position plus disk_r exceed the detector
edges, then the intensity for that peak will be put to zero.
Parameters
----------
frame : NumPy 2D array
peaks: Numpy 2D array with x and y coordinates of peaks
disk : NumPy 2D array
Must be smaller than frame
peaks: NumPy Object
can have multiple peaks per image
Returns
-------
intensity_array : NumPy array with
peak coordinates and intensity of peaks
Examples
--------
>>> import pyxem.utils.dask_tools as dt
>>> s = pxm.dummy_data.dummy_data.get_cbed_signal()
>>> peaks = np.array(([50,50],[25,50]))
>>> intensity = dt._intensity_peaks_image_single_frame(
... s.data[0,0,:,:], peaks, 5)
"""
array_shape = peaks.shape
mask = morphology.disk(disk_r)
size = np.shape(frame)
intensity_array = np.zeros((array_shape[0], 3), dtype="float64")
for i in range(array_shape[0]):
cx = int(peaks[i, 0])
cy = int(peaks[i, 1])
intensity_array[i, 0] = peaks[i, 0]
intensity_array[i, 1] = peaks[i, 1]
if (
(cx - disk_r < 0)
| (cx + disk_r + 1 >= size[0])
| (cy - disk_r < 0)
| (cy + disk_r + 1 >= size[1])
):
intensity_array[i, 2] = 0
else:
subframe = frame[
cx - disk_r : cx + disk_r + 1, cy - disk_r : cy + disk_r + 1
]
intensity_array[i, 2] = np.mean(mask * subframe)

return intensity_array


def _get_chunking(signal, chunk_shape=None, chunk_bytes=None):
"""Get chunk tuple based on the size of the dataset.
The signal dimensions will be within one chunk, and the navigation
dimensions will be chunked based on either chunk_shape, or
be optimized based on the chunk_bytes.
Parameters
----------
signal : hyperspy or pyxem signal
chunk_shape : int, optional
Size of the navigation chunk, of None (the default), the chunk
size will be set automatically.
chunk_bytes : int or string, optional
Number of bytes in each chunk. For example '60MiB'. If None (the default),
the limit will be '30MiB'. Will not be used if chunk_shape is None.
Returns
-------
chunks : tuple
"""
if chunk_bytes is None:
chunk_bytes = "30MiB"
nav_dim = signal.axes_manager.navigation_dimension
sig_dim = signal.axes_manager.signal_dimension
if chunk_shape is not None:
if not isiterable(chunk_shape):
chunk_shape = [chunk_shape] * nav_dim

chunks_dict = {}
for i in range(nav_dim):
if chunk_shape is None:
chunks_dict[i] = "auto"
else:
chunks_dict[i] = chunk_shape[i]
for i in range(nav_dim, nav_dim + sig_dim):
chunks_dict[i] = -1

chunks = da.core.normalize_chunks(
chunks=chunks_dict,
shape=signal.data.shape,
limit=chunk_bytes,
dtype=signal.data.dtype,
)
return chunks


def _get_dask_array(signal, chunk_shape=None, chunk_bytes=None):
if signal._lazy:
dask_array = signal.data
else:
chunks = _get_chunking(signal, chunk_shape, chunk_bytes)
dask_array = da.from_array(signal.data, chunks=chunks)
return dask_array

0 comments on commit 2b24e20

Please sign in to comment.