Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small simulator improvements for v0.3 #114

Merged
merged 25 commits into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
071d39e
Tidy up func sig for get_kinematical_intensities
pc494 Sep 2, 2020
992d1f1
adjustments for calculate_ed_data
pc494 Sep 2, 2020
f71d74f
Adjusting call of calculate_profile_data
pc494 Sep 2, 2020
e4c3427
Update __init__ of DiffractionGenerator
pc494 Sep 2, 2020
17ce971
Some more tidy up in kinematic sims
pc494 Sep 2, 2020
f1e77f9
Some style corrections for FFT generator
pc494 Sep 2, 2020
1ffaaa7
linear is a string, not a function for now
pc494 Sep 2, 2020
b46e2c9
Update default DiffractionGenerator
pc494 Sep 2, 2020
ef37df7
Minor tidy, fixes unbound local error
pc494 Sep 2, 2020
48f5737
Reordering to avoid unbound local
pc494 Sep 2, 2020
3d44bb0
Fix spelling error
pc494 Sep 2, 2020
969c8ec
General test updates
pc494 Sep 2, 2020
9cef9d2
propogating unique_hkl change to g_indicies
pc494 Sep 2, 2020
1d0d976
Dropping excitation_error again
pc494 Sep 2, 2020
412c6d6
Fixing syntaxes within tests
pc494 Sep 2, 2020
f740b90
adds binary feature + docstrings
pc494 Sep 2, 2020
070f167
stash changes prior to running black
pc494 Sep 2, 2020
445391f
black formatting
pc494 Sep 2, 2020
f2bee9f
catching the missed spare arg coverage
pc494 Sep 2, 2020
d8fb291
Coverage boosts for newly added code
pc494 Sep 2, 2020
a9c6e38
Small refactor and some more docstrings
pc494 Sep 2, 2020
3cbb7d5
hotfix + minor code space savings
pc494 Sep 2, 2020
da9a0e5
Cleaning up testing syntax, hopefully addressing coverage
pc494 Sep 2, 2020
bebba5f
More coverage adjustments
pc494 Sep 2, 2020
cbf11da
excitation_function---> shape_factor_model
pc494 Sep 3, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 65 additions & 66 deletions diffsims/generators/diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
get_vectorized_list_for_atomic_scattering_factors,
is_lattice_hexagonal,
get_intensities_params,
get_scattering_params_dict
)
from diffsims.utils.fourier_transform import from_recip

Expand All @@ -59,29 +60,28 @@ class DiffractionGenerator(object):
accelerating_voltage : float
The accelerating voltage of the microscope in kV.
max_excitation_error : float
The maximum extent of the relrods in reciprocal angstroms. Typically
equal to 1/{specimen thickness}.
debye_waller_factors : dict of str : float
Removed in this version, defaults to None
debye_waller_factors : dict of str:value pairs
Maps element names to their temperature-dependent Debye-Waller factors.

scattering_params : str
"lobato" or "xtables"
"""

# TODO: Refactor the excitation error to a structure property.

def __init__(
self,
accelerating_voltage,
max_excitation_error,
debye_waller_factors=None,
scattering_params="lobato",
max_excitation_error=None,
debye_waller_factors={},
scattering_params="lobato"
):
if max_excitation_error is not None:
print(
"This class changed in v0.3 and no longer takes a maximum_excitation_error"
)
self.wavelength = get_electron_wavelength(accelerating_voltage)
self.max_excitation_error = max_excitation_error
self.debye_waller_factors = debye_waller_factors or {}

scattering_params_dict = {"lobato": "lobato", "xtables": "xtables"}
if scattering_params in scattering_params_dict:
self.scattering_params = scattering_params_dict[scattering_params]
self.debye_waller_factors = debye_waller_factors
if scattering_params in ["lobato","xtables"]:
self.scattering_params = scattering_params
else:
raise NotImplementedError(
"The scattering parameters `{}` is not implemented. "
Expand All @@ -90,7 +90,14 @@ def __init__(
)

def calculate_ed_data(
self, structure, reciprocal_radius, rotation=(0, 0, 0), with_direct_beam=True
self,
structure,
reciprocal_radius,
rotation=(0, 0, 0),
shape_factor_model="linear",
max_excitation_error=1e-2,
with_direct_beam=True,
**kwargs
):
"""Calculates the Electron Diffraction data for a structure.

Expand All @@ -106,9 +113,16 @@ def calculate_ed_data(
rotation : tuple
Euler angles, in degrees, in the rzxz convention. Default is (0,0,0)
which aligns 'z' with the electron beam
shape_factor_model : function or str
a function that takes excitation_error and max_excitation_error (and potentially **kwargs) and returns an intensity
scaling factor. The code provides "linear" and "binary" options accessed with by parsing the associated strings
max_excitation_error : float
the exctinction distance for reflections, in reciprocal Angstroms
with_direct_beam : bool
If True, the direct beam is included in the simulated diffraction
pattern. If False, it is not.
**kwargs :
passed to shape_factor_model

Returns
-------
Expand All @@ -118,12 +132,9 @@ def calculate_ed_data(
"""
# Specify variables used in calculation
wavelength = self.wavelength
max_excitation_error = self.max_excitation_error
debye_waller_factors = self.debye_waller_factors
latt = structure.lattice
scattering_params = self.scattering_params

# Obtain crystallographic reciprocal lattice points within `max_r` and
# Obtain crystallographic reciprocal lattice points within `reciprocal_radius` and
# g-vector magnitudes for intensity calculations.
recip_latt = latt.reciprocal()
spot_indices, cartesian_coordinates, spot_distances = get_points_in_sphere(
Expand All @@ -150,19 +161,24 @@ def calculate_ed_data(
g_indices = spot_indices[intersection]
excitation_error = excitation_error[intersection]
g_hkls = spot_distances[intersection]
multiplicites = np.ones_like(g_hkls)

shape_factor = 1 - (excitation_error / max_excitation_error)
if shape_factor_model == "linear":
shape_factor = 1 - (excitation_error / max_excitation_error)
elif shape_factor_model == "binary":
shape_factor = 1
else:
shape_factor = shape_factor_model(
excitation_error, max_excitation_error, **kwargs
)

# Calculate diffracted intensities based on a kinematical model.
intensities = get_kinematical_intensities(
structure,
g_indices,
g_hkls,
debye_waller_factors,
multiplicites,
scattering_params,
shape_factor,
prefactor=shape_factor,
scattering_params=self.scattering_params,
debye_waller_factors=self.debye_waller_factors,
)

# Threshold peaks included in simulation based on minimum intensity.
Expand Down Expand Up @@ -197,7 +213,7 @@ def calculate_profile_data(
reciprocal angstroms.
magnitude_tolerance : float
The minimum difference between diffraction magnitudes in reciprocal
angstroms for two peaks to be consdiered different.
angstroms for two peaks to be considered different.
minimum_intensity : float
The minimum intensity required for a diffraction peak to be
considered real. Deals with numerical precision issues.
Expand All @@ -208,12 +224,8 @@ def calculate_profile_data(
The diffraction profile corresponding to this structure and
experimental conditions.
"""
max_r = reciprocal_radius
wavelength = self.wavelength
scattering_params = self.scattering_params

latt = structure.lattice
is_hex = is_lattice_hexagonal(latt)

# Obtain crystallographic reciprocal lattice points within range
recip_latt = latt.reciprocal()
Expand All @@ -222,30 +234,29 @@ def calculate_profile_data(
)

##spot_indicies is a numpy.array of the hkls allowd in the recip radius
unique_hkls, multiplicites, g_hkls = get_intensities_params(
g_indices, multiplicities, g_hkls = get_intensities_params(
recip_latt, reciprocal_radius
)
g_indices = unique_hkls
debye_waller_factors = self.debye_waller_factors
excitation_error = None
max_excitation_error = None
g_hkls_array = np.asarray(g_hkls)

i_hkl = get_kinematical_intensities(
structure,
g_indices,
g_hkls_array,
debye_waller_factors,
multiplicites,
scattering_params,
shape_factor=1,
np.asarray(g_hkls),
prefactor=multiplicities,
scattering_params=self.scattering_params,
debye_waller_factors=self.debye_waller_factors,
)

if is_hex:
if is_lattice_hexagonal(latt):
# Use Miller-Bravais indices for hexagonal lattices.
g_indices = (g_indices[0], g_indices[1], -g_indices[0] - g_indices[1], g_indices[2])
g_indices = (
g_indices[0],
g_indices[1],
-g_indices[0] - g_indices[1],
g_indices[2],
)

hkls_labels = ["".join([str(int(x)) for x in xs]) for xs in unique_hkls]
hkls_labels = ["".join([str(int(x)) for x in xs]) for xs in g_indices]

peaks = {}
for l, i, g in zip(hkls_labels, i_hkl, g_hkls):
Expand Down Expand Up @@ -285,21 +296,11 @@ class AtomicDiffractionGenerator:

"""

def __init__(
self,
accelerating_voltage,
detector,
reciprocal_mesh=False,
debye_waller_factors=None,
):
def __init__(self, accelerating_voltage, detector, reciprocal_mesh=False):
self.wavelength = get_electron_wavelength(accelerating_voltage)
# Always store a 'real' mesh
self.detector = detector if not reciprocal_mesh else from_recip(detector)

if debye_waller_factors:
raise NotImplementedError("Not implemented for this simulator")
self.debye_waller_factors = debye_waller_factors or {}

def calculate_ed_data(
self,
structure,
Expand All @@ -319,10 +320,8 @@ def calculate_ed_data(

Parameters
----------
coordinates : ndarray of floats, shape [n_atoms, 3]
List of atomic coordinates, i.e. atom i is centred at <coordinates>[i]
species : ndarray of integers, shape [n_atoms]
List of atomic numbers, i.e. atom i has atomic number <species>[i]
structure : Structure
The structure for upon which to perform the calculation
probe : instance of probeFunction
Function representing 3D shape of beam
slice_thickness : float
Expand Down Expand Up @@ -366,14 +365,16 @@ def calculate_ed_data(

species = structure.element
coordinates = structure.xyz_cartn.reshape(species.size, -1)
dim = coordinates.shape[1]
assert dim == 3
dim = coordinates.shape[1] # guarenteed to be 3

if not ZERO > 0:
raise ValueError("The value of the ZERO argument must be greater than 0")

if probe_centre is None:
probe_centre = np.zeros(dim)
elif len(probe_centre) < dim:
elif len(probe_centre) == (dim - 1):
probe_centre = np.array(list(probe_centre) + [0])
probe_centre = np.array(probe_centre)

coordinates = coordinates - probe_centre[None]

if not precessed:
Expand All @@ -385,8 +386,6 @@ def calculate_ed_data(
dtype = round(dtype.itemsize / (1 if dtype.kind == "f" else 2))
dtype = "f" + str(dtype), "c" + str(2 * dtype)

assert ZERO > 0

# Filter list of atoms
for d in range(dim - 1):
ind = coordinates[:, d] >= self.detector[d].min() - 20
Expand Down
1 change: 0 additions & 1 deletion diffsims/sims/diffraction_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,6 @@ def __init__(self, magnitudes, intensities, hkls):
self.intensities = intensities
self.hkls = hkls


def get_plot(self, g_max, annotate_peaks=True, with_labels=True, fontsize=12):

"""Plots the diffraction profile simulation for the
Expand Down
3 changes: 1 addition & 2 deletions diffsims/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@ def default_structure():
@pytest.fixture
def default_simulator():
accelerating_voltage = 300
max_excitation_error = 1e-2
return DiffractionGenerator(accelerating_voltage, max_excitation_error)
return DiffractionGenerator(accelerating_voltage)
41 changes: 25 additions & 16 deletions diffsims/tests/test_generators/test_diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,9 @@
import diffpy.structure


@pytest.fixture(
params=[(300, 0.02, None),]
)
@pytest.fixture(params=[(300)])
def diffraction_calculator(request):
return DiffractionGenerator(*request.param)
return DiffractionGenerator(request.param)


@pytest.fixture(params=[(300, [np.linspace(-1, 1, 10)] * 2)])
Expand Down Expand Up @@ -87,6 +85,7 @@ def probe(x, out=None, scale=None):
class TestDiffractionCalculator:
def test_init(self, diffraction_calculator: DiffractionGenerator):
assert diffraction_calculator.debye_waller_factors == {}
_ = DiffractionGenerator(300, 2)

def test_matching_results(self, diffraction_calculator, local_structure):
diffraction = diffraction_calculator.calculate_ed_data(
Expand Down Expand Up @@ -125,6 +124,20 @@ def test_appropriate_intensities(self, diffraction_calculator, local_structure):
)
assert np.all(smaller)

@pytest.mark.parametrize("string", ["linear", "binary"])
def test_shape_factor_strings(
self, diffraction_calculator, local_structure, string
):
_ = diffraction_calculator.calculate_ed_data(
local_structure, 2, shape_factor_model=string
)

def test_shape_factor_custom(self, diffraction_calculator, local_structure):
def local_excite(excitation_error, maximum_excitation_error, t):
return (np.sin(t) * excitation_error) / maximum_excitation_error

_ = diffraction_calculator.calculate_ed_data(local_structure, 2,shape_factor_model=local_excite, t=0.2)

def test_calculate_profile_class(self, local_structure, diffraction_calculator):
# tests the non-hexagonal (cubic) case
profile = diffraction_calculator.calculate_profile_data(
Expand All @@ -143,7 +156,6 @@ def test_calculate_profile_class(self, local_structure, diffraction_calculator):

class TestDiffractionCalculatorAtomic:
def test_init(self, diffraction_calculator_atomic: AtomicDiffractionGenerator):
assert diffraction_calculator_atomic.debye_waller_factors == {}
assert len(diffraction_calculator_atomic.detector) == 2

def test_shapes(self, diffraction_calculator_atomic, local_structure, precessed):
Expand All @@ -168,28 +180,25 @@ def test_mode(self, diffraction_calculator_atomic, local_structure):
local_structure, probe, 1, mode="other"
)


scattering_params = ["lobato", "xtables"]
@pytest.mark.xfail(raises=ValueError, strict=True)
def test_bad_ZERO(self, diffraction_calculator_atomic, local_structure):
_ = diffraction_calculator_atomic.calculate_ed_data(
local_structure, probe, 1, ZERO=-1
)


@pytest.mark.parametrize("scattering_param", scattering_params)
@pytest.mark.parametrize("scattering_param", ["lobato", "xtables"])
def test_param_check(scattering_param):
generator = DiffractionGenerator(300, 0.2, None, scattering_params=scattering_param)
generator = DiffractionGenerator(300, scattering_params=scattering_param)


@pytest.mark.xfail(raises=NotImplementedError)
def test_invalid_scattering_params():
scattering_param = "_empty"
generator = DiffractionGenerator(300, 0.2, None, scattering_params=scattering_param)
generator = DiffractionGenerator(300, scattering_params=scattering_param)


@pytest.mark.parametrize("shape", [(10, 20), (20, 10)])
def test_param_check_atomic(shape):
detector = [np.linspace(-1, 1, s) for s in shape]
generator = AtomicDiffractionGenerator(300, detector, True)


@pytest.mark.xfail(raises=NotImplementedError)
def test_invalid_scattering_params_atomic():
detector = [np.linspace(-1, 1, 10)] * 2
generator = AtomicDiffractionGenerator(300, detector, debye_waller_factors=True)
2 changes: 1 addition & 1 deletion diffsims/tests/test_generators/test_library_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

@pytest.fixture
def diffraction_calculator():
return DiffractionGenerator(300.0, 0.02)
return DiffractionGenerator(300.0)


@pytest.fixture
Expand Down
2 changes: 1 addition & 1 deletion diffsims/tests/test_library/test_diffraction_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

@pytest.fixture
def get_library(default_structure):
diffraction_calculator = DiffractionGenerator(300.0, 0.02)
diffraction_calculator = DiffractionGenerator(300.0)
dfl = DiffractionLibraryGenerator(diffraction_calculator)
structure_library = StructureLibrary(
["Phase"], [default_structure], [np.array([(0, 0, 0), (0, 0.2, 0)])]
Expand Down
1 change: 1 addition & 0 deletions diffsims/tests/test_sims/test_diffraction_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def profile_simulation():
],
)


def test_plot_profile_simulation(profile_simulation):
profile_simulation.get_plot(g_max=1)

Expand Down
Loading