Skip to content

Commit

Permalink
Mass black reformatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pc494 committed Aug 27, 2020
1 parent fe2c2a4 commit 1f1195b
Show file tree
Hide file tree
Showing 18 changed files with 1,648 additions and 886 deletions.
71 changes: 49 additions & 22 deletions diffsims/generators/diffraction_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def calculate_profile_data(


class AtomicDiffractionGenerator:
'''
"""
Computes electron diffraction patterns for an atomic lattice.
Parameters
Expand All @@ -304,22 +304,37 @@ class AtomicDiffractionGenerator:
If True then `detector` is assumed to be a reciprocal grid, else
(default) it is assumed to be a real grid.
'''
"""

def __init__(self, accelerating_voltage, detector,
reciprocal_mesh=False, debye_waller_factors=None):
def __init__(
self,
accelerating_voltage,
detector,
reciprocal_mesh=False,
debye_waller_factors=None,
):
self.wavelength = get_electron_wavelength(accelerating_voltage)
# Always store a 'real' mesh
self.detector = detector if not reciprocal_mesh else from_recip(detector)

if debye_waller_factors:
raise NotImplementedError('Not implemented for this simulator')
raise NotImplementedError("Not implemented for this simulator")
self.debye_waller_factors = debye_waller_factors or {}

def calculate_ed_data(self, structure, probe, slice_thickness,
probe_centre=None, z_range=200, precessed=False, dtype='float64',
ZERO=1e-14, mode='kinematic', **kwargs):
'''
def calculate_ed_data(
self,
structure,
probe,
slice_thickness,
probe_centre=None,
z_range=200,
precessed=False,
dtype="float64",
ZERO=1e-14,
mode="kinematic",
**kwargs
):
"""
Calculates single electron diffraction image for particular atomic
structure and probe.
Expand Down Expand Up @@ -368,7 +383,7 @@ def calculate_ed_data(self, structure, probe, slice_thickness,
Diffraction data to be interpreted as a discretisation on the original
detector mesh.
'''
"""

species = structure.element
coordinates = structure.xyz_cartn.reshape(species.size, -1)
Expand All @@ -388,8 +403,8 @@ def calculate_ed_data(self, structure, probe, slice_thickness,
precessed = (float(precessed), 30)

dtype = np.dtype(dtype)
dtype = round(dtype.itemsize / (1 if dtype.kind == 'f' else 2))
dtype = 'f' + str(dtype), 'c' + str(2 * dtype)
dtype = round(dtype.itemsize / (1 if dtype.kind == "f" else 2))
dtype = "f" + str(dtype), "c" + str(2 * dtype)

assert ZERO > 0

Expand All @@ -401,16 +416,28 @@ def calculate_ed_data(self, structure, probe, slice_thickness,
coordinates, species = coordinates[ind, :], species[ind]

# Add z-coordinate
z_range = max(z_range, coordinates[:, -1].ptp()) # enforce minimal resolution in reciprocal space
x = [self.detector[0], self.detector[1],
np.arange(coordinates[:, -1].min() - 20, coordinates[:, -1].min() + z_range + 20, slice_thickness)]

if mode == 'kinematic':
z_range = max(
z_range, coordinates[:, -1].ptp()
) # enforce minimal resolution in reciprocal space
x = [
self.detector[0],
self.detector[1],
np.arange(
coordinates[:, -1].min() - 20,
coordinates[:, -1].min() + z_range + 20,
slice_thickness,
),
]

if mode == "kinematic":
from diffsims.sims import kinematic_simulation as simlib
else:
raise NotImplementedError('<mode> = %s is not currently supported' % repr(mode))
raise NotImplementedError(
"<mode> = %s is not currently supported" % repr(mode)
)

kwargs['dtype'] = dtype
kwargs['ZERO'] = ZERO
return simlib.get_diffraction_image(coordinates, species, probe, x,
self.wavelength, precessed, **kwargs)
kwargs["dtype"] = dtype
kwargs["ZERO"] = ZERO
return simlib.get_diffraction_image(
coordinates, species, probe, x, self.wavelength, precessed, **kwargs
)
34 changes: 20 additions & 14 deletions diffsims/generators/library_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,7 @@ def get_diffraction_library(
return diffraction_library


def _generate_lookup_table(recip_latt,
reciprocal_radius: float,
unique: bool = True):
def _generate_lookup_table(recip_latt, reciprocal_radius: float, unique: bool = True):
"""Generate a look-up table with all combinations of indices,
including their reciprocal distances and the angle between
them.
Expand All @@ -163,8 +161,8 @@ def _generate_lookup_table(recip_latt,
"""
miller_indices, coordinates, distances = get_points_in_sphere(
recip_latt,
reciprocal_radius)
recip_latt, reciprocal_radius
)

# Create pair_indices for selecting all point pair combinations
num_indices = len(miller_indices)
Expand All @@ -184,18 +182,26 @@ def _generate_lookup_table(recip_latt,
pair_indices = np.vstack([pair_a_indices, pair_b_indices])

# Create library entries
angles = get_angle_cartesian_vec(coordinates[pair_a_indices], coordinates[pair_b_indices])
angles = get_angle_cartesian_vec(
coordinates[pair_a_indices], coordinates[pair_b_indices]
)
pair_distances = distances[pair_indices.T]
# Ensure longest vector is first
len_sort = np.fliplr(pair_distances.argsort(axis=1))
# phase_index_pairs is a list of [hkl1, hkl2]
phase_index_pairs = np.take_along_axis(miller_indices[pair_indices.T], len_sort[:, :, np.newaxis], axis=1)
phase_index_pairs = np.take_along_axis(
miller_indices[pair_indices.T], len_sort[:, :, np.newaxis], axis=1
)
# phase_measurements is a list of [len1, len2, angle]
phase_measurements = np.column_stack((np.take_along_axis(pair_distances, len_sort, axis=1), angles))
phase_measurements = np.column_stack(
(np.take_along_axis(pair_distances, len_sort, axis=1), angles)
)

if unique:
# Only keep unique triplets
measurements, measurement_indices = np.unique(phase_measurements, axis=0, return_index=True)
measurements, measurement_indices = np.unique(
phase_measurements, axis=0, return_index=True
)
indices = phase_index_pairs[measurement_indices]
else:
measurements = phase_measurements
Expand Down Expand Up @@ -245,13 +251,13 @@ def get_vector_library(self, reciprocal_radius):
# Get reciprocal lattice points within reciprocal_radius
recip_latt = structure.lattice.reciprocal()

measurements, indices = _generate_lookup_table(recip_latt=recip_latt,
reciprocal_radius=reciprocal_radius,
unique=True)
measurements, indices = _generate_lookup_table(
recip_latt=recip_latt, reciprocal_radius=reciprocal_radius, unique=True
)

vector_library[phase_name] = {
'indices': indices,
'measurements': measurements
"indices": indices,
"measurements": measurements,
}

# Pass attributes to diffraction library from structure library.
Expand Down
20 changes: 13 additions & 7 deletions diffsims/generators/rotation_list_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@
"triclinic": [180, 360, 0],
}

def get_list_from_orix(grid,rounding=2):

def get_list_from_orix(grid, rounding=2):
"""
Converts an orix sample to a rotation list
Expand All @@ -64,12 +65,13 @@ def get_list_from_orix(grid,rounding=2):
rotation_list = z.data.tolist()
i = 0
while i < len(rotation_list):
rotation_list[i] = tuple(np.round(rotation_list[i],decimals=rounding))
rotation_list[i] = tuple(np.round(rotation_list[i], decimals=rounding))
i += 1

return rotation_list

def get_fundamental_zone_grid(resolution=2, point_group=None,space_group=None):

def get_fundamental_zone_grid(resolution=2, point_group=None, space_group=None):
"""
Generates an equispaced grid of rotations within a fundamental zone.
Expand All @@ -88,10 +90,11 @@ def get_fundamental_zone_grid(resolution=2, point_group=None,space_group=None):
Grid of rotations lying within the specified fundamental zone
"""

orix_grid = get_sample_fundamental(resolution=resolution,space_group=space_group)
rotation_list = get_list_from_orix(orix_grid,rounding=2)
orix_grid = get_sample_fundamental(resolution=resolution, space_group=space_group)
rotation_list = get_list_from_orix(orix_grid, rounding=2)
return rotation_list


def get_local_grid(resolution=2, center=None, grid_width=10):
"""
Generates a grid of rotations about a given rotation
Expand All @@ -112,10 +115,13 @@ def get_local_grid(resolution=2, center=None, grid_width=10):
-------
rotation_list : list of tuples
"""
orix_grid = get_sample_local(resolution=resolution, center=center, grid_width=grid_width)
rotation_list = get_list_from_orix(orix_grid,rounding=2)
orix_grid = get_sample_local(
resolution=resolution, center=center, grid_width=grid_width
)
rotation_list = get_list_from_orix(orix_grid, rounding=2)
return rotation_list


def get_grid_around_beam_direction(beam_rotation, resolution, angular_range=(0, 360)):
"""
Creates a rotation list of rotations for which the rotation is about given beam direction
Expand Down
4 changes: 3 additions & 1 deletion diffsims/libraries/structure_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def from_crystal_systems(
NotImplementedError:
"This function has been removed in version 0.3.0, in favour of creation from orientation lists"
"""
raise NotImplementedError("This function has been removed in version 0.3.0, in favour of creation from orientation lists")
raise NotImplementedError(
"This function has been removed in version 0.3.0, in favour of creation from orientation lists"
)

def get_library_size(self, to_print=False):
"""
Expand Down
9 changes: 5 additions & 4 deletions diffsims/sims/diffraction_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,17 @@ def get_diffraction_pattern(self, size=512, sigma=10):
the order of 0.5nm and a the default size and sigma are used.
"""
side_length = np.min(np.multiply((size / 2), self.calibration))
mask_for_sides = np.all((np.abs(self.coordinates[:, 0:2]) < side_length), axis=1)

mask_for_sides = np.all(
(np.abs(self.coordinates[:, 0:2]) < side_length), axis=1
)

spot_coords = np.add(
self.calibrated_coordinates[mask_for_sides], size / 2
).astype(int)
spot_intens = self.intensities[mask_for_sides]
pattern = np.zeros([size, size])
#checks that we have some spots
if spot_intens.shape[0]==0:
# checks that we have some spots
if spot_intens.shape[0] == 0:
return pattern
else:
pattern[spot_coords[:, 0], spot_coords[:, 1]] = spot_intens
Expand Down
72 changes: 46 additions & 26 deletions diffsims/sims/kinematic_simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,31 @@
from diffsims.utils.discretise_utils import get_discretisation
from numpy import array, pi, sin, cos, empty, maximum, sqrt
from scipy.interpolate import interpn
from diffsims.utils.fourier_transform import (get_DFT, to_recip, fftshift_phase,
plan_fft, fast_abs)
from diffsims.utils.fourier_transform import (
get_DFT,
to_recip,
fftshift_phase,
plan_fft,
fast_abs,
)
from diffsims.utils.generic_utils import to_mesh


def normalise(arr): return arr / arr.max()
def normalise(arr):
return arr / arr.max()


def get_diffraction_image(coordinates, species, probe, x, wavelength,
precession, GPU=True, pointwise=False, **kwargs):
def get_diffraction_image(
coordinates,
species,
probe,
x,
wavelength,
precession,
GPU=True,
pointwise=False,
**kwargs
):
"""
Return kinematically simulated diffraction pattern
Expand Down Expand Up @@ -71,15 +86,15 @@ def get_diffraction_image(coordinates, species, probe, x, wavelength,
The two-dimensional diffraction pattern evaluated on the reciprocal grid
corresponding to the first two vectors of `x`.
"""
FTYPE = kwargs['dtype'][0]
kwargs['GPU'] = GPU
kwargs['pointwise'] = pointwise
FTYPE = kwargs["dtype"][0]
kwargs["GPU"] = GPU
kwargs["pointwise"] = pointwise

x = [X.astype(FTYPE, copy=False) for X in x]
y = to_recip(x)
if wavelength == 0:
p = probe(x).mean(-1)
# vol = get_discretisation(coordinates, species, x, **kwargs).mean(-1)
# vol = get_discretisation(coordinates, species, x, **kwargs).mean(-1)
vol = get_discretisation(coordinates, species, x[:2], **kwargs)[..., 0]
ft = get_DFT(x[:-1], y[:-1])[0]
else:
Expand All @@ -95,13 +110,20 @@ def get_diffraction_image(coordinates, species, probe, x, wavelength,
else:
return normalise(grid2sphere(arr, y, None, 2 * pi / wavelength))

R = [precess_mat(precession[0], i * 360 / precession[1]) for i in range(precession[1])]
R = [
precess_mat(precession[0], i * 360 / precession[1])
for i in range(precession[1])
]

if wavelength == 0:
return normalise(sum(get_diffraction_image(coordinates.dot(r),
species, probe, x, wavelength,
(0, 1), **kwargs)
for r in R))
return normalise(
sum(
get_diffraction_image(
coordinates.dot(r), species, probe, x, wavelength, (0, 1), **kwargs
)
for r in R
)
)

fftshift_phase(vol) # removes need for fftshift after fft
buf = empty(vol.shape, dtype=FTYPE)
Expand Down Expand Up @@ -144,11 +166,9 @@ def precess_mat(alpha, theta):
if alpha == 0:
return array([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
alpha, theta = alpha * pi / 180, theta * pi / 180
R_a = array([[1, 0, 0], [0, cos(alpha), -sin(alpha)],
[0, sin(alpha), cos(alpha)]])
R_t = array([[cos(theta), -sin(theta), 0],
[sin(theta), cos(theta), 0], [0, 0, 1]])
R = (R_t.T.dot(R_a.dot(R_t)))
R_a = array([[1, 0, 0], [0, cos(alpha), -sin(alpha)], [0, sin(alpha), cos(alpha)]])
R_t = array([[cos(theta), -sin(theta), 0], [sin(theta), cos(theta), 0], [0, 0, 1]])
R = R_t.T.dot(R_a.dot(R_t))

return R

Expand Down Expand Up @@ -184,12 +204,12 @@ def grid2sphere(arr, x, dx, C):
return arr[:, :, 0]

y = to_mesh((x[0], x[1], array([0])), dx).reshape(-1, 3)
# if C is not None: # project straight up
# w = C - sqrt(maximum(0, C ** 2 - (y ** 2).sum(-1)))
# if dx is None:
# y[:, 2] = w.reshape(-1)
# else:
# y += w.reshape(y.shape[0], 1) * dx[2].reshape(1, 3)
# if C is not None: # project straight up
# w = C - sqrt(maximum(0, C ** 2 - (y ** 2).sum(-1)))
# if dx is None:
# y[:, 2] = w.reshape(-1)
# else:
# y += w.reshape(y.shape[0], 1) * dx[2].reshape(1, 3)

if C is not None: # project on line to centre
w = 1 / (1 + (y ** 2).sum(-1) / C ** 2)
Expand All @@ -199,6 +219,6 @@ def grid2sphere(arr, x, dx, C):
else:
y += C * (1 - w)[:, None] * dx[2]

out = interpn(x, arr, y, method='linear', bounds_error=False, fill_value=0)
out = interpn(x, arr, y, method="linear", bounds_error=False, fill_value=0)

return out.reshape(x[0].size, x[1].size)
Loading

0 comments on commit 1f1195b

Please sign in to comment.