-
Notifications
You must be signed in to change notification settings - Fork 4
/
kbnufft.py
298 lines (266 loc) · 12.5 KB
/
kbnufft.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
import warnings
import numpy as np
import tensorflow as tf
# from .functional.kbnufft import AdjKbNufftFunction, KbNufftFunction
# ToepNufftFunction)
from .kbmodule import KbModule
from .nufft.fft_functions import scale_and_fft_on_image_volume, ifft_and_scale_on_gridded_data
from .nufft.interp_functions import kbinterp, adjkbinterp
from .nufft.utils import build_spmatrix, build_table, compute_scaling_coefs
from .utils.itertools import cartesian_product
class KbNufftModule(KbModule):
"""Parent class for KbNufft classes.
This implementation collects all init functions into one place.
Args:
im_size (int or tuple of ints): Size of base image.
grid_size (int or tuple of ints, default=2*im_size): Size of the grid
to interpolate from.
numpoints (int or tuple of ints, default=6): Number of points to use
for interpolation in each dimension. Default is six points in each
direction.
n_shift (int or tuple of ints, default=im_size//2): Number of points to
shift for fftshifts.
table_oversamp (int, default=2^10): Table oversampling factor.
kbwidth (double, default=2.34): Kaiser-Bessel width parameter.
order (double, default=0): Order of Kaiser-Bessel kernel.
norm (str, default='None'): Normalization for FFT. Default uses no
normalization. Use 'ortho' to use orthogonal FFTs and preserve
energy.
"""
def __init__(self, im_size, grid_size=None, numpoints=6, n_shift=None,
table_oversamp=2**10, kbwidth=2.34, order=0, norm='None',
coil_broadcast=False, matadj=False, grad_traj=False):
super(KbNufftModule, self).__init__()
self.im_size = im_size
self.im_rank = len(im_size)
self.grad_traj = grad_traj
if self.grad_traj:
warnings.warn('The gradient w.r.t trajectory is Experimental and WIP. '
'Please use with caution')
if grid_size is None:
self.grid_size = tuple(np.array(self.im_size) * 2)
else:
self.grid_size = grid_size
if n_shift is None:
self.n_shift = tuple(np.array(self.im_size) // 2)
else:
self.n_shift = n_shift
if isinstance(numpoints, int):
self.numpoints = (numpoints,) * len(self.grid_size)
else:
self.numpoints = numpoints
self.alpha = tuple(np.array(kbwidth) * np.array(self.numpoints))
if isinstance(order, int) or isinstance(order, float):
self.order = (order,) * len(self.grid_size)
else:
self.order = order
if isinstance(table_oversamp, float) or isinstance(table_oversamp, int):
self.table_oversamp = (table_oversamp,) * len(self.grid_size)
else:
self.table_oversamp = table_oversamp
# dimension checking
assert len(self.grid_size) == len(self.im_size)
assert len(self.n_shift) == len(self.im_size)
assert len(self.numpoints) == len(self.im_size)
assert len(self.alpha) == len(self.im_size)
assert len(self.order) == len(self.im_size)
assert len(self.table_oversamp) == len(self.im_size)
table = build_table(
numpoints=self.numpoints,
table_oversamp=self.table_oversamp,
grid_size=self.grid_size,
im_size=self.im_size,
ndims=len(self.im_size),
order=self.order,
alpha=self.alpha
)
self.table = table
assert len(self.table) == len(self.im_size)
scaling_coef = compute_scaling_coefs(
im_size=self.im_size,
grid_size=self.grid_size,
numpoints=self.numpoints,
alpha=self.alpha,
order=self.order
)
self.scaling_coef = scaling_coef
self.norm = norm
self.coil_broadcast = coil_broadcast
self.matadj = matadj
if coil_broadcast == True:
warnings.warn(
'coil_broadcast will be deprecated in a future release',
DeprecationWarning)
if matadj == True:
warnings.warn(
'matadj will be deprecated in a future release',
DeprecationWarning)
self.scaling_coef_tensor = tf.convert_to_tensor(self.scaling_coef)
self.table_tensors = []
for item in self.table:
self.table_tensors.append(tf.convert_to_tensor(item))
# register buffer is not necessary in tf, you just have the variable in
# your class, point.
self.n_shift_tensor = tf.convert_to_tensor(np.array(self.n_shift, dtype=np.int64))
self.grid_size_tensor = tf.convert_to_tensor(np.array(self.grid_size, dtype=np.int64))
self.im_size_tensor = tf.convert_to_tensor(np.array(self.im_size, dtype=np.int64))
self.numpoints_tensor = tf.convert_to_tensor(np.array(self.numpoints, dtype=np.double))
self.table_oversamp_tensor = tf.convert_to_tensor(np.array(self.table_oversamp, dtype=np.double))
def _extract_nufft_interpob(self):
"""Extracts interpolation object from self.
Returns:
dict: An interpolation object for the NUFFT operation.
"""
interpob = dict()
interpob['scaling_coef'] = self.scaling_coef_tensor
interpob['table'] = self.table_tensors
interpob['n_shift'] = self.n_shift_tensor
interpob['grid_size'] = self.grid_size_tensor
interpob['im_size'] = self.im_size_tensor
interpob['im_rank'] = self.im_rank
interpob['numpoints'] = self.numpoints_tensor
interpob['table_oversamp'] = self.table_oversamp_tensor
interpob['norm'] = self.norm
interpob['coil_broadcast'] = self.coil_broadcast
interpob['matadj'] = self.matadj
interpob['grad_traj'] = self.grad_traj
Jgen = []
for i in range(self.im_rank):
# number of points to use for interpolation is numpoints
Jgen.append(np.arange(self.numpoints[i]))
Jgen = cartesian_product(Jgen)
interpob['Jlist'] = Jgen.astype('int64')
return interpob
def kbnufft_forward(interpob, multiprocessing=False):
@tf.function(experimental_relax_shapes=True)
@tf.custom_gradient
def kbnufft_forward_for_interpob(x, om):
"""Apply FFT and interpolate from gridded data to scattered data.
Inputs are assumed to be batch/chans x coil x image dims.
Om should be nbatch x ndims x klength.
Args:
x (tensor): The original imagel.
om (tensor, optional): A new set of omega coordinates at which to
calculate the signal in radians/voxel.
Returns:
tensor: x computed at off-grid locations in om.
"""
# this is with registered gradient, I would like to try without
# y = KbNufftFunction.apply(x, om, interpob, interp_mats)
# extract interpolation params
scaling_coef = interpob['scaling_coef']
grid_size = interpob['grid_size']
im_size = interpob['im_size']
norm = interpob['norm']
grad_traj = interpob['grad_traj']
im_rank = interpob.get('im_rank', 2)
fft_x = scale_and_fft_on_image_volume(
x, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing)
y = kbinterp(fft_x, om, interpob)
def grad(dy):
# Gradients with respect to image
grid_dy = adjkbinterp(dy, om, interpob)
ifft_dy = ifft_and_scale_on_gridded_data(
grid_dy, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
if grad_traj:
# Gradients with respect to trajectory locations
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), x.dtype)[None, ...]
fft_dx_dom = scale_and_fft_on_image_volume(
x * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
# Do this when handling batches
fft_dx_dom = tf.reshape(fft_dx_dom, shape=(-1, 1, *fft_dx_dom.shape[2:]))
nufft_dx_dom = kbinterp(fft_dx_dom, tf.repeat(om, im_rank, axis=0), interpob)
# Unbatch back the data
nufft_dx_dom = tf.reshape(nufft_dx_dom, shape=(-1, im_rank, *nufft_dx_dom.shape[2:]))
dy_dom = tf.cast(-1j * tf.math.conj(dy) * nufft_dx_dom, om.dtype)
# dy_dom = tf.math.reduce_sum(dy_dom, axis=1)[None, :]
else:
dy_dom = None
return ifft_dy, dy_dom
return y, grad
return kbnufft_forward_for_interpob
def kbnufft_adjoint(interpob, multiprocessing=False):
@tf.function(experimental_relax_shapes=True)
@tf.custom_gradient
def kbnufft_adjoint_for_interpob(y, om):
"""Interpolate from scattered data to gridded data and then iFFT.
Inputs are assumed to be batch/chans x coil x kspace
length. Om should be nbatch x ndims x klength.
Args:
y (tensor): The off-grid signal.
om (tensor, optional): The off-grid coordinates in radians/voxel.
Returns:
tensor: The image after adjoint NUFFT.
"""
grid_y = adjkbinterp(y, om, interpob)
scaling_coef = interpob['scaling_coef']
grid_size = interpob['grid_size']
im_size = interpob['im_size']
norm = interpob['norm']
grad_traj = interpob['grad_traj']
im_rank = interpob.get('im_rank', 2)
ifft_y = ifft_and_scale_on_gridded_data(
grid_y, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, multiprocessing=multiprocessing)
def grad(dx):
# Gradients with respect to off grid signal
fft_dx = scale_and_fft_on_image_volume(
dx, scaling_coef, grid_size, im_size, norm, im_rank=im_rank)
dx_dy = kbinterp(fft_dx, om, interpob)
if grad_traj:
# Gradients with respect to trajectory locations
r = [tf.linspace(-im_size[i]/2, im_size[i]/2-1, im_size[i]) for i in range(im_rank)]
# This wont work for multicoil case as the dimension for dx is `batch_size x coil x Nx x Ny`
grid_r = tf.cast(tf.meshgrid(*r, indexing='ij'), dx.dtype)[None, ...]
ifft_dxr = scale_and_fft_on_image_volume(
tf.math.conj(dx) * grid_r, scaling_coef, grid_size, im_size, norm, im_rank=im_rank, do_ifft=True)
# Do this when handling batches
ifft_dxr = tf.reshape(ifft_dxr, shape=(-1, 1, *ifft_dxr.shape[2:]))
inufft_dxr = kbinterp(ifft_dxr, tf.repeat(om, im_rank, axis=0), interpob, conj=True)
# Unbatch back the data
inufft_dxr = tf.reshape(inufft_dxr, shape=(-1, im_rank, *inufft_dxr.shape[2:]))
dx_dom = tf.cast(1j * y * inufft_dxr, om.dtype)
# dx_dom = tf.math.reduce_sum(dx_dom, axis=1)[None, :]
else:
dx_dom = None
return dx_dy, dx_dom
return ifft_y, grad
return kbnufft_adjoint_for_interpob
# class ToepNufft(KbModule):
# """Forward/backward NUFFT with Toeplitz embedding.
#
# This module applies Tx, where T is a matrix such that T = A'A, where A is
# a NUFFT matrix. Using Toeplitz embedding, this module computes the A'A
# operation without interpolations, which is extremely fast.
#
# The module is intended to be used in combination with an fft kernel
# computed to be the frequency response of an embedded Toeplitz matrix. The
# kernel is calculated offline via
#
# torchkbnufft.nufft.toep_functions.calc_toep_kernel
#
# The corresponding kernel is then passed to this module in its forward
# forward operation, which applies a (zero-padded) fft filter using the
# kernel.
# """
#
# def __init__(self):
# super(ToepNufft, self).__init__()
#
# def forward(self, x, kern, norm=None):
# """Toeplitz NUFFT forward function.
#
# Args:
# x (tensor): The image (or images) to apply the forward/backward
# Toeplitz-embedded NUFFT to.
# kern (tensor): The filter response taking into account Toeplitz
# embedding.
# norm (str, default=None): Use 'ortho' if kern was designed to use
# orthogonal FFTs.
#
# Returns:
# tensor: x after applying the Toeplitz NUFFT.
# """
# x = ToepNufftFunction.apply(x, kern, norm)
#
# return x