Skip to content

Commit

Permalink
Merge #993 which refactors azimuthal integrator
Browse files Browse the repository at this point in the history
  • Loading branch information
pc494 committed Jan 29, 2024
2 parents a558762 + 8745ab2 commit 961a3f4
Show file tree
Hide file tree
Showing 9 changed files with 733 additions and 46 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ Fixed
-----
- Fixed pytest failure. Changed `setup` --> `setup_method` (#997)

Added
-----
- Added `pyxem.utils.calibration_utils.Calibration` class for calibrating the signal axes of a 4-D STEM dataset(#993)

Deprecated
----------
- The module & all functions within `utils.reduced_intensity1d` are deprecated in favour of using the methods of `ReducedIntensity1D` (#994).
Expand Down
22 changes: 22 additions & 0 deletions examples/processing/calibrating.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""
Calibrating a dataset
=====================
There are two different ways to calibrate a dataset in pyxem and depending on what
kind of data you have you may need to use each of these methods.
The first method is to basically ignore the Ewald sphere effects. This is the
easiest method but not the most correct. For a 200+ keV microscope the assumption
that the Ewald sphere is flat is not a bad one. For lower energy microscopes
this assumption is not as great but still not terrible. For x-ray data with longer
wavelengths this assumption starts to break down.
"""

# import pyxem as pxm

# al = pxm.data.al_peaks()

# determine the pixel size from one peak

# al.calibrate(scale=0.1, center=None, units="k_nm^-1")
# al.plot()
94 changes: 70 additions & 24 deletions pyxem/signals/diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
#
# You should have received a copy of the GNU General Public License
# along with pyXem. If not, see <http://www.gnu.org/licenses/>.


import numba
import numpy as np
from scipy.ndimage import rotate
from skimage import morphology
Expand Down Expand Up @@ -63,6 +62,7 @@
medfilt_1d,
sigma_clip,
)
from pyxem.utils._azimuthal_utils import _slice_radial_integrate
from pyxem.utils.dask_tools import (
_get_dask_array,
get_signal_dimension_host_chunk_slice,
Expand All @@ -84,6 +84,7 @@
_subtract_hdome,
_subtract_radial_median,
)
from pyxem.utils.calibration_utils import Calibration


class Diffraction2D(Signal2D, CommonDiffraction):
Expand All @@ -101,6 +102,21 @@ class Diffraction2D(Signal2D, CommonDiffraction):

""" Methods that make geometrical changes to a diffraction pattern """

def __init__(self, *args, **kwargs):
"""
Create a Diffraction2D object from numpy.ndarray.
Parameters
----------
*args :
Passed to the __init__ of Signal2D. The first arg should be
numpy.ndarray
**kwargs :
Passed to the __init__ of Signal2D
"""
super().__init__(*args, **kwargs)
self.calibrate = Calibration(self)

def apply_affine_transformation(
self, D, order=1, keep_dtype=False, inplace=True, *args, **kwargs
):
Expand Down Expand Up @@ -1813,6 +1829,11 @@ def angular_slice_radial_integration(self):
"angular_slice_radial_average"
)

@deprecated(
since="0.17",
alternative="pyxem.signals.diffraction2d.azimuthal_integral2d",
removal="1.0.0",
)
def angular_slice_radial_average(
self,
angleN=20,
Expand Down Expand Up @@ -1918,6 +1939,11 @@ def ai(self):
except AttributeError:
raise ValueError("ai property is not currently set")

@deprecated(
since="0.18",
removal="1.0.0",
alternative="pyxem.signals.diffraction2d.calibrate",
)
def set_ai(
self, center=None, wavelength=None, affine=None, radial_range=None, **kwargs
):
Expand Down Expand Up @@ -2089,7 +2115,7 @@ def get_azimuthal_integral2d(
radial_range=None,
azimuth_range=None,
inplace=False,
method="splitpixel",
method="splitpixel_pyxem",
sum=False,
correctSolidAngle=True,
**kwargs,
Expand Down Expand Up @@ -2119,7 +2145,8 @@ def get_azimuthal_integral2d(
method: str
Can be “numpy”, “cython”, “BBox” or “splitpixel”, “lut”, “csr”,
“nosplit_csr”, “full_csr”, “lut_ocl” and “csr_ocl” if you want
to go on GPU. To Specify the device: “csr_ocl_1,2”
to go on GPU. To Specify the device: “csr_ocl_1,2”. For pure
pyxem based methods use "splitpixel_pyxem".
sum: bool
If true the radial integration is returned rather then the Azimuthal Integration.
correctSolidAngle: bool
Expand Down Expand Up @@ -2164,27 +2191,46 @@ def get_azimuthal_integral2d(
>>> ds.get_azimuthal_integral2d(npt_rad=100)
"""
sig_shape = self.axes_manager.signal_shape
if radial_range is None:
radial_range = _get_radial_extent(
ai=self.ai, shape=sig_shape, unit=self.unit
usepyfai = method not in ["splitpixel_pyxem"]
if not usepyfai:
# get_slices2d should be sped up in the future by
# getting rid of shapely and using numba on the for loop
slices, factors, factors_slice, radial_range = self.calibrate.get_slices2d(
npt, npt_azim, radial_range=radial_range
)
integration = self.map(
_slice_radial_integrate,
slices=slices,
factors=factors,
factors_slice=factors_slice,
npt_rad=npt,
npt_azim=npt_azim,
inplace=inplace,
**kwargs,
)

else:
sig_shape = self.axes_manager.signal_shape
if radial_range is None:
radial_range = _get_radial_extent(
ai=self.ai, shape=sig_shape, unit=self.unit
)
radial_range[0] = 0
integration = self.map(
azimuthal_integrate2d,
azimuthal_integrator=self.ai,
npt_rad=npt,
npt_azim=npt_azim,
azimuth_range=azimuth_range,
radial_range=radial_range,
method=method,
inplace=inplace,
unit=self.unit,
mask=mask,
sum=sum,
correctSolidAngle=correctSolidAngle,
**kwargs,
)
radial_range[0] = 0
integration = self.map(
azimuthal_integrate2d,
azimuthal_integrator=self.ai,
npt_rad=npt,
npt_azim=npt_azim,
azimuth_range=azimuth_range,
radial_range=radial_range,
method=method,
inplace=inplace,
unit=self.unit,
mask=mask,
sum=sum,
correctSolidAngle=correctSolidAngle,
**kwargs,
)

s = self if inplace else integration
s.set_signal_type("polar_diffraction")
Expand Down
10 changes: 8 additions & 2 deletions pyxem/tests/signals/test_diffraction2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,12 +456,12 @@ def test_2d_azimuthal_integral(self, ones):

def test_2d_azimuthal_integral_scale(self, ring):
ring.set_ai(wavelength=2.5e-12)
az = ring.get_azimuthal_integral2d(npt=500)
az = ring.get_azimuthal_integral2d(npt=500, method="bbox")
peak = np.argmax(az.sum(axis=0)).data * az.axes_manager[1].scale
np.testing.assert_almost_equal(peak[0], 3, decimal=1)
ring.unit = "k_A^-1"
ring.set_ai(wavelength=2.5e-12)
az = ring.get_azimuthal_integral2d(npt=500)
az = ring.get_azimuthal_integral2d(npt=500, method="bbox")
peak = np.argmax(az.sum(axis=0)).data * az.axes_manager[1].scale
np.testing.assert_almost_equal(peak[0], 3, decimal=1)

Expand Down Expand Up @@ -552,6 +552,12 @@ def test_2d_azimuthal_integral_sum(self, ones):
npt=10, npt_azim=15, radial_range=[0, 0.5], sum=True, mask=mask
)

def test_internal_azimuthal_integration(self, ring):
ring.calibrate(scale=1)
az = ring.get_azimuthal_integral2d(npt=40, npt_azim=100, radial_range=(0, 40))
ring_sum = np.sum(az.data, axis=1)
assert ring_sum.shape == (40,)


class TestPyFAIIntegration:
@pytest.fixture
Expand Down
106 changes: 106 additions & 0 deletions pyxem/tests/utils/test_calibration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,113 @@
import pytest
import numpy as np

from hyperspy.axes import UniformDataAxis

from pyxem.utils import calibration_utils
from pyxem.utils.calibration_utils import Calibration
from pyxem.signals import Diffraction2D


class TestCalibrationClass:
@pytest.fixture
def calibration(self):
s = Diffraction2D(np.zeros((10, 10)))
return Calibration(s)

def test_init(self, calibration):
assert isinstance(calibration, Calibration)

def test_set_center(self, calibration):
calibration(center=(5, 5))
assert calibration.signal.axes_manager[0].offset == -5
assert calibration.signal.axes_manager[1].offset == -5
assert calibration.flat_ewald is True

def test_set_beam_energy(self, calibration):
calibration(beam_energy=200)
assert calibration.beam_energy == 200
assert calibration.wavelength is not None

def test_set_wavelength(self, calibration):
calibration(wavelength=0.02508)
assert calibration.wavelength == 0.02508

def test_set_scale(self, calibration):
calibration(scale=0.01)
assert calibration.signal.axes_manager[0].scale == 0.01
assert calibration.signal.axes_manager[1].scale == 0.01
assert calibration.flat_ewald is True

def test_set_failure(self, calibration):
assert calibration.wavelength is None
assert calibration.beam_energy is None
with pytest.raises(ValueError):
calibration.detector(pixel_size=0.1, detector_distance=1)
calibration.beam_energy = 200
calibration.detector(pixel_size=0.1, detector_distance=1)
calibration.detector(
pixel_size=0.1, detector_distance=1, beam_energy=200, units="k_nm^-1"
)
assert calibration.flat_ewald is False
with pytest.raises(ValueError):
calibration(scale=0.01)
assert calibration.scale is None
with pytest.raises(ValueError):
calibration(center=(5, 5))
assert calibration.center == [5, 5]

def test_set_detector(self, calibration):
calibration.detector(
pixel_size=15e-6, # 15 um
detector_distance=3.8e-2, # 38 mm
beam_energy=200, # 200 keV
units="k_nm^-1",
)
assert not isinstance(calibration.signal.axes_manager[0], UniformDataAxis)
diff_arr = np.diff(
calibration.signal.axes_manager[0].axis
) # assume mostly flat.
assert np.allclose(
diff_arr,
diff_arr[0],
)
assert calibration.flat_ewald == False

def test_get_slices2d(self, calibration):
calibration(scale=0.01)
slices, factors, _, _ = calibration.get_slices2d(5, 90)
assert len(slices) == 5 * 90

def test_get_slices_and_factors(self):
s = Diffraction2D(np.zeros((100, 100)))
s.calibrate(scale=0.1, center=None)
slices, factors, factor_slices = s.calibrate._get_slices_and_factors(
npt=100, npt_azim=360, radial_range=(0, 4)
)
# check that the number of pixels for each radial slice is the same
sum_factors = [np.sum(factors[f[0] : f[1]]) for f in factor_slices]
sum_factors = np.reshape(sum_factors, (360, 100)).T
for row in sum_factors:
print(np.min(row), np.max(row))
assert np.allclose(row, row[0], atol=1e-2)
# Check that the total number of pixels accounted for is equal to the area of the circle
# Up to rounding due to the fact that we are actually finding the area of an n-gon where
# n = npt_azim
all_sum = np.sum(sum_factors)
assert np.allclose(all_sum, 3.1415 * 40**2, atol=1)
slices, factors, factor_slices = s.calibrate._get_slices_and_factors(
npt=100, npt_azim=360, radial_range=(0, 15)
)
# check that the number of pixels for each radial slice is the same
sum_factors = [np.sum(factors[f[0] : f[1]]) for f in factor_slices]
sum_factors = np.reshape(sum_factors, (360, 100)).T
# Check that the total number of pixels accounted for is equal to the area of the circle
# Up to rounding due to the fact that we are actually finding the area of an n-gon where
# n = npt_azim
all_sum = np.sum(sum_factors)
# For some reason we are missing 1 row/ column of pixels on the edge
# of the image so this is 9801 instead of 10000!
# assert np.allclose(all_sum, 10000, atol=1)


@pytest.mark.skip(
Expand Down
20 changes: 8 additions & 12 deletions pyxem/tests/utils/test_ransac_ellipse_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,13 +931,11 @@ def test_determine_ellipse_rotated(self, execution_number):
center, affine = ret.determine_ellipse(
s, mask=mask, use_ransac=False, num_points=2000
)
s.unit = "k_nm^-1"
s.beam_energy = 200
s.axes_manager.signal_axes[0].scale = 0.1
s.axes_manager.signal_axes[1].scale = 0.1
s.set_ai(center=center, affine=affine)
s.calibrate(
center=center, affine=affine, scale=0.1, units="k_nm^-1", beam_energy=200
)
s_az = s.get_azimuthal_integral2d(npt=100)
assert np.sum((s_az.sum(axis=0).isig[6:] > 1).data) < 11
assert np.sum((s_az.sum(axis=0).isig[6:] > 1).data) < 12

@mark.parametrize("rot", np.linspace(0, 2 * np.pi, 10))
def test_determine_ellipse_ring(self, rot):
Expand All @@ -951,13 +949,11 @@ def test_determine_ellipse_ring(self, rot):
mask = np.zeros_like(s.data, dtype=bool)
mask[100 - 20 : 100 + 20, 100 - 20 : 100 + 20] = True
center, affine = ret.determine_ellipse(s, mask=mask, use_ransac=False)
s.unit = "k_nm^-1"
s.beam_energy = 200
s.axes_manager.signal_axes[0].scale = 0.1
s.axes_manager.signal_axes[1].scale = 0.1
s.set_ai(center=center, affine=affine)
s.calibrate(
center=center, affine=affine, scale=0.1, units="k_nm^-1", beam_energy=200
)
s_az = s.get_azimuthal_integral2d(npt=100)
assert np.sum((s_az.sum(axis=0).isig[10:] > 1).data) < 10
assert np.sum((s_az.sum(axis=0).isig[10:] > 1).data) < 11

def test_get_max_pos(self):
t = np.ones((100, 100))
Expand Down

0 comments on commit 961a3f4

Please sign in to comment.