diff --git a/diffsims/simulations/simulation2d.py b/diffsims/simulations/simulation2d.py index 682f424d..115f9e71 100644 --- a/diffsims/simulations/simulation2d.py +++ b/diffsims/simulations/simulation2d.py @@ -152,8 +152,6 @@ def __init__( else: # iterable of Rotation rotations = np.array(rotations, dtype=object) coordinates = np.array(coordinates, dtype=object) - if len(coordinates.shape) != 2: - coordinates = coordinates[:, np.newaxis] phases = np.array(phases) if rotations.size != phases.size: raise ValueError( @@ -162,7 +160,13 @@ def __init__( ) for r, c in zip(rotations, coordinates): - if r.size != c.size: + if isinstance(c, ReciprocalLatticeVector): + c = np.array( + [ + c, + ] + ) + if r.size != len(c): raise ValueError( f"The number of rotations: {r.size} must match the number of " f"coordinates {c.shape[0]}" diff --git a/diffsims/tests/generators/test_simulation_generator.py b/diffsims/tests/generators/test_simulation_generator.py index f1116e6b..06c6ddea 100644 --- a/diffsims/tests/generators/test_simulation_generator.py +++ b/diffsims/tests/generators/test_simulation_generator.py @@ -21,6 +21,7 @@ import diffpy.structure from orix.crystal_map import Phase +from orix.quaternion import Rotation from diffsims.generators.simulation_generator import SimulationGenerator from diffsims.utils.shape_factor_models import ( @@ -226,6 +227,15 @@ def test_simulate_1d(self, is_hex): assert len(h) == 3 +def test_multiphase_multirotation_simulation(): + generator = SimulationGenerator(300) + silicon = make_phase(5) + big_silicon = make_phase(10) + rot = Rotation.from_euler([[0, 0, 0], [0.1, 0.1, 0.1]]) + rot2 = Rotation.from_euler([[0, 0, 0], [0.1, 0.1, 0.1], [0.2, 0.2, 0.2]]) + sim = generator.calculate_ed_data([silicon, big_silicon], rotation=[rot, rot2]) + + @pytest.mark.parametrize("scattering_param", ["lobato", "xtables"]) def test_param_check(scattering_param): generator = SimulationGenerator(300, scattering_params=scattering_param)