/
_similarity_metric.py
254 lines (217 loc) · 8.62 KB
/
_similarity_metric.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
# Copyright 2019-2023 The kikuchipy developers
#
# This file is part of kikuchipy.
#
# kikuchipy is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# kikuchipy is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with kikuchipy. If not, see <http://www.gnu.org/licenses/>.
import abc
from typing import List, Optional, Union
import numpy as np
class SimilarityMetric(abc.ABC):
"""Abstract class implementing a similarity metric to match
experimental and simulated EBSD patterns in a dictionary.
For use in :meth:`~kikuchipy.signals.EBSD.dictionary_indexing` or
directly on pattern arrays if a :meth:`__call__` method is
implemented. Note that `dictionary_indexing()` will always reshape
the dictionary pattern array to 2D (1 navigation dimension, 1 signal
dimension) before calling :meth:`prepare_dictionary` and
:meth:`match`.
Take a look at the implementation of
:class:`~kikuchipy.indexing.NormalizedCrossCorrelationMetric` for
how to write a concrete custom metric.
When writing a custom similarity metric class, the methods listed as
`abstract` below must be implemented. Any number of custom
parameters can be passed. Also listed are the attributes available
to the methods if set properly during initialization or after.
Parameters
----------
n_experimental_patterns
Number of experimental patterns. If not given, this is set
to ``None`` and must be set later. Must be at least ``1``.
n_dictionary_patterns
Number of dictionary patterns. If not given, this is set to
``None`` and must be set later. Must be at least ``1``.
navigation_mask
A boolean mask equal to the experimental patterns' navigation
(map) shape, where only patterns equal to ``False`` are matched.
If not given, all patterns are used.
signal_mask
A boolean mask equal to the experimental patterns' detector
shape ``(n rows, n columns)``, where only pixels equal to
``False`` are matched. If not given, all pixels are used.
dtype
Which data type to cast the patterns to before matching to.
rechunk
Whether to allow rechunking of arrays before matching.
Default is ``False``.
"""
_allowed_dtypes: List[type] = []
_sign: Optional[int] = None
def __init__(
self,
n_experimental_patterns: Optional[int] = None,
n_dictionary_patterns: Optional[int] = None,
navigation_mask: Optional[np.ndarray] = None,
signal_mask: Optional[np.ndarray] = None,
dtype: Union[str, np.dtype, type] = "float32",
rechunk: bool = False,
):
"""Create a similarity metric matching experimental and
simulated EBSD patterns in a dictionary.
"""
self._n_experimental_patterns = n_experimental_patterns
self._n_dictionary_patterns = n_dictionary_patterns
self._navigation_mask = navigation_mask
self._signal_mask = signal_mask
self._dtype = np.dtype(dtype)
self._rechunk = rechunk
def __repr__(self):
string = f"{self.__class__.__name__}: {np.dtype(self.dtype).name}, "
sign_string = {1: "greater is better", -1: "lower is better"}
string += sign_string[self.sign]
string += f", rechunk: {self.rechunk}, "
string += f"navigation mask: {self.navigation_mask is not None}, "
string += f"signal mask: {self.signal_mask is not None}"
return string
@property
def allowed_dtypes(self) -> List[type]:
"""Return the list of allowed array data types used during
matching.
"""
return self._allowed_dtypes
@property
def dtype(self) -> np.dtype:
"""Return or set which data type to cast the patterns to before
matching.
Parameters
----------
value
Data type listed in :attr:`allowed_dtypes`.
"""
return self._dtype
@dtype.setter
def dtype(self, value: Union[str, np.dtype, type]):
"""Set which data type to cast the patterns to before
matching.
"""
self._dtype = np.dtype(value)
@property
def n_dictionary_patterns(self) -> int:
"""Return or set the number of dictionary patterns to match.
This information might be necessary when reshaping the
dictionary array in :meth:`prepare_dictionary`.
Parameters
----------
value
Number of dictionary patterns to match.
"""
return self._n_dictionary_patterns
@n_dictionary_patterns.setter
def n_dictionary_patterns(self, value: int):
"""Set the number of dictionary patterns to match."""
self._n_dictionary_patterns = value
@property
def n_experimental_patterns(self) -> int:
"""Return or set the number of experimental patterns to match.
This information might be necessary when reshaping the
experimental array in :meth:`prepare_experimental`.
Parameters
----------
value
Number of experimental patterns to match.
"""
return self._n_experimental_patterns
@n_experimental_patterns.setter
def n_experimental_patterns(self, value: int):
"""Set the number of experimental patterns to match."""
self._n_experimental_patterns = value
@property
def navigation_mask(self) -> np.ndarray:
"""Return or set the boolean mask of patterns to match, equal to
the navigation (map) shape.
Parameters
----------
value
Navigation mask where points set to ``False`` are matched.
"""
return self._navigation_mask
@navigation_mask.setter
def navigation_mask(self, value: np.ndarray):
"""Set the boolean mask of patterns to match, equal to the
navigation (map) shape.
"""
self._navigation_mask = value
@property
def signal_mask(self) -> np.ndarray:
"""Return or set the boolean mask equal to the experimental
patterns' detector shape ``(s rows, s columns)``.
Parameters
----------
value
Signal mask where pixels set to ``False`` are matched.
"""
return self._signal_mask
@signal_mask.setter
def signal_mask(self, value: np.ndarray):
"""Set the boolean mask equal to the experimental patterns'
detector shape ``(s rows, s columns)``.
"""
self._signal_mask = value
@property
def sign(self) -> int:
"""Return the sign signifying whether a greater match is better,
either +1 (greater is better) or -1 (lower is better).
"""
return self._sign
@property
def rechunk(self) -> bool:
"""Return or set whether to allow rechunking of arrays before
matching.
Parameters
----------
value
Whether to allow rechunking of arrays before matching.
"""
return self._rechunk
@rechunk.setter
def rechunk(self, value: bool):
"""Set whether to allow rechunking of arrays before matching."""
self._rechunk = value
@abc.abstractmethod
def prepare_dictionary(self, *args, **kwargs):
"""Prepare dictionary patterns before matching to experimental
patterns in :meth:`match`.
"""
return NotImplemented # pragma: no cover
@abc.abstractmethod
def prepare_experimental(self, *args, **kwargs):
"""Prepare experimental patterns before matching to dictionary
patterns in :meth:`match`.
"""
return NotImplemented # pragma: no cover
@abc.abstractmethod
def match(self, *args, **kwargs):
"""Match all experimental patterns to all dictionary patterns
and return their similarities.
"""
return NotImplemented # pragma: no cover
def raise_error_if_invalid(self):
"""Raise a ValueError if :attr:`dtype` is not among
:attr:`allowed_dtypes` and the latter is not an empty list.
"""
allowed_dtypes = self.allowed_dtypes
if len(allowed_dtypes) != 0 and self.dtype not in allowed_dtypes:
raise ValueError(
f"Data type {self.dtype} not among supported data types "
f"{allowed_dtypes}"
)