/
_merge_crystal_maps.py
227 lines (198 loc) · 8.95 KB
/
_merge_crystal_maps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
# -*- coding: utf-8 -*-
# Copyright 2019-2021 The kikuchipy developers
#
# This file is part of kikuchipy.
#
# kikuchipy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# kikuchipy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with kikuchipy. If not, see <http://www.gnu.org/licenses/>.
from math import copysign
from typing import List, Optional, Tuple, Union
import warnings
import numpy as np
from orix.crystal_map import CrystalMap, PhaseList
from orix.quaternion.rotation import Rotation
from kikuchipy.indexing.similarity_metrics import SimilarityMetric, _SIMILARITY_METRICS
def merge_crystal_maps(
crystal_maps: List[CrystalMap],
mean_n_best: int = 1,
metric: Union[str, SimilarityMetric] = None,
scores_prop: str = "scores",
simulation_indices_prop: Optional[str] = None,
):
"""Merge a list of at least two single phase
:class:`~orix.crystal_map.crystal_map.CrystalMap` with a 1D or 2D
navigation shape into one multi phase map.
It is required that all maps have the same number of rotations and
scores (and simulation indices if applicable) per point.
Parameters
----------
crystal_maps : list of\
:class:`~orix.crystal_map.crystal_map.CrystalMap`
A list of crystal maps with simulated indices and scores among
their properties.
mean_n_best : int, optional
Number of best metric results to take the mean of before
comparing. Default is 1.
metric : str or SimilarityMetric, optional
Similarity metric, default is None.
scores_prop : str, optional
Name of scores array in the crystal maps' properties. Default
is "scores".
simulation_indices_prop : str, optional
Name of simulated indices array in the crystal maps' properties.
If None (default), the merged crystal map will not contain
an array of merged simulation indices from the input crystal
maps' properties. If a string, there must be as many simulation
indices per point as there are scores.
Returns
-------
merged_xmap : ~orix.crystal_map.crystal_map.CrystalMap
A crystal map where the rotation of the phase with the best
matching score(s) is assigned to each point. The best matching
scores, merge sorted, are added to its properties with a name
equal to whatever passed to `scores_prop` with "merged" as a
suffix. If `simulation_indices_prop` is passed, the best
matching simulation indices are added in the same way as the
scores.
Notes
-----
`mean_n_best` can be given with a negative sign if `metric` is not
given, in order to choose the lowest valued metric results.
"""
map_shapes = [xmap.shape for xmap in crystal_maps]
if not np.sum(abs(np.diff(map_shapes, axis=0))) == 0:
raise ValueError("All crystal maps must have the same navigation shape")
rot_per_point_per_map = [xmap.rotations_per_point for xmap in crystal_maps]
if not all(np.diff(rot_per_point_per_map) == 0):
raise ValueError(
"All crystal maps must have the same number of rotations and scores"
" per point."
)
if metric is None:
sign = copysign(1, mean_n_best)
mean_n_best = abs(mean_n_best)
else:
sign = _SIMILARITY_METRICS.get(metric, metric).sign
# Notation used in the comments below:
# - M: number of map points
# - N: number of scores per point
# - I: number of simulation indices per point
# - K: number of maps to merge
# Shape of the combined (unsorted) scores array, and the total
# number of scores per point. Shape: (M, N, K) or (M, K) if only one
# score is available (e.g. refined dot products from EMsoft)
(comb_shape, n_scores_per_point) = _get_combined_scores_shape(
crystal_maps=crystal_maps, scores_prop=scores_prop
)
# Combined (unsorted) scores array of shape (M, N, K) or (M, K)
combined_scores = np.dstack([xmap.prop[scores_prop] for xmap in crystal_maps])
combined_scores = combined_scores.reshape(comb_shape)
# Best score in each map point
if n_scores_per_point > 1: # (M, N, K)
best_scores = np.nanmean(combined_scores[:, :mean_n_best], axis=1)
else: # (M, K)
best_scores = combined_scores
# Phase of best score in each map point
phase_id = np.argmax(sign * best_scores, axis=1)
# Get the new CrystalMap's rotations, scores and indices, restricted
# to one phase per point (uncombined)
new_rotations = Rotation(np.zeros_like(crystal_maps[0].rotations.data))
new_scores = np.zeros_like(crystal_maps[0].prop[scores_prop])
if simulation_indices_prop is not None:
new_indices = np.zeros_like(crystal_maps[0].prop[simulation_indices_prop])
phase_list = PhaseList()
for i, xmap in enumerate(crystal_maps):
mask = phase_id == i
new_rotations[mask] = xmap.rotations[mask]
new_scores[mask] = xmap.prop[scores_prop][mask]
if simulation_indices_prop is not None:
new_indices[mask] = xmap.prop[simulation_indices_prop][mask]
if np.sum(mask) != 0:
current_id = xmap.phases_in_data.ids[0]
phase = xmap.phases_in_data[current_id].deepcopy()
try:
phase_list.add(phase)
except ValueError:
name = phase.name
warnings.warn(
f"There are duplicates of phase {name}, will therefore "
f"rename this phase's name to {name + str(i)} in the merged"
" PhaseList",
)
phase.name = name + str(i)
phase_list.add(phase)
# To get the combined, best, sorted scores and simulation indices
# from all maps (phases), we collapse the second and (potentially)
# third axis to get (M, N * K) or (M, K)
mergesort_shape = (comb_shape[0], np.prod(comb_shape[1:]))
comb_scores_reshaped = combined_scores.reshape(mergesort_shape)
best_sorted_idx = np.argsort(sign * -comb_scores_reshaped, kind="mergesort", axis=1)
# Best, sorted scores in all maps (for all phases) per point
merged_best_scores = np.take_along_axis(
comb_scores_reshaped, best_sorted_idx, axis=-1
)
# Set up merged map's properties
props = {scores_prop: new_scores, f"merged_{scores_prop}": merged_best_scores}
if simulation_indices_prop is not None:
# Combined (unsorted) simulation indices array of shape
# (M, N, K) or (M, K), accounting for the case where there are
# more simulation indices per point than scores (e.g. refined
# dot products from EMsoft)
comb_sim_idx = np.dstack(
[xmap.prop[simulation_indices_prop] for xmap in crystal_maps]
)
if comb_sim_idx.size != np.prod(mergesort_shape):
raise ValueError(
"Cannot merge maps with more simulation indices than scores per point"
)
# To enable calculation of an orientation similarity map from
# the combined, sorted simulation indices array, we must make
# the indices unique across all maps
for i in range(1, comb_sim_idx.shape[-1]):
increment = (
abs(comb_sim_idx[..., i - 1].max() - comb_sim_idx[..., i].min()) + 1
)
comb_sim_idx[..., i] += increment
# Collapse axes as for the combined scores array above
comb_sim_idx = comb_sim_idx.reshape(mergesort_shape)
# Best, sorted simulation indices in all maps (for all phases)
# per point
merged_simulated_indices = np.take_along_axis(
comb_sim_idx, best_sorted_idx, axis=-1
)
# Finally, add to properties
props[simulation_indices_prop] = new_indices
props[f"merged_{simulation_indices_prop}"] = merged_simulated_indices
return CrystalMap(
rotations=new_rotations,
phase_id=phase_id,
phase_list=phase_list,
x=crystal_maps[0].x,
y=crystal_maps[0].y,
z=crystal_maps[0].z,
prop=props,
scan_unit=crystal_maps[0].scan_unit,
)
def _get_combined_scores_shape(
crystal_maps: List[CrystalMap], scores_prop: str = "scores"
) -> Tuple[tuple, int]:
xmap = crystal_maps[0]
all_scores_shape = (xmap.size,)
single_scores_shape = xmap.prop[scores_prop].shape
if len(single_scores_shape) == 1:
n_scores_per_point = 1
else:
n_scores_per_point = single_scores_shape[1]
all_scores_shape += (single_scores_shape[-1],)
all_scores_shape += (len(crystal_maps),)
return all_scores_shape, n_scores_per_point