diff --git a/diffsims/generators/diffraction_generator.py b/diffsims/generators/diffraction_generator.py index e481526a..27addaa9 100644 --- a/diffsims/generators/diffraction_generator.py +++ b/diffsims/generators/diffraction_generator.py @@ -289,16 +289,15 @@ def calculate_ed_data( r_spot = np.sqrt(np.sum(np.square(cartesian_coordinates[:, :2]), axis=1)) z_spot = cartesian_coordinates[:, 2] - if self.precession_angle > 0: - if not self.approximate_precession: - # We find the average excitation error - this step can be - # quite expensive - excitation_error = _average_excitation_error_precession( - z_spot, - r_spot, - wavelength, - self.precession_angle, - ) + if self.precession_angle > 0 and not self.approximate_precession: + # We find the average excitation error - this step can be + # quite expensive + excitation_error = _average_excitation_error_precession( + z_spot, + r_spot, + wavelength, + self.precession_angle, + ) else: z_sphere = -np.sqrt(r_sphere ** 2 - r_spot ** 2) + r_sphere excitation_error = z_sphere - z_spot @@ -317,7 +316,7 @@ def calculate_ed_data( wavelength, self.precession_angle, self.shape_factor_model, - self.max_excitation_error, + max_excitation_error, **self.shape_factor_kwargs, ) elif self.precession_angle > 0 and self.approximate_precession: diff --git a/diffsims/tests/test_generators/test_diffraction_generator.py b/diffsims/tests/test_generators/test_diffraction_generator.py index 8655d442..fae195ca 100644 --- a/diffsims/tests/test_generators/test_diffraction_generator.py +++ b/diffsims/tests/test_generators/test_diffraction_generator.py @@ -28,15 +28,27 @@ _average_excitation_error_precession, ) import diffpy.structure -from diffsims.utils.shape_factor_models import (linear, binary, sinc, sin2c, - atanc, lorentzian, - lorentzian_precession) +from diffsims.utils.shape_factor_models import (linear, binary, sin2c, + atanc, lorentzian) + @pytest.fixture(params=[(300)]) def diffraction_calculator(request): return DiffractionGenerator(request.param) +@pytest.fixture(scope="module") +def diffraction_calculator_precession_full(): + return DiffractionGenerator(300, precession_angle=0.5, + approximate_precession=False) + + +@pytest.fixture(scope="module") +def diffraction_calculator_precession_simple(): + return DiffractionGenerator(300, precession_angle=0.5, + approximate_precession=True) + + @pytest.fixture(params=[(300, [np.linspace(-1, 1, 10)] * 2)]) def diffraction_calculator_atomic(request): return AtomicDiffractionGenerator(*request.param) @@ -96,7 +108,7 @@ def test_z_sphere_precession(parameters, expected): assert np.allclose(result, expected) -@pytest.mark.parametrize("model", [linear, atanc, sin2c, lorentzian]) +@pytest.mark.parametrize("model", [binary, linear, atanc, sin2c, lorentzian]) def test_shape_factor_precession(model): z = np.array([-0.1, 0.1]) r = np.array([1, 5]) @@ -133,6 +145,22 @@ def test_matching_results(self, diffraction_calculator, local_structure): assert len(diffraction.indices) == len(diffraction.coordinates) assert len(diffraction.coordinates) == len(diffraction.intensities) + def test_precession_simple(self, diffraction_calculator_precession_simple, + local_structure): + diffraction = diffraction_calculator_precession_simple.calculate_ed_data( + local_structure, reciprocal_radius=5.0, + ) + assert len(diffraction.indices) == len(diffraction.coordinates) + assert len(diffraction.coordinates) == len(diffraction.intensities) + + def test_precession_full(self, diffraction_calculator_precession_full, + local_structure): + diffraction = diffraction_calculator_precession_full.calculate_ed_data( + local_structure, reciprocal_radius=5.0, + ) + assert len(diffraction.indices) == len(diffraction.coordinates) + assert len(diffraction.coordinates) == len(diffraction.intensities) + def test_appropriate_scaling(self, diffraction_calculator: DiffractionGenerator): """Tests that doubling the unit cell halves the pattern spacing.""" silicon = make_structure(5)