Skip to content

Commit

Permalink
reformatted with black
Browse files Browse the repository at this point in the history
  • Loading branch information
din14970 committed Nov 18, 2020
1 parent 598a95e commit d24eb78
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 92 deletions.
115 changes: 61 additions & 54 deletions diffsims/generators/diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,22 @@
)
from diffsims.utils.fourier_transform import from_recip
from diffsims.utils.shape_factor_models import (
linear,
atanc,
lorentzian,
sinc,
sin2c,
lorentzian_precession,
)
linear,
atanc,
lorentzian,
sinc,
sin2c,
lorentzian_precession,
)


_shape_factor_model_mapping = {
"linear": linear,
"atanc": atanc,
"sinc": sinc,
"sin2c": sin2c,
"lorentzian": lorentzian,
}
"linear": linear,
"atanc": atanc,
"sinc": sinc,
"sin2c": sin2c,
"lorentzian": lorentzian,
}


def _z_sphere_precession(phi, r_spot, wavelength, theta):
Expand All @@ -75,15 +75,17 @@ def _z_sphere_precession(phi, r_spot, wavelength, theta):
The height of the ewald sphere at the point r in A^-1
"""
phi = np.deg2rad(phi)
r = 1/wavelength
r = 1 / wavelength
theta = np.deg2rad(theta)
return (-np.sqrt(r**2*(1-np.sin(theta)**2*np.sin(phi)**2) -
(r_spot - r*np.sin(theta)*np.cos(phi))**2) +
r*np.cos(theta))
return -np.sqrt(
r ** 2 * (1 - np.sin(theta) ** 2 * np.sin(phi) ** 2)
- (r_spot - r * np.sin(theta) * np.cos(phi)) ** 2
) + r * np.cos(theta)


def _shape_factor_precession(z_spot, r_spot, wavelength, precession_angle,
function, max_excitation, **kwargs):
def _shape_factor_precession(
z_spot, r_spot, wavelength, precession_angle, function, max_excitation, **kwargs
):
"""
The rel-rod shape factors for reflections taking into account
precession
Expand Down Expand Up @@ -120,29 +122,30 @@ def _shape_factor_precession(z_spot, r_spot, wavelength, precession_angle,
shf = np.zeros(z_spot.shape)
# loop over all spots
for i, (z_spot_i, r_spot_i) in enumerate(zip(z_spot, r_spot)):

def integrand(phi):
z_sph = _z_sphere_precession(phi, r_spot_i,
wavelength, precession_angle)
return function(z_spot_i-z_sph, max_excitation, **kwargs)
z_sph = _z_sphere_precession(phi, r_spot_i, wavelength, precession_angle)
return function(z_spot_i - z_sph, max_excitation, **kwargs)

# average factor integrated over the full revolution of the beam
shf[i] = (1/(360))*quad(integrand, 0, 360)[0]
shf[i] = (1 / (360)) * quad(integrand, 0, 360)[0]
return shf


def _average_excitation_error_precession(z_spot, r_spot,
wavelength, precession_angle):
def _average_excitation_error_precession(z_spot, r_spot, wavelength, precession_angle):
"""
Calculate the average excitation error for spots
"""
ext = np.zeros(z_spot.shape)
# loop over all spots
for i, (z_spot_i, r_spot_i) in enumerate(zip(z_spot, r_spot)):

def integrand(phi):
z_sph = _z_sphere_precession(phi, r_spot_i,
wavelength, precession_angle)
return z_spot_i-z_sph
z_sph = _z_sphere_precession(phi, r_spot_i, wavelength, precession_angle)
return z_spot_i - z_sph

# average factor integrated over the full revolution of the beam
ext[i] = (1/(360))*quad(integrand, 0, 360)[0]
ext[i] = (1 / (360)) * quad(integrand, 0, 360)[0]
return ext


Expand Down Expand Up @@ -172,7 +175,7 @@ class DiffractionGenerator(object):
Angle about which the beam is precessed. Default is no precession.
approximate_precession : boolean
When using precession, whether to precisely calculate average
excitation errors and intensities or use an approximation.
excitation errors and intensities or use an approximation.
shape_factor_model : function or string
A function that takes excitation_error and
`max_excitation_error` (and potentially kwargs) and returns
Expand Down Expand Up @@ -206,7 +209,9 @@ def __init__(
self.approximate_precession = approximate_precession
if isinstance(shape_factor_model, str):
if shape_factor_model in _shape_factor_model_mapping.keys():
self.shape_factor_model = _shape_factor_model_mapping[shape_factor_model]
self.shape_factor_model = _shape_factor_model_mapping[
shape_factor_model
]
else:
raise NotImplementedError(
f"{shape_factor_model} is not a recognized shape factor "
Expand Down Expand Up @@ -293,11 +298,11 @@ def calculate_ed_data(
# We find the average excitation error - this step can be
# quite expensive
excitation_error = _average_excitation_error_precession(
z_spot,
r_spot,
wavelength,
self.precession_angle,
)
z_spot,
r_spot,
wavelength,
self.precession_angle,
)
else:
z_sphere = -np.sqrt(r_sphere ** 2 - r_spot ** 2) + r_sphere
excitation_error = z_sphere - z_spot
Expand All @@ -311,25 +316,24 @@ def calculate_ed_data(
# take into consideration rel-rods
if self.precession_angle > 0 and not self.approximate_precession:
shape_factor = _shape_factor_precession(
intersection_coordinates[:, 2],
r_spot,
wavelength,
self.precession_angle,
self.shape_factor_model,
max_excitation_error,
**self.shape_factor_kwargs,
)
intersection_coordinates[:, 2],
r_spot,
wavelength,
self.precession_angle,
self.shape_factor_model,
max_excitation_error,
**self.shape_factor_kwargs,
)
elif self.precession_angle > 0 and self.approximate_precession:
shape_factor = lorentzian_precession(
excitation_error,
max_excitation_error,
r_spot,
self.precession_angle,
)
excitation_error,
max_excitation_error,
r_spot,
self.precession_angle,
)
else:
shape_factor = self.shape_factor_model(
excitation_error, max_excitation_error,
**self.shape_factor_kwargs
excitation_error, max_excitation_error, **self.shape_factor_kwargs
)

# Calculate diffracted intensities based on a kinematical model.
Expand All @@ -356,8 +360,11 @@ def calculate_ed_data(
)

def calculate_profile_data(
self, structure, reciprocal_radius=1.0, minimum_intensity=1e-3,
debye_waller_factors={}
self,
structure,
reciprocal_radius=1.0,
minimum_intensity=1e-3,
debye_waller_factors={},
):
"""Calculates a one dimensional diffraction profile for a
structure.
Expand Down Expand Up @@ -469,7 +476,7 @@ def calculate_ed_data(
dtype="float64",
ZERO=1e-14,
mode="kinematic",
**kwargs
**kwargs,
):
"""
Calculates single electron diffraction image for particular atomic
Expand Down
12 changes: 6 additions & 6 deletions diffsims/generators/library_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def get_diffraction_library(
rotation=orientation,
with_direct_beam=with_direct_beam,
max_excitation_error=max_excitation_error,
debye_waller_factors=debye_waller_factors
debye_waller_factors=debye_waller_factors,
)

# Calibrate simulation
Expand All @@ -126,11 +126,11 @@ def get_diffraction_library(
intensities[i] = simulation.intensities

diffraction_library[phase_name] = {
"simulations": simulations,
"orientations": orientations,
"pixel_coords": pixel_coords,
"intensities": intensities,
}
"simulations": simulations,
"orientations": orientations,
"pixel_coords": pixel_coords,
"intensities": intensities,
}
# Pass attributes to diffraction library from structure library.
diffraction_library.identifiers = structure_library.identifiers
diffraction_library.structures = structure_library.structures
Expand Down
2 changes: 1 addition & 1 deletion diffsims/generators/sphere_mesh_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,5 +481,5 @@ def beam_directions_grid_to_euler(vectors):
phi2 = sign * np.nan_to_num(np.arccos(x_comp / norm_proj))
# phi1 is just 0, rotation around z''
phi1 = np.zeros(phi2.shape[0])
grid = np.rad2deg(np.vstack([phi1, Phi, np.pi/2 - phi2]).T)
grid = np.rad2deg(np.vstack([phi1, Phi, np.pi / 2 - phi2]).T)
return grid
65 changes: 34 additions & 31 deletions diffsims/tests/test_generators/test_diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
_average_excitation_error_precession,
)
import diffpy.structure
from diffsims.utils.shape_factor_models import (linear, binary, sin2c,
atanc, lorentzian)
from diffsims.utils.shape_factor_models import linear, binary, sin2c, atanc, lorentzian


@pytest.fixture(params=[(300)])
Expand All @@ -39,14 +38,12 @@ def diffraction_calculator(request):

@pytest.fixture(scope="module")
def diffraction_calculator_precession_full():
return DiffractionGenerator(300, precession_angle=0.5,
approximate_precession=False)
return DiffractionGenerator(300, precession_angle=0.5, approximate_precession=False)


@pytest.fixture(scope="module")
def diffraction_calculator_precession_simple():
return DiffractionGenerator(300, precession_angle=0.5,
approximate_precession=True)
return DiffractionGenerator(300, precession_angle=0.5, approximate_precession=True)


def local_excite(excitation_error, maximum_excitation_error, t):
Expand All @@ -55,9 +52,7 @@ def local_excite(excitation_error, maximum_excitation_error, t):

@pytest.fixture(scope="module")
def diffraction_calculator_custom():
return DiffractionGenerator(300,
shape_factor_model=local_excite,
t=0.2)
return DiffractionGenerator(300, shape_factor_model=local_excite, t=0.2)


@pytest.fixture(params=[(300, [np.linspace(-1, 1, 10)] * 2)])
Expand Down Expand Up @@ -109,11 +104,17 @@ def probe(x, out=None, scale=None):
return v + 0 * x[2].reshape(1, 1, -1)


@pytest.mark.parametrize("parameters, expected",
[([0, 1, 0.001, 0.5], -0.00822681491001731),
([0, np.array([1, 2, 20]), 0.001, 0.5],
np.array([-0.00822681, -0.01545354, 0.02547058])),
([180, 1, 0.001, 0.5], 0.00922693)])
@pytest.mark.parametrize(
"parameters, expected",
[
([0, 1, 0.001, 0.5], -0.00822681491001731),
(
[0, np.array([1, 2, 20]), 0.001, 0.5],
np.array([-0.00822681, -0.01545354, 0.02547058]),
),
([180, 1, 0.001, 0.5], 0.00922693),
],
)
def test_z_sphere_precession(parameters, expected):
result = _z_sphere_precession(*parameters)
assert np.allclose(result, expected)
Expand All @@ -132,10 +133,10 @@ def test_average_excitation_error_precession():
_ = _average_excitation_error_precession(z, r, 0.001, 0.5)


@pytest.mark.parametrize("model, expected",
[("linear", linear),
("lorentzian", lorentzian),
(binary, binary)],)
@pytest.mark.parametrize(
"model, expected",
[("linear", linear), ("lorentzian", lorentzian), (binary, binary)],
)
def test_diffraction_generator_init(model, expected):
generator = DiffractionGenerator(300, shape_factor_model=model)
assert generator.shape_factor_model == expected
Expand All @@ -156,26 +157,30 @@ def test_matching_results(self, diffraction_calculator, local_structure):
assert len(diffraction.indices) == len(diffraction.coordinates)
assert len(diffraction.coordinates) == len(diffraction.intensities)

def test_precession_simple(self, diffraction_calculator_precession_simple,
local_structure):
def test_precession_simple(
self, diffraction_calculator_precession_simple, local_structure
):
diffraction = diffraction_calculator_precession_simple.calculate_ed_data(
local_structure, reciprocal_radius=5.0,
local_structure,
reciprocal_radius=5.0,
)
assert len(diffraction.indices) == len(diffraction.coordinates)
assert len(diffraction.coordinates) == len(diffraction.intensities)

def test_precession_full(self, diffraction_calculator_precession_full,
local_structure):
def test_precession_full(
self, diffraction_calculator_precession_full, local_structure
):
diffraction = diffraction_calculator_precession_full.calculate_ed_data(
local_structure, reciprocal_radius=5.0,
local_structure,
reciprocal_radius=5.0,
)
assert len(diffraction.indices) == len(diffraction.coordinates)
assert len(diffraction.coordinates) == len(diffraction.intensities)

def test_custom_shape_func(self, diffraction_calculator_custom,
local_structure):

def test_custom_shape_func(self, diffraction_calculator_custom, local_structure):
diffraction = diffraction_calculator_custom.calculate_ed_data(
local_structure, reciprocal_radius=5.0,
local_structure,
reciprocal_radius=5.0,
)
assert len(diffraction.indices) == len(diffraction.coordinates)
assert len(diffraction.coordinates) == len(diffraction.intensities)
Expand Down Expand Up @@ -211,9 +216,7 @@ def test_appropriate_intensities(self, diffraction_calculator, local_structure):
assert np.all(smaller)

def test_shape_factor_strings(self, diffraction_calculator, local_structure):
_ = diffraction_calculator.calculate_ed_data(
local_structure, 2
)
_ = diffraction_calculator.calculate_ed_data(local_structure, 2)

def test_shape_factor_custom(self, diffraction_calculator, local_structure):

Expand Down

0 comments on commit d24eb78

Please sign in to comment.