-
Notifications
You must be signed in to change notification settings - Fork 0
/
fourier_layer.py
313 lines (265 loc) · 12.8 KB
/
fourier_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import tensorflow as tf
import numpy as np
import DeepSaki.initializer.helper
class FourierConvolution2D(tf.keras.layers.Layer):
'''
Performs convolution by multiplying filter and inputs in the fourier domain. Layer input is asumed to be in spatial domain.
args:
- filters: Number of individual filters.
- kernels: Kernel of the spatial convolution. Expected input [height,width]. If None, kernel size is size of the input height and width
- use_bias: bool to indicate whether or not to us bias weights.
- filter_initializer: Initializer for filter weights.
- bias_initializer: Initializer for bias weights.
- isChannelFirst: True or False. If True, input shape is assumed to be [batch,channel,height,width]. If False, input shape is assumed to be [batch,height,width,channel]
- **kwargs: keyword arguments passed to the parent class tf.keras.layers.Layer.
'''
def __init__(self,
filters,
kernels = None,
use_bias = True,
kernel_initializer = tf.keras.initializers.RandomUniform(-0.05,0.05),
bias_initializer = tf.keras.initializers.Zeros(),
isChannelFirst = False,
**kwargs
):
super(FourierConvolution2D, self).__init__(**kwargs)
self.filters = filters
self.kernels = kernels
self.use_bias = use_bias
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer
self.isChannelFirst = isChannelFirst
def build(self, input_shape):
super(FourierConvolution2D, self).build(input_shape)
if self.isChannelFirst:
batch_size, inp_filter, inp_height, inp_width = input_shape
else:
batch_size, inp_height, inp_width, inp_filter = input_shape
if not self.kernels:
self.kernels = [inp_height, inp_width]
#weights are independent from batch size [out_filter,inp_filter,kernel,kernel]. I leave the two kernels last, since I then can easily calculate the 2d FFT at once!
self.kernel = self.add_weight(name="kernel", shape=[self.filters, inp_filter, self.kernels[0], self.kernels[1]],initializer = self.kernel_initializer, trainable=True)
if self.use_bias:
self.bias = self.add_weight(name="bias", shape=[self.filters,1,1],initializer = self.bias_initializer, trainable=True)
#Output shape: batch_size, self.filters, inp_height, inp_width/2+1. Filters is zero, since concatenated later. For rFFT, the las dimension is reduced!
self.out_shape = (batch_size,0,inp_height, int(inp_width/2)+1)
def call(self, inputs):
if not self.built:
raise ValueError('This model has not yet been built.')
#FFT2D is calculated over last two dimensions!
if not self.isChannelFirst:
inputs = tf.einsum("bhwc->bchw",inputs)
outputs_F = np.ndarray(shape=self.out_shape)
# Pad the kernel to the shape of the input to enable element-wise multiplication
signal_shape = tf.shape(inputs)
kernel_shape = tf.shape(self.kernel)
x_pad = signal_shape[2] - kernel_shape[2]
y_pad = signal_shape[3] - kernel_shape[3]
# paddings shape is [2,4] because rank of inputs is 4, and 2 for height and width
paddings = [[0,0],
[0,0],
[0,x_pad],
[0,y_pad]
]
kernels_padded = tf.pad(self.kernel, paddings) # [out_channels, inp_channel, height,width,out_channe]
#print("Shape: inputs {}".format(np.shape(inputs)))
#print("Shape: kernels_padded {}".format(np.shape(kernels_padded)))
# Compute DFFTs for both inputs and kernel weights
inputs_F = tf.signal.rfft2d(inputs) #[batch,height,width,channel]
kernels_F = tf.signal.rfft2d(kernels_padded)
#kernels_F = tf.math.conj(kernels_F) #calculate conjugate to be mathematically correct with the cross-corelation implementation. Not important, since filter is learned!
#print("Shape: inputs_F {}".format(np.shape(inputs_F)))
#print("Shape: kernels_F {}".format(np.shape(kernels_F)))
# Apply filters by element wise multiplications
for filter in range(self.filters):
#print("Shape: kernels_F[filter,:,:,:] {}".format(np.shape(kernels_F[filter,:,:,:])))
outputs_F = tf.concat(
[outputs_F,
tf.reduce_sum(
inputs_F * kernels_F[filter,:,:,:], #inputs:(batch, inp_filter, height, width ), fourier_filter:(...,out_filter,inp_filter,height, width)
axis = -3, # sum over all applied filters
keepdims = True
)],
axis = -3 # is the new filter count, since channel first
)
#print("Shape: outputs_F {}".format(np.shape(outputs_F)))
# Inverse rDFFT
output = tf.signal.irfft2d(outputs_F)
#output = tf.math.real(output)
if self.use_bias:
output += self.bias
#reverse the channel configuration to its initial config
if not self.isChannelFirst:
output = tf.einsum("bchw->bhwc",output)
return output
def get_config(self):
config = super(FourierConvolution2D, self).get_config()
config.update({
"filters":self.filters,
"kernels":self.kernels,
"use_bias":self.use_bias,
"kernel_initializer":self.kernel_initializer,
"bias_initializer":self.bias_initializer,
"isChannelFirst":self.isChannelFirst
})
return config
class FourierFilter2D(tf.keras.layers.Layer):
'''
Learnable filter in frequency domain. Expects input data to be in the fourier domain.
args:
- filters: number of independent filters
- use_bias: bool to indicate whether or not to us bias weights.
- filter_initializer: Initializer for filter weights.
- bias_initializer: Initializer for bias weights.
- isChannelFirst: True or False. If True, input shape is assumed to be [batch,channel,height,width]. If False, input shape is assumed to be [batch,height,width,channel]
- **kwargs: keyword arguments passed to the parent class tf.keras.layers.Layer.
'''
def __init__(self,
filters,
use_bias = True,
filter_initializer = tf.keras.initializers.RandomUniform(-0.05,0.05),
bias_initializer = tf.keras.initializers.Zeros(),
isChannelFirst = False,
**kwargs
):
super(FourierFilter2D, self).__init__(**kwargs)
self.filters = filters
self.use_bias = use_bias
self.filter_initializer = DeepSaki.initializer.helper.MakeInitializerComplex(filter_initializer)
self.bias_initializer = DeepSaki.initializer.helper.MakeInitializerComplex(bias_initializer)
self.isChannelFirst = isChannelFirst
self.fourier_filter = None # shape: batch, height, width, input_filters, output_filters
self.fourier_bias = None
self.out_shape = None
def build(self, input_shape):
super(FourierFilter2D, self).build(input_shape)
if self.isChannelFirst:
batch_size,inp_filter, inp_height, inp_width = input_shape
else:
batch_size, inp_height, inp_width, inp_filter = input_shape
#weights are independent from batch size. Filter dimensions differ from convolution, since FFT2D is calculated over last 2 dimensions
self.fourier_filter = self.add_weight(name="filter", shape=[inp_filter, inp_height, inp_width, self.filters],initializer = self.filter_initializer, trainable=True, dtype=tf.dtypes.complex64)
if self.use_bias: #shape: [filter,1,1] so it can be broadcasted when adding to the output, since FFT asumes channel first!
self.fourier_bias = self.add_weight(name="bias", shape=[self.filters,1,1],initializer = self.bias_initializer, trainable=True, dtype=tf.dtypes.complex64)
#Output shape: batch_size, self.filters, inp_height, inp_width. Filters is zero, since concatenated later
self.out_shape = (batch_size,0,inp_height, inp_width)
def call(self, inputs):
'''
I take advantage of broadcasting to calculate the batches: https://numpy.org/doc/stable/user/basics.broadcasting.html
'''
if not self.built:
raise ValueError('This model has not yet been built.')
if not self.isChannelFirst: #FFT2D is calculated over last two dimensions!
inputs = tf.einsum("bhwc->bchw",inputs)
output = np.ndarray(shape=self.out_shape)
for filter in range(self.filters):
output = tf.concat(
[output,
tf.reduce_sum(
inputs * self.fourier_filter[:,:,:,filter], #inputs:(batch, inp_filter, height, width ), fourier_filter:(...,inp_filter,height, width, out_filter)
axis = -3, # sum over all applied filters
keepdims = True
)],
axis = -3 # is the new filter count, since channel first
)
if self.use_bias:
output += self.fourier_bias
if not self.isChannelFirst: #reverse the channel configuration to its initial config
output = tf.einsum("bchw->bhwc",output)
return output
def get_config(self):
config = super(FourierFilter2D, self).get_config()
config.update({
"filters":self.filters,
"use_bias":self.use_bias,
"kernel_initializer":self.filter_initializer,
"bias_initializer":self.bias_initializer,
"isChannelFirst":self.isChannelFirst
})
return config
class FFT2D(tf.keras.layers.Layer):
'''
Calculates the 2D descrete fourier transform
args:
- isChannelFirst: True or False. If True, input shape is assumed to be [batch,channel,height,width]. If False, input shape is assumed to be [batch,height,width,channel]
- applyRealFFT: True or False. If True, rfft2D is applied, which assumes real valued inputs and halves the width of the output. If False, fft2D is applied, which assumes complex input.
- shiftFFT: True or False. If true, low frequency componentes are centered.
- **kwargs: keyword arguments passed to the parent class tf.keras.layers.Layer.
'''
def __init__(self,
isChannelFirst = False,
applyRealFFT = False,
shiftFFT = True,
**kwargs
):
super(FFT2D, self).__init__(**kwargs)
self.isChannelFirst = isChannelFirst
self.applyRealFFT = applyRealFFT
self.shiftFFT = shiftFFT
def call(self, inputs):
if not self.isChannelFirst:
inputs = tf.einsum("bhwc->bchw",inputs)
if self.applyRealFFT:
x = tf.signal.rfft2d(inputs)
if self.shiftFFT:
x = tf.signal.fftshift(x, axes=[-2])
else:
imag = tf.zeros_like(inputs)
inputs = tf.complex(inputs,imag) #fft2d requires complex inputs -> create complex with 0 imaginary
x = tf.signal.fft2d(inputs)
if self.shiftFFT:
x = tf.signal.fftshift(x)
if not self.isChannelFirst: #reverse the channel configuration to its initial config
x = tf.einsum("bchw->bhwc",x)
return x
def get_config(self):
config = super(FFT2D, self).get_config()
config.update({
"isChannelFirst":self.isChannelFirst,
"applyRealFFT":self.applyRealFFT,
"shiftFFT":self.shiftFFT
})
return config
class iFFT2D(tf.keras.layers.Layer):
'''
Calculates the 2D inverse FFT and reverses the center shift operation
args:
- isChannelFirst: True or False. If True, input shape is assumed to be [batch,channel,height,width]. If False, input shape is assumed to be [batch,height,width,channel]
- applyRealFFT: True or False. If True, rfft2D is applied, which assumes real valued inputs and halves the width of the output. If False, fft2D is applied, which assumes complex input.
- shiftFFT: True or False. If True, shift operation of fourier transform is reversed before calculating the inverse fourier transformation
- **kwargs: keyword arguments passed to the parent class tf.keras.layers.Layer.
'''
def __init__(self,
isChannelFirst = False,
applyRealFFT = False,
shiftFFT = True,
**kwargs
):
super(iFFT2D, self).__init__(**kwargs)
self.isChannelFirst =isChannelFirst
self.applyRealFFT = applyRealFFT
self.shiftFFT=shiftFFT
def call(self, inputs):
if not self.isChannelFirst:
inputs = tf.einsum("bhwc->bchw",inputs)
x = inputs
if self.applyRealFFT:
if self.shiftFFT:
x = tf.signal.ifftshift(x, axes=[-2])
x = tf.signal.irfft2d(x)
else:
if self.shiftFFT:
x = tf.signal.ifftshift(x)
x = tf.signal.ifft2d(x)
x = tf.math.real(x)
if not self.isChannelFirst: #reverse the channel configuration to its initial config
x = tf.einsum("bchw->bhwc",x)
return x
def get_config(self):
config = super(iFFT2D, self).get_config()
config.update({
"isChannelFirst":self.isChannelFirst,
"applyRealFFT":self.applyRealFFT,
"shiftFFT":self.shiftFFT
})
return config