Skip to content

Commit

Permalink
Merge #114 from pc494/tidy-ups
Browse files Browse the repository at this point in the history
Small simulator improvements
  • Loading branch information
pc494 committed Sep 3, 2020
2 parents 297d2d7 + cbf11da commit 2de88f5
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 125 deletions.
131 changes: 65 additions & 66 deletions diffsims/generators/diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_vectorized_list_for_atomic_scattering_factors,
is_lattice_hexagonal,
get_intensities_params,
get_scattering_params_dict
)
from diffsims.utils.fourier_transform import from_recip

Expand All @@ -59,29 +60,28 @@ class DiffractionGenerator(object):
accelerating_voltage : float
The accelerating voltage of the microscope in kV.
max_excitation_error : float
The maximum extent of the relrods in reciprocal angstroms. Typically
equal to 1/{specimen thickness}.
debye_waller_factors : dict of str : float
Removed in this version, defaults to None
debye_waller_factors : dict of str:value pairs
Maps element names to their temperature-dependent Debye-Waller factors.
scattering_params : str
"lobato" or "xtables"
"""

# TODO: Refactor the excitation error to a structure property.

def __init__(
self,
accelerating_voltage,
max_excitation_error,
debye_waller_factors=None,
scattering_params="lobato",
max_excitation_error=None,
debye_waller_factors={},
scattering_params="lobato"
):
if max_excitation_error is not None:
print(
"This class changed in v0.3 and no longer takes a maximum_excitation_error"
)
self.wavelength = get_electron_wavelength(accelerating_voltage)
self.max_excitation_error = max_excitation_error
self.debye_waller_factors = debye_waller_factors or {}

scattering_params_dict = {"lobato": "lobato", "xtables": "xtables"}
if scattering_params in scattering_params_dict:
self.scattering_params = scattering_params_dict[scattering_params]
self.debye_waller_factors = debye_waller_factors
if scattering_params in ["lobato","xtables"]:
self.scattering_params = scattering_params
else:
raise NotImplementedError(
"The scattering parameters `{}` is not implemented. "
Expand All @@ -90,7 +90,14 @@ def __init__(
)

def calculate_ed_data(
self, structure, reciprocal_radius, rotation=(0, 0, 0), with_direct_beam=True
self,
structure,
reciprocal_radius,
rotation=(0, 0, 0),
shape_factor_model="linear",
max_excitation_error=1e-2,
with_direct_beam=True,
**kwargs
):
"""Calculates the Electron Diffraction data for a structure.
Expand All @@ -106,9 +113,16 @@ def calculate_ed_data(
rotation : tuple
Euler angles, in degrees, in the rzxz convention. Default is (0,0,0)
which aligns 'z' with the electron beam
shape_factor_model : function or str
a function that takes excitation_error and max_excitation_error (and potentially **kwargs) and returns an intensity
scaling factor. The code provides "linear" and "binary" options accessed with by parsing the associated strings
max_excitation_error : float
the exctinction distance for reflections, in reciprocal Angstroms
with_direct_beam : bool
If True, the direct beam is included in the simulated diffraction
pattern. If False, it is not.
**kwargs :
passed to shape_factor_model
Returns
-------
Expand All @@ -118,12 +132,9 @@ def calculate_ed_data(
"""
# Specify variables used in calculation
wavelength = self.wavelength
max_excitation_error = self.max_excitation_error
debye_waller_factors = self.debye_waller_factors
latt = structure.lattice
scattering_params = self.scattering_params

# Obtain crystallographic reciprocal lattice points within `max_r` and
# Obtain crystallographic reciprocal lattice points within `reciprocal_radius` and
# g-vector magnitudes for intensity calculations.
recip_latt = latt.reciprocal()
spot_indices, cartesian_coordinates, spot_distances = get_points_in_sphere(
Expand All @@ -150,19 +161,24 @@ def calculate_ed_data(
g_indices = spot_indices[intersection]
excitation_error = excitation_error[intersection]
g_hkls = spot_distances[intersection]
multiplicites = np.ones_like(g_hkls)

shape_factor = 1 - (excitation_error / max_excitation_error)
if shape_factor_model == "linear":
shape_factor = 1 - (excitation_error / max_excitation_error)
elif shape_factor_model == "binary":
shape_factor = 1
else:
shape_factor = shape_factor_model(
excitation_error, max_excitation_error, **kwargs
)

# Calculate diffracted intensities based on a kinematical model.
intensities = get_kinematical_intensities(
structure,
g_indices,
g_hkls,
debye_waller_factors,
multiplicites,
scattering_params,
shape_factor,
prefactor=shape_factor,
scattering_params=self.scattering_params,
debye_waller_factors=self.debye_waller_factors,
)

# Threshold peaks included in simulation based on minimum intensity.
Expand Down Expand Up @@ -197,7 +213,7 @@ def calculate_profile_data(
reciprocal angstroms.
magnitude_tolerance : float
The minimum difference between diffraction magnitudes in reciprocal
angstroms for two peaks to be consdiered different.
angstroms for two peaks to be considered different.
minimum_intensity : float
The minimum intensity required for a diffraction peak to be
considered real. Deals with numerical precision issues.
Expand All @@ -208,12 +224,8 @@ def calculate_profile_data(
The diffraction profile corresponding to this structure and
experimental conditions.
"""
max_r = reciprocal_radius
wavelength = self.wavelength
scattering_params = self.scattering_params

latt = structure.lattice
is_hex = is_lattice_hexagonal(latt)

# Obtain crystallographic reciprocal lattice points within range
recip_latt = latt.reciprocal()
Expand All @@ -222,30 +234,29 @@ def calculate_profile_data(
)

##spot_indicies is a numpy.array of the hkls allowd in the recip radius
unique_hkls, multiplicites, g_hkls = get_intensities_params(
g_indices, multiplicities, g_hkls = get_intensities_params(
recip_latt, reciprocal_radius
)
g_indices = unique_hkls
debye_waller_factors = self.debye_waller_factors
excitation_error = None
max_excitation_error = None
g_hkls_array = np.asarray(g_hkls)

i_hkl = get_kinematical_intensities(
structure,
g_indices,
g_hkls_array,
debye_waller_factors,
multiplicites,
scattering_params,
shape_factor=1,
np.asarray(g_hkls),
prefactor=multiplicities,
scattering_params=self.scattering_params,
debye_waller_factors=self.debye_waller_factors,
)

if is_hex:
if is_lattice_hexagonal(latt):
# Use Miller-Bravais indices for hexagonal lattices.
g_indices = (g_indices[0], g_indices[1], -g_indices[0] - g_indices[1], g_indices[2])
g_indices = (
g_indices[0],
g_indices[1],
-g_indices[0] - g_indices[1],
g_indices[2],
)

hkls_labels = ["".join([str(int(x)) for x in xs]) for xs in unique_hkls]
hkls_labels = ["".join([str(int(x)) for x in xs]) for xs in g_indices]

peaks = {}
for l, i, g in zip(hkls_labels, i_hkl, g_hkls):
Expand Down Expand Up @@ -285,21 +296,11 @@ class AtomicDiffractionGenerator:
"""

def __init__(
self,
accelerating_voltage,
detector,
reciprocal_mesh=False,
debye_waller_factors=None,
):
def __init__(self, accelerating_voltage, detector, reciprocal_mesh=False):
self.wavelength = get_electron_wavelength(accelerating_voltage)
# Always store a 'real' mesh
self.detector = detector if not reciprocal_mesh else from_recip(detector)

if debye_waller_factors:
raise NotImplementedError("Not implemented for this simulator")
self.debye_waller_factors = debye_waller_factors or {}

def calculate_ed_data(
self,
structure,
Expand All @@ -319,10 +320,8 @@ def calculate_ed_data(
Parameters
----------
coordinates : ndarray of floats, shape [n_atoms, 3]
List of atomic coordinates, i.e. atom i is centred at <coordinates>[i]
species : ndarray of integers, shape [n_atoms]
List of atomic numbers, i.e. atom i has atomic number <species>[i]
structure : Structure
The structure for upon which to perform the calculation
probe : instance of probeFunction
Function representing 3D shape of beam
slice_thickness : float
Expand Down Expand Up @@ -366,14 +365,16 @@ def calculate_ed_data(

species = structure.element
coordinates = structure.xyz_cartn.reshape(species.size, -1)
dim = coordinates.shape[1]
assert dim == 3
dim = coordinates.shape[1] # guarenteed to be 3

if not ZERO > 0:
raise ValueError("The value of the ZERO argument must be greater than 0")

if probe_centre is None:
probe_centre = np.zeros(dim)
elif len(probe_centre) < dim:
elif len(probe_centre) == (dim - 1):
probe_centre = np.array(list(probe_centre) + [0])
probe_centre = np.array(probe_centre)

coordinates = coordinates - probe_centre[None]

if not precessed:
Expand All @@ -385,8 +386,6 @@ def calculate_ed_data(
dtype = round(dtype.itemsize / (1 if dtype.kind == "f" else 2))
dtype = "f" + str(dtype), "c" + str(2 * dtype)

assert ZERO > 0

# Filter list of atoms
for d in range(dim - 1):
ind = coordinates[:, d] >= self.detector[d].min() - 20
Expand Down
1 change: 0 additions & 1 deletion diffsims/sims/diffraction_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def __init__(self, magnitudes, intensities, hkls):
self.intensities = intensities
self.hkls = hkls


def get_plot(self, g_max, annotate_peaks=True, with_labels=True, fontsize=12):

"""Plots the diffraction profile simulation for the
Expand Down
3 changes: 1 addition & 2 deletions diffsims/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ def default_structure():
@pytest.fixture
def default_simulator():
accelerating_voltage = 300
max_excitation_error = 1e-2
return DiffractionGenerator(accelerating_voltage, max_excitation_error)
return DiffractionGenerator(accelerating_voltage)
41 changes: 25 additions & 16 deletions diffsims/tests/test_generators/test_diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
import diffpy.structure


@pytest.fixture(
params=[(300, 0.02, None),]
)
@pytest.fixture(params=[(300)])
def diffraction_calculator(request):
return DiffractionGenerator(*request.param)
return DiffractionGenerator(request.param)


@pytest.fixture(params=[(300, [np.linspace(-1, 1, 10)] * 2)])
Expand Down Expand Up @@ -87,6 +85,7 @@ def probe(x, out=None, scale=None):
class TestDiffractionCalculator:
def test_init(self, diffraction_calculator: DiffractionGenerator):
assert diffraction_calculator.debye_waller_factors == {}
_ = DiffractionGenerator(300, 2)

def test_matching_results(self, diffraction_calculator, local_structure):
diffraction = diffraction_calculator.calculate_ed_data(
Expand Down Expand Up @@ -125,6 +124,20 @@ def test_appropriate_intensities(self, diffraction_calculator, local_structure):
)
assert np.all(smaller)

@pytest.mark.parametrize("string", ["linear", "binary"])
def test_shape_factor_strings(
self, diffraction_calculator, local_structure, string
):
_ = diffraction_calculator.calculate_ed_data(
local_structure, 2, shape_factor_model=string
)

def test_shape_factor_custom(self, diffraction_calculator, local_structure):
def local_excite(excitation_error, maximum_excitation_error, t):
return (np.sin(t) * excitation_error) / maximum_excitation_error

_ = diffraction_calculator.calculate_ed_data(local_structure, 2,shape_factor_model=local_excite, t=0.2)

def test_calculate_profile_class(self, local_structure, diffraction_calculator):
# tests the non-hexagonal (cubic) case
profile = diffraction_calculator.calculate_profile_data(
Expand All @@ -143,7 +156,6 @@ def test_calculate_profile_class(self, local_structure, diffraction_calculator):

class TestDiffractionCalculatorAtomic:
def test_init(self, diffraction_calculator_atomic: AtomicDiffractionGenerator):
assert diffraction_calculator_atomic.debye_waller_factors == {}
assert len(diffraction_calculator_atomic.detector) == 2

def test_shapes(self, diffraction_calculator_atomic, local_structure, precessed):
Expand All @@ -168,28 +180,25 @@ def test_mode(self, diffraction_calculator_atomic, local_structure):
local_structure, probe, 1, mode="other"
)


scattering_params = ["lobato", "xtables"]
@pytest.mark.xfail(raises=ValueError, strict=True)
def test_bad_ZERO(self, diffraction_calculator_atomic, local_structure):
_ = diffraction_calculator_atomic.calculate_ed_data(
local_structure, probe, 1, ZERO=-1
)


@pytest.mark.parametrize("scattering_param", scattering_params)
@pytest.mark.parametrize("scattering_param", ["lobato", "xtables"])
def test_param_check(scattering_param):
generator = DiffractionGenerator(300, 0.2, None, scattering_params=scattering_param)
generator = DiffractionGenerator(300, scattering_params=scattering_param)


@pytest.mark.xfail(raises=NotImplementedError)
def test_invalid_scattering_params():
scattering_param = "_empty"
generator = DiffractionGenerator(300, 0.2, None, scattering_params=scattering_param)
generator = DiffractionGenerator(300, scattering_params=scattering_param)


@pytest.mark.parametrize("shape", [(10, 20), (20, 10)])
def test_param_check_atomic(shape):
detector = [np.linspace(-1, 1, s) for s in shape]
generator = AtomicDiffractionGenerator(300, detector, True)


@pytest.mark.xfail(raises=NotImplementedError)
def test_invalid_scattering_params_atomic():
detector = [np.linspace(-1, 1, 10)] * 2
generator = AtomicDiffractionGenerator(300, detector, debye_waller_factors=True)
2 changes: 1 addition & 1 deletion diffsims/tests/test_generators/test_library_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@pytest.fixture
def diffraction_calculator():
return DiffractionGenerator(300.0, 0.02)
return DiffractionGenerator(300.0)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion diffsims/tests/test_library/test_diffraction_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

@pytest.fixture
def get_library(default_structure):
diffraction_calculator = DiffractionGenerator(300.0, 0.02)
diffraction_calculator = DiffractionGenerator(300.0)
dfl = DiffractionLibraryGenerator(diffraction_calculator)
structure_library = StructureLibrary(
["Phase"], [default_structure], [np.array([(0, 0, 0), (0, 0.2, 0)])]
Expand Down
1 change: 1 addition & 0 deletions diffsims/tests/test_sims/test_diffraction_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def profile_simulation():
],
)


def test_plot_profile_simulation(profile_simulation):
profile_simulation.get_plot(g_max=1)

Expand Down
Loading

0 comments on commit 2de88f5

Please sign in to comment.