-
Notifications
You must be signed in to change notification settings - Fork 84
/
insitu_diffraction2d.py
381 lines (329 loc) · 12.6 KB
/
insitu_diffraction2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
# -*- coding: utf-8 -*-
# Copyright 2016-2023 The pyXem developers
#
# This file is part of pyXem.
#
# pyXem is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# pyXem is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with pyXem. If not, see <http://www.gnu.org/licenses/>.
from hyperspy.signals import Signal1D
from pyxem.signals import Diffraction2D
from hyperspy._signals.lazy import LazySignal
import numpy as np
from hyperspy.roi import RectangularROI
import dask.array as da
from dask.graph_manipulation import clone
from pyxem.utils.dask_tools import _get_dask_array, _get_chunking
from pyxem.utils.insitu_utils import (
_register_drift_5d,
_register_drift_2d,
_g2_2d,
_interpolate_g2_2d,
_get_resample_time,
)
import pyxem.utils.pixelated_stem_tools as pst
class InSituDiffraction2D(Diffraction2D):
"""Signal class for in-situ 4D-STEM data.
Parameters
----------
*args:
See :class:`~hyperspy._signals.signal2d.Signal2D`.
**kwargs:
See :class:`~hyperspy._signals.signal2d.Signal2D`
"""
_signal_type = "insitu_diffraction"
def roll_time_axis(self, time_axis):
"""Roll time axis to default index (2)"""
return self.rollaxis(time_axis, 2)
def get_time_series(self, roi=None, time_axis=2):
"""Create a intensity time series from virtual aperture defined by roi.
Parameters
----------
roi: :obj:`~hyperspy.roi.BaseInteractiveROI`
Roi for virtual detector. If None, full roi of diffraction plane is used
time_axis: int
Index of time axis. Default is 2
Returns
---------
virtual_series: Signal2D
Time series of virtual detector images
"""
out_axes = [0, 1, 2]
out_axes.remove(time_axis)
if roi is None:
roi = RectangularROI(
self.axes_manager.signal_extent[0],
self.axes_manager.signal_extent[2],
self.axes_manager.signal_extent[1],
self.axes_manager.signal_extent[3],
)
virtual_series = self.get_integrated_intensity(roi, out_signal_axes=out_axes)
virtual_series.metadata.General.title = "Integrated intensity time series"
return virtual_series
def get_drift_vectors(
self, time_axis=2, reference="cascade", sub_pixel_factor=10, **kwargs
):
"""Calculate real space drift vectors from time series of images
Parameters
----------
s: :class:`~hyperspy.api.signals.Signal2D`
Time series of reconstructed images
reference: 'current', 'cascade', or 'stat'
reference argument passed to :meth:`~hyperspy.api.signals.Signal2D.estimate_shift2D`
function. Default is 'cascade'
sub_pixel_factor: float
sub_pixel_factor passed to :meth:`~hyperspy.api.signals.Signal2D.estimate_shift2D`
function. Default is 10
**kwargs:
Passed to the :meth:`~pyxem.signals.InSituDiffraction2D.get_time_series` function
Returns
-------
shift_vectors
"""
roi = kwargs.pop("roi", None)
ref = self.get_time_series(roi=roi, time_axis=time_axis)
s = ref.estimate_shift2D(
reference=reference, sub_pixel_factor=sub_pixel_factor, **kwargs
)
shift_vectors = Signal1D(s)
pst._copy_axes_object_metadata(
self.axes_manager.navigation_axes[time_axis],
shift_vectors.axes_manager.navigation_axes[0],
)
return shift_vectors
def correct_real_space_drift(
self, shifts=None, time_axis=2, order=1, lazy_result=True
):
"""
Perform real space drift registration on the dataset.
Parameters
----------
shifts: Signal1D
shift vectors to register, must be in the shape of <N_time | 2>.
If None, shift vectors will be calculated automatically
time_axis: int
Index of time axis. Default is 2
lazy_result: bool, default True
Whether to return lazy result.
order: int
The order of the spline interpolation for registration. Default is 1
Returns
---------
registered_data: InSituDiffraction2D
Real space drift corrected version of the original dataset
"""
if shifts is None:
shifts = self.get_drift_vectors(time_axis=time_axis)
if time_axis != 2:
s_ = self.roll_time_axis(time_axis)
else:
s_ = self
dask_data = _get_dask_array(s_)
if self._lazy:
time_chunks = s_.get_chunk_size()[0][0]
else:
time_chunks = _get_chunking(s_)[0][0]
xdrift = shifts.data[:, 0]
ydrift = shifts.data[:, 1]
xdrift_dask = da.from_array(
xdrift[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis],
chunks=(time_chunks, 1, 1, 1, 1),
)
ydrift_dask = da.from_array(
ydrift[:, np.newaxis, np.newaxis, np.newaxis, np.newaxis],
chunks=(time_chunks, 1, 1, 1, 1),
)
depthx = np.ceil(np.max(np.abs(xdrift))).astype(int)
depthy = np.ceil(np.max(np.abs(ydrift))).astype(int)
overlapped_depth = {0: 0, 1: depthx, 2: depthy, 3: 0, 4: 0}
data_overlapped = da.overlap.overlap(
dask_data, depth=overlapped_depth, boundary={a: "none" for a in range(5)}
)
# Clone original overlap dask array to work around memory release issue in map_overlap
data_clones = da.concatenate(
[clone(b, omit=data_overlapped) for b in data_overlapped.blocks]
)
mapped = data_clones.map_blocks(
_register_drift_5d,
shifts1=xdrift_dask,
shifts2=ydrift_dask,
order=order,
dtype="float32",
)
registered_data = InSituDiffraction2D(
da.overlap.trim_internal(mapped, overlapped_depth)
).as_lazy()
# Set axes info for registered signal
for nav_axis_old, nav_axis_new in zip(
s_.axes_manager.navigation_axes,
registered_data.axes_manager.navigation_axes,
):
pst._copy_axes_object_metadata(nav_axis_old, nav_axis_new)
for sig_axis_old, sig_axis_new in zip(
s_.axes_manager.signal_axes, registered_data.axes_manager.signal_axes
):
pst._copy_axes_object_metadata(sig_axis_old, sig_axis_new)
if not lazy_result:
registered_data.compute()
return registered_data
def correct_real_space_drift_fast(
self, shifts=None, time_axis=2, order=1, **kwargs
):
"""
Perform real space drift registration on the dataset with fast performance
over spatial axes. If signal is lazy, spatial axes must not be chunked
Parameters
----------
shifts: Signal1D
shift vectors to register, must be in the shape of <N_time | 2>.
If None, shift vectors will be calculated automatically
time_axis: int
Index of time axis. Default is 2
order: int
The order of the spline interpolation for registration. Default is 1
**kwargs:
Passed to :meth:`~hyperspy.signal.BaseSignal.map`
Returns
---------
registered_data: InSituDiffraction2D
Real space drift corrected version of the original dataset
"""
if self._lazy:
nav_axes = [0, 1, 2]
nav_axes.remove(2 - time_axis)
chunkings = self.get_chunk_size()
if len(chunkings[nav_axes[0]]) != 1 or len(chunkings[nav_axes[1]]) != 1:
raise Exception(
"Spatial axes are chunked. Please rechunk signal or use 'correct_real_space_drift' "
"instead"
)
if shifts is None:
shifts = self.get_drift_vectors(time_axis=time_axis)
if time_axis != 2:
s_ = self.roll_time_axis(time_axis=time_axis)
else:
s_ = self
s_transposed = s_.transpose(signal_axes=(0, 1))
xdrift = shifts.data[:, 0]
ydrift = shifts.data[:, 1]
xs = Signal1D(
np.repeat(
np.repeat(
xdrift[:, np.newaxis, np.newaxis],
repeats=s_transposed.axes_manager.navigation_axes[0].size,
axis=-1,
),
repeats=s_transposed.axes_manager.navigation_axes[1].size,
axis=1,
)[:, :, :, np.newaxis]
)
ys = Signal1D(
np.repeat(
np.repeat(
ydrift[:, np.newaxis, np.newaxis],
repeats=s_transposed.axes_manager.navigation_axes[0].size,
axis=-1,
),
repeats=s_transposed.axes_manager.navigation_axes[1].size,
axis=1,
)[:, :, :, np.newaxis]
)
registered_data = s_transposed.map(
_register_drift_2d,
shift1=xs,
shift2=ys,
order=order,
inplace=False,
**kwargs
)
registered_data_t = registered_data.transpose(navigation_axes=[-2, -1, -3])
registered_data_t.set_signal_type("insitu_diffraction")
return registered_data_t
def get_g2_2d_kresolved(
self,
time_axis=2,
normalization="split",
k1bin=1,
k2bin=1,
tbin=1,
resample_time=None,
):
"""
Calculate k resolved g2 from in situ diffraction signal
Parameters
----------
time_axis: int
Index of time axis. Default is 2
normalization: string, Default is 'split'
Normalization format for time autocorrelation, 'split' or 'self'
k1bin: int
Binning factor for k1 axis
k2bin: int
Binning factor for k2 axis
tbin: int
Binning factor for t axis
resample_time: int or np.array, Default is None
If int, time is resample into log linear with resample_time as
number of sampling. If array, it is used as resampled time axis
instead. No resampling is performed if None
Returns
---------
g2kt: Signal2D or Correlation2D
k resolved time correlation signal
"""
if time_axis != 2:
transposed_signal = self.roll_time_axis(time_axis).transpose(
navigation_axes=[0, 1]
)
else:
transposed_signal = self.transpose(navigation_axes=[0, 1])
g2kt = transposed_signal.map(
_g2_2d,
normalization=normalization,
k1bin=k1bin,
k2bin=k2bin,
tbin=tbin,
inplace=False,
)
if resample_time is not None:
if isinstance(resample_time, int):
trs = _get_resample_time(
t_size=transposed_signal.axes_manager.signal_axes[-1].size / tbin,
dt=transposed_signal.axes_manager.signal_axes[-1].scale * tbin,
t_rs_size=resample_time,
)
g2rs = g2kt.map(
_interpolate_g2_2d,
t_rs=trs,
dt=transposed_signal.axes_manager.signal_axes[-1].scale * tbin,
inplace=False,
)
g2rs.set_signal_type("correlation")
return g2rs
if (
isinstance(resample_time, (list, tuple, np.ndarray))
and len(np.shape(resample_time)) == 1
):
g2rs = g2kt.map(
_interpolate_g2_2d,
t_rs=resample_time / tbin,
dt=transposed_signal.axes_manager.signal_axes[-1].scale * tbin,
inplace=False,
)
g2rs.set_signal_type("correlation")
return g2rs
else:
raise TypeError("'resample_time' must be int or 1d array")
g2kt.set_signal_type("correlation")
return g2kt
class LazyInSituDiffraction2D(LazySignal, InSituDiffraction2D):
pass