Skip to content

Commit

Permalink
Testing: Fix hexagonal Simulation1D
Browse files Browse the repository at this point in the history
  • Loading branch information
CSSFrancis committed Dec 19, 2023
1 parent 8580ce7 commit 659af99
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
14 changes: 8 additions & 6 deletions diffsims/generators/simulation_generator.py
Expand Up @@ -248,12 +248,14 @@ def calculate_diffraction1d(

if is_lattice_hexagonal(latt):
# Use Miller-Bravais indices for hexagonal lattices.
g_indices = (
g_indices[0],
g_indices[1],
-g_indices[0] - g_indices[1],
g_indices[2],
)
g_indices = np.array(
[
g_indices[:, 0],
g_indices[:, 1],
g_indices[:, 0] - g_indices[:, 1],
g_indices[:, 2],
]
).T

hkls_labels = ["".join([str(int(x)) for x in xs]) for xs in g_indices]

Expand Down
21 changes: 20 additions & 1 deletion diffsims/tests/generators/test_simulation_generator.py
Expand Up @@ -32,6 +32,7 @@
_shape_factor_precession,
)
from diffsims.simulations import Simulation1D
from diffsims.utils.sim_utils import is_lattice_hexagonal


@pytest.fixture(params=[(300)])
Expand Down Expand Up @@ -200,12 +201,30 @@ def test_shape_factor_custom(self, diffraction_calculator, local_structure):
# softly makes sure the two sims are different
assert np.sum(t1.coordinates.intensity) != np.sum(t2.coordinates.intensity)

def test_simulate_1d(self):
@pytest.mark.parametrize("is_hex", [True, False])
def test_simulate_1d(self, is_hex):
generator = SimulationGenerator(300)
phase = make_phase()
if is_hex:
phase.structure.lattice.a = phase.structure.lattice.b
phase.structure.lattice.alpha = 90
phase.structure.lattice.beta = 90
phase.structure.lattice.gamma = 120
assert is_lattice_hexagonal(phase.structure.lattice)
else:
assert not is_lattice_hexagonal(phase.structure.lattice)
sim = generator.calculate_diffraction1d(phase, 0.5)
assert isinstance(sim, Simulation1D)

assert len(sim.intensities) == len(sim.reciprocal_spacing)
assert len(sim.intensities) == len(sim.hkl)
for h in sim.hkl:
h = h.replace("-", "")
if is_hex:
assert len(h) == 4
else:
assert len(h) == 3


@pytest.mark.parametrize("scattering_param", ["lobato", "xtables"])
def test_param_check(scattering_param):
Expand Down

0 comments on commit 659af99

Please sign in to comment.