diff --git a/diffsims/generators/diffraction_generator.py b/diffsims/generators/diffraction_generator.py index 796740d4..55f961eb 100644 --- a/diffsims/generators/diffraction_generator.py +++ b/diffsims/generators/diffraction_generator.py @@ -58,7 +58,7 @@ class DiffractionGenerator(object): ---------- accelerating_voltage : float The accelerating voltage of the microscope in kV. - debye_waller_factors : dict of str + debye_waller_factors : dict of str:value pairs Maps element names to their temperature-dependent Debye-Waller factors. scattering_params : str "lobato" or "xtables" @@ -68,8 +68,10 @@ def __init__( self, accelerating_voltage, debye_waller_factors={}, - scattering_params="lobato", + scattering_params="lobato", *args ): + if args: + print("This class changed in v0.3 and no longer takes a maximum_excitation_error") self.wavelength = get_electron_wavelength(accelerating_voltage) self.debye_waller_factors = debye_waller_factors @@ -151,7 +153,7 @@ def calculate_ed_data( if excitation_function == "linear": shape_factor = 1 - (excitation_error / max_excitation_error) - elif excitation_function = "binary": + elif excitation_function == "binary": shape_factor = 1 else: shape_factor = excitation_function(excitation_error,max_excitation_error,**kwargs) @@ -350,12 +352,10 @@ def calculate_ed_data( species = structure.element coordinates = structure.xyz_cartn.reshape(species.size, -1) - dim = coordinates.shape[1] + dim = coordinates.shape[1] #guarenteed to be 3 if not ZERO > 0: raise ValueError("The value of the ZERO argument must be greater than 0") - if not dim == 3: - raise ValueError("This code currently only supports structure represented in 3D") if probe_centre is None: probe_centre = np.zeros(dim) diff --git a/diffsims/tests/test_generators/test_diffraction_generator.py b/diffsims/tests/test_generators/test_diffraction_generator.py index 5c3522a0..1058bb33 100644 --- a/diffsims/tests/test_generators/test_diffraction_generator.py +++ b/diffsims/tests/test_generators/test_diffraction_generator.py @@ -167,15 +167,16 @@ def test_mode(self, diffraction_calculator_atomic, local_structure): local_structure, probe, 1, mode="other" ) + @pytest.mark.xfail(raises=ValueError,strict=True) + def test_bad_ZERO(self,diffraction_calculator_atomic,local_structure): + _ = diffraction_calculator_atomic.calculate_ed_data(local_structure,probe,1,ZERO=-1) -scattering_params = ["lobato", "xtables"] -@pytest.mark.parametrize("scattering_param", scattering_params) +@pytest.mark.parametrize("scattering_param",["lobato", "xtables"]) def test_param_check(scattering_param): generator = DiffractionGenerator(300,scattering_params=scattering_param) - @pytest.mark.xfail(raises=NotImplementedError) def test_invalid_scattering_params(): scattering_param = "_empty" @@ -186,9 +187,3 @@ def test_invalid_scattering_params(): def test_param_check_atomic(shape): detector = [np.linspace(-1, 1, s) for s in shape] generator = AtomicDiffractionGenerator(300, detector, True) - - -@pytest.mark.xfail(raises=NotImplementedError) -def test_invalid_scattering_params_atomic(): - detector = [np.linspace(-1, 1, 10)] * 2 - generator = AtomicDiffractionGenerator(300, detector)