Skip to content

Commit

Permalink
Merge 70e7d8b into c024409
Browse files Browse the repository at this point in the history
  • Loading branch information
CSSFrancis committed Mar 22, 2024
2 parents c024409 + 70e7d8b commit 9d1ff18
Show file tree
Hide file tree
Showing 9 changed files with 277 additions and 105 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Fixed
- `pyxem.signals.Diffraction2D.center_of_mass` now uses the `map` function. (#1005)
- Replace ``matplotlib.cm.get_cmap`` (removed in matplotlib 3.9) with ``matplotlib.colormaps``. (#1023)
- Documentation fixes and improvement. (#1028)
- Fixed bug with flattening diffraction Vectors when there are different scales (#1024)

Added
-----
Expand Down
4 changes: 2 additions & 2 deletions doc/dev_guide/contributing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,9 @@ Tips for writing Jupyter notebooks that are meant to be converted to reST text f
command is the last line in a cell.
- Refer to our API reference with this general markdown syntax:
:code:`:meth:`~.signals.Diffraction2D.get_azimuthal_integral1d`` which will be
displayed as :meth:`~.signals.Diffraction2D.get_azimuthal_integral1d` or
displayed as :meth:`~.signals.Diffraction2D.get_azimuthal_integral1d` or
:code:`:meth:`pyxem.signals.Diffraction2D.get_azimuthal_integral1d`` to have the full
path: :meth:`pyxem.signals.Diffraction2D.get_azimuthal_integral1d`
path: :meth:`pyxem.signals.Diffraction2D.get_azimuthal_integral1d`
- Reference external APIs via standard markdown like :code:`:class:`hyperspy.api.signals.Signal2D``,
which will be displayed as :class:`hyperspy.api.signals.Signal2D`.
- The Sphinx gallery thumbnail used for a notebook is set by adding the
Expand Down
158 changes: 131 additions & 27 deletions pyxem/signals/diffraction_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,14 @@ def __init__(self, *args, **kwargs):
_units = kwargs.pop("units", None)
super().__init__(*args, **kwargs)
self._set_up_vector(_scales, _offsets, _detector_shape, _column_names, _units)
if (
self._is_object_dtype is None
): # empty signal with data=None due to `_deepcopy_with_new_data`
pass
elif self._is_object_dtype:
self.ragged = True
elif self.ragged == True:
self.ragged = False

def _repr_html_(self):
table = '<table align="center">'
Expand All @@ -118,6 +126,8 @@ def _repr_html_(self):
vectors = self._get_current_data()
if vectors.dtype.kind == "O":
vectors = vectors[0]
if self.num_columns == 1:
vectors = np.array([vectors]).T
for i, row in enumerate(vectors):
table += "<tr>"
table += f"<td><center>{i}</center></td>"
Expand Down Expand Up @@ -253,6 +263,8 @@ def from_peaks(
cen=center,
inplace=False,
ragged=True,
output_signal_size=(),
output_dtype=object,
)
vectors.set_signal_type("diffraction_vectors")
if isinstance(peaks, LazySignal):
Expand Down Expand Up @@ -374,7 +386,13 @@ def get_pixels(x, off, scale, square_size=None, shape=None):

@property
def _is_object_dtype(self):
return self.data.dtype.kind == "O"
try:
if self.data[0] is None:
return None
else:
return self.data.dtype.kind == "O"
except IndexError:
return None

@cached_property
def num_columns(self):
Expand All @@ -385,6 +403,8 @@ def num_columns(self):
shape = self.data[self.data.ndim * (0,)].shape
if shape is None:
return 0
elif len(shape) == 1:
return 1
else:
return shape[1]
else:
Expand All @@ -401,12 +421,15 @@ def units(self):

@units.setter
def units(self, value):
if isinstance(value, str) and self.num_columns == 1:
value = [value]
if (
isiterable(value)
and len(value) == self.num_columns
and not isinstance(value, str)
):
self.metadata.VectorMetadata["units"] = value

elif isiterable(value) and len(value) != self.num_columns:
raise ValueError(
"The len of the units parameter must equal the number of"
Expand Down Expand Up @@ -448,6 +471,9 @@ def column_names(self):
def column_names(self, value):
if value is None:
value = [f"column_{i}" for i in range(self.num_columns)]

if isinstance(value, str):
value = [value]
if len(value) != self.num_columns:
raise ValueError(
f"The len of the column_names parameter: {len(value)} must equal the"
Expand All @@ -464,6 +490,7 @@ def offsets(self):
def offsets(self, value):
if isiterable(value) and len(value) == self.num_columns:
self.metadata.VectorMetadata["offsets"] = value

elif isiterable(value) and len(value) != self.num_columns:
raise ValueError(
"The len of the scales parameter must equal the number of"
Expand All @@ -475,23 +502,55 @@ def offsets(self, value):
] * self.num_columns

def __lt__(self, other):
if self.ragged:
kwargs = dict(output_signal_size=(), output_dtype=object)
else:
kwargs = dict()
return self.map(
lambda x, other: x < other, other=other, inplace=False, ragged=True
lambda x, other: x < other,
other=other,
inplace=False,
ragged=self.ragged,
**kwargs,
)

def __le__(self, other):
if self.ragged:
kwargs = dict(output_signal_size=(), output_dtype=object)
else:
kwargs = dict()
return self.map(
lambda x, other: x <= other, other=other, inplace=False, ragged=True
lambda x, other: x <= other,
other=other,
inplace=False,
ragged=self.ragged,
**kwargs,
)

def __gt__(self, other):
if self.ragged:
kwargs = dict(output_signal_size=(), output_dtype=object)
else:
kwargs = dict()
return self.map(
lambda x, other: x > other, other=other, inplace=False, ragged=True
lambda x, other: x > other,
other=other,
inplace=False,
ragged=self.ragged,
**kwargs,
)

def __ge__(self, other):
if self.ragged:
kwargs = dict(output_signal_size=(), output_dtype=object)
else:
kwargs = dict()
return self.map(
lambda x, other: x >= other, other=other, inplace=False, ragged=True
lambda x, other: x >= other,
other=other,
inplace=False,
ragged=self.ragged,
**kwargs,
)

@property
Expand Down Expand Up @@ -537,8 +596,8 @@ def _get_navigation_positions(self, flatten=False, real_units=True):
scales = [1 for a in self.axes_manager.navigation_axes]
offsets = [0 for a in self.axes_manager.navigation_axes]
else:
scales = [a.scale for a in self.axes_manager.navigation_axes]
offsets = [a.offset for a in self.axes_manager.navigation_axes]
scales = [a.scale for a in self.axes_manager.navigation_axes[::-1]]
offsets = [a.offset for a in self.axes_manager.navigation_axes[::-1]]

if flatten:
real_nav = np.array(
Expand Down Expand Up @@ -584,22 +643,28 @@ def flatten_diffraction_vectors(
"""
from pyxem.signals.diffraction_vectors2d import DiffractionVectors2D

nav_positions = self._get_navigation_positions(
flatten=True, real_units=real_units
)
if self.axes_manager._navigation_shape_in_array == ():
return self

if self._is_object_dtype:
nav_positions = self._get_navigation_positions(
flatten=False, real_units=real_units
)
vectors = np.vstack(
[
np.hstack(
[np.tile(nav_pos, (len(self.data[ind]), 1)), self.data[ind]]
[
np.tile(nav_positions[ind][::-1], (len(self.data[ind]), 1)),
self.data[ind],
]
)
for ind, nav_pos in zip(np.ndindex(self.data.shape), nav_positions)
for ind in np.ndindex(self.axes_manager._navigation_shape_in_array)
]
)
else:
nav_positions = self._get_navigation_positions(
flatten=True, real_units=real_units
)
navs = np.repeat(nav_positions, self.num_rows, axis=0)
data = self.data.reshape((-1, self.num_columns))
vectors = np.vstack((navs, data))
Expand Down Expand Up @@ -627,14 +692,18 @@ def flatten_diffraction_vectors(
column_offsets = np.append(column_offsets, offsets)
column_scale = np.append(column_scale, scales)

column_names = [
a.name for a in self.axes_manager.navigation_axes
] + self.column_names
column_names = np.append(
[a.name for a in self.axes_manager.navigation_axes], self.column_names
)

if real_units:
units = [a.units for a in self.axes_manager.navigation_axes] + self.units
units = np.append(
[a.units for a in self.axes_manager.navigation_axes], self.units
)
else:
units = ["pixels"] * len(self.axes_manager.navigation_axes) + self.units
units = np.append(
["pixels"] * len(self.axes_manager.navigation_axes), self.units
)

return DiffractionVectors2D(
vectors,
Expand Down Expand Up @@ -817,11 +886,14 @@ def plot_diffraction_vectors_on_signal(self, signal, *args, **kwargs):
)
signal.add_marker(marker, plot_marker=True, permanent=False)

def get_magnitudes(self, *args, **kwargs):
def get_magnitudes(self, columns=None, *args, **kwargs):
"""Calculate the magnitude of diffraction vectors.
Parameters
----------
columns : list, optional
The columns of the diffraction vectors to be used to calculate
the magnitude. If not given, the first two columns will be used.
*args:
Arguments to be passed to map().
**kwargs:
Expand All @@ -835,9 +907,13 @@ def get_magnitudes(self, *args, **kwargs):
navigation position.
"""
magnitudes = self.map(
np.linalg.norm, inplace=False, axis=-1, ragged=True, *args, **kwargs
)
if columns is None:
columns = [0, 1]

def get_magnitude(x):
return np.linalg.norm(x[:, columns], axis=-1)

magnitudes = self.map(get_magnitude, inplace=False, *args, **kwargs)

return magnitudes

Expand Down Expand Up @@ -911,17 +987,27 @@ def cluster(
if columns is None:
columns = list(range(self.data.shape[-1]))

if self.ragged:
signal_shape = ()
dtype = object
else:
signal_shape = self.axes_manager._signal_shape_in_array
signal_shape = signal_shape[:-1] + (signal_shape[-1] + 1,)
dtype = float
new_signal = self.map(
cluster,
inplace=False,
method=method,
columns=columns,
column_scale_factors=column_scale_factors,
min_vectors=min_vectors,
ragged=self.ragged,
remove_nan=remove_nan,
output_signal_size=signal_shape,
output_dtype=dtype,
)
new_signal.column_names = self.column_names + ["cluster"]
new_signal.units = self.units + ["n.a."]
new_signal.column_names = np.append(self.column_names, ["cluster"])
new_signal.units = np.append(self.units, ["n.a."])

if not self.has_navigation_axis:
new_signal.set_signal_type("labeled_diffraction_vectors")
Expand Down Expand Up @@ -1000,13 +1086,16 @@ def filter_magnitude(self, min_magnitude, max_magnitude, *args, **kwargs):
filtered_vectors : DiffractionVectors
Diffraction vectors within allowed magnitude tolerances.
"""
# If ragged the signal axes will not be defined

if self.ragged:
kwargs["output_signal_size"] = ()
kwargs["output_dtype"] = object

filtered_vectors = self.map(
filter_vectors_ragged,
min_magnitude=min_magnitude,
max_magnitude=max_magnitude,
inplace=False,
ragged=True,
*args,
**kwargs,
)
Expand Down Expand Up @@ -1097,15 +1186,30 @@ def filter_detector_edge(self, exclude_width, *args, **kwargs):
)
return filtered_vectors

def to_polar(self):
def to_polar(self, columns=None, **kwargs):
"""Convert the diffraction vectors to polar coordinates.
Parameters
----------
columns : list
The columns of the diffraction vectors to be converted to polar
coordinates. The default is the first two columns (kx, ky) in most
cases.
kwargs : dict
Any other parameters passed to the `hyperspy.signal.BaseSignal.map` function.
Returns
-------
polar_vectors : DiffractionVectors
Diffraction vectors in polar coordinates.
"""
polar_vectors = self.map(vectors_to_polar, inplace=False, ragged=True)
polar_vectors = self.map(
vectors_to_polar,
inplace=False,
ragged=self.ragged,
columns=columns,
**kwargs,
)
polar_vectors.set_signal_type("polar_vectors")
polar_vectors.column_names[0] = "r"
polar_vectors.column_names[1] = "theta"
Expand Down

0 comments on commit 9d1ff18

Please sign in to comment.