forked from titu1994/tf_neural_deconvolution
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pt_deconv.py
318 lines (250 loc) · 12 KB
/
pt_deconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
""" Modified from https://github.com/yechengxi/deconvolution/blob/master/models/deconv.py """
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules import conv
from torch.nn.modules.utils import _pair
# iteratively solve for inverse sqrt of a matrix
def isqrt_newton_schulz_autograd(A, numIters):
dim = A.shape[0]
normA = A.norm()
Y = A.div(normA)
I = torch.eye(dim, dtype=A.dtype, device=A.device)
Z = torch.eye(dim, dtype=A.dtype, device=A.device)
for i in range(numIters):
T = 0.5 * (3.0 * I - Z @ Y)
Y = Y @ T
Z = T @ Z
# A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
def isqrt_newton_schulz_autograd_batch(A, numIters):
batchSize, dim, _ = A.shape
normA = A.view(batchSize, -1).norm(2, 1).view(batchSize, 1, 1)
Y = A.div(normA)
I = torch.eye(dim, dtype=A.dtype, device=A.device).unsqueeze(0).expand_as(A)
Z = torch.eye(dim, dtype=A.dtype, device=A.device).unsqueeze(0).expand_as(A)
for i in range(numIters):
T = 0.5 * (3.0 * I - Z.bmm(Y))
Y = Y.bmm(T)
Z = T.bmm(Z)
# A_sqrt = Y*torch.sqrt(normA)
A_isqrt = Z / torch.sqrt(normA)
return A_isqrt
# deconvolve channels
class ChannelDeconv2D(nn.Module):
def __init__(self, block, eps=1e-2, n_iter=5, momentum=0.1, sampling_stride=3):
super(ChannelDeconv2D, self).__init__()
self.eps = eps
self.n_iter = n_iter
self.momentum = momentum
self.block = block
self.register_buffer('running_mean1', torch.zeros(block, 1))
# self.register_buffer('running_cov', torch.eye(block))
self.register_buffer('running_deconv', torch.eye(block))
self.register_buffer('running_mean2', torch.zeros(1, 1))
self.register_buffer('running_var', torch.ones(1, 1))
self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
self.sampling_stride = sampling_stride
def forward(self, x):
x_shape = x.shape
if len(x.shape) == 2:
x = x.view(x.shape[0], x.shape[1], 1, 1)
if len(x.shape) == 3:
print('Error! Unsupprted tensor shape.')
N, C, H, W = x.size()
B = self.block
# take the first c channels out for deconv
c = int(C / B) * B
if c == 0:
print('Error! block should be set smaller.')
# step 1. remove mean
if c != C:
x1 = x[:, :c].permute(1, 0, 2, 3).contiguous().view(B, -1)
else:
x1 = x.permute(1, 0, 2, 3).contiguous().view(B, -1)
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
x1_s = x1[:, ::self.sampling_stride ** 2]
else:
x1_s = x1
mean1 = x1_s.mean(-1, keepdim=True)
if self.num_batches_tracked == 0:
self.running_mean1.copy_(mean1.detach())
if self.training:
self.running_mean1.mul_(1 - self.momentum)
self.running_mean1.add_(mean1.detach() * self.momentum)
else:
mean1 = self.running_mean1
x1 = x1 - mean1
# step 2. calculate deconv@x1 = cov^(-0.5)@x1
if self.training:
cov = x1_s @ x1_s.t() / x1_s.shape[1] + self.eps * torch.eye(B, dtype=x.dtype, device=x.device)
deconv = isqrt_newton_schulz_autograd(cov, self.n_iter)
if self.num_batches_tracked == 0:
# self.running_cov.copy_(cov.detach())
self.running_deconv.copy_(deconv.detach())
if self.training:
# self.running_cov.mul_(1-self.momentum)
# self.running_cov.add_(cov.detach()*self.momentum)
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
# cov = self.running_cov
deconv = self.running_deconv
x1 = deconv @ x1
# reshape to N,c,J,W
x1 = x1.view(c, N, H, W).contiguous().permute(1, 0, 2, 3)
# normalize the remaining channels
if c != C:
x_tmp = x[:, c:].view(N, -1)
if self.sampling_stride > 1 and H >= self.sampling_stride and W >= self.sampling_stride:
x_s = x_tmp[:, ::self.sampling_stride ** 2]
else:
x_s = x_tmp
mean2 = x_s.mean()
var = x_s.var()
if self.num_batches_tracked == 0:
self.running_mean2.copy_(mean2.detach())
self.running_var.copy_(var.detach())
if self.training:
self.running_mean2.mul_(1 - self.momentum)
self.running_mean2.add_(mean2.detach() * self.momentum)
self.running_var.mul_(1 - self.momentum)
self.running_var.add_(var.detach() * self.momentum)
else:
mean2 = self.running_mean2
var = self.running_var
x_tmp = (x[:, c:] - mean2) / (var + self.eps).sqrt()
x1 = torch.cat([x1, x_tmp], dim=1)
if self.training:
self.num_batches_tracked.add_(1)
if len(x_shape) == 2:
x1 = x1.view(x_shape)
return x1
class FastDeconv2D(conv._ConvNd):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3, freeze=False, freeze_iter=100):
self.momentum = momentum
self.n_iter = n_iter
self.eps = eps
self.counter = 0
self.track_running_stats = True
super(FastDeconv2D, self).__init__(
in_channels, out_channels, _pair(kernel_size), _pair(stride), _pair(padding), _pair(dilation),
False, _pair(0), groups, bias, padding_mode='zeros')
if block > in_channels:
block = in_channels
else:
if in_channels % block != 0:
block = math.gcd(block, in_channels)
if groups > 1:
# grouped conv
block = in_channels // groups
self.block = block
self.num_features = kernel_size[0] * kernel_size[1] * block
if groups == 1:
self.register_buffer('running_mean', torch.zeros(self.num_features))
self.register_buffer('running_deconv', torch.eye(self.num_features))
else:
self.register_buffer('running_mean', torch.zeros(kernel_size[0] * kernel_size[1] * in_channels))
self.register_buffer('running_deconv', torch.eye(self.num_features).repeat(in_channels // block, 1, 1))
stride_int = stride[0] if type(stride) in (list, tuple) else stride
self.sampling_stride = sampling_stride * stride_int
self.counter = 0
self.freeze_iter = freeze_iter
self.freeze = freeze
def forward(self, x):
N, C, H, W = x.shape
B = self.block
frozen = self.freeze and (self.counter > self.freeze_iter)
if self.training and self.track_running_stats:
self.counter += 1
self.counter %= (self.freeze_iter * 10)
if self.training and (not frozen):
# 1. im2col: N x cols x pixels -> N*pixles x cols
if self.kernel_size[0] > 1:
X = torch.nn.functional.unfold(x, self.kernel_size, self.dilation, self.padding,
self.sampling_stride).transpose(1, 2).contiguous()
else:
# channel wise
X = x.permute(0, 2, 3, 1).contiguous().view(-1, C)[::self.sampling_stride ** 2, :]
if self.groups == 1:
# (C//B*N*pixels,k*k*B)
X = X.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1, self.num_features)
else:
X = X.view(-1, X.shape[-1])
# 2. subtract mean
X_mean = X.mean(0)
X = X - X_mean.unsqueeze(0)
# 3. calculate COV, COV^(-0.5), then deconv
if self.groups == 1:
# Cov = X.t() @ X / X.shape[0] + self.eps * torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Id = torch.eye(X.shape[1], dtype=X.dtype, device=X.device)
Cov = torch.addmm(self.eps, Id, 1. / X.shape[0], X.t(), X)
deconv = isqrt_newton_schulz_autograd(Cov, self.n_iter)
else:
X = X.view(-1, self.groups, self.num_features).transpose(0, 1)
Id = torch.eye(self.num_features, dtype=X.dtype, device=X.device).expand(self.groups, self.num_features,
self.num_features)
Cov = torch.baddbmm(self.eps, Id, 1. / X.shape[1], X.transpose(1, 2), X)
deconv = isqrt_newton_schulz_autograd_batch(Cov, self.n_iter)
if self.track_running_stats:
self.running_mean.mul_(1 - self.momentum)
self.running_mean.add_(X_mean.detach() * self.momentum)
# track stats for evaluation
self.running_deconv.mul_(1 - self.momentum)
self.running_deconv.add_(deconv.detach() * self.momentum)
else:
X_mean = self.running_mean
deconv = self.running_deconv
# 4. X * deconv * conv = X * (deconv * conv)
if self.groups == 1:
w = self.weight.view(-1, self.num_features, C // B).transpose(1, 2).contiguous().view(-1,
self.num_features) @ deconv
b = self.bias - (w @ (X_mean.unsqueeze(1))).view(self.weight.shape[0], -1).sum(1)
w = w.view(-1, C // B, self.num_features).transpose(1, 2).contiguous()
else:
w = self.weight.view(C // B, -1, self.num_features) @ deconv
b = self.bias - (w @ (X_mean.view(-1, self.num_features, 1))).view(self.bias.shape)
w = w.view(self.weight.shape)
x = F.conv2d(x, w, b, self.stride, self.padding, self.dilation, self.groups)
return x
""" 1D Conv Wrapper """
class ChannelDeconv1D(ChannelDeconv2D):
def __init__(self, block, eps=1e-5, n_iter=5, momentum=0.1, sampling_stride=3):
super(ChannelDeconv1D, self).__init__(block=block, eps=eps, n_iter=n_iter,
momentum=momentum, sampling_stride=sampling_stride)
def forward(self, x: torch.Tensor):
# insert dummy dimension in time channel
shape = x.size()
if len(shape) == 3:
x_expanded = x.unsqueeze(-1) # [N, C, T, 1]
else:
x_expanded = x
out = super(ChannelDeconv1D, self).forward(x_expanded)
if len(shape) == 3:
# remove dummy dimension
x = out.squeeze(-1) # [N, C', T / stride]
else:
x = out
return x
class FastDeconv1D(FastDeconv2D):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True,
eps=1e-5, n_iter=5, momentum=0.1, block=64, sampling_stride=3, freeze=False, freeze_iter=100):
kernel_size = (kernel_size, 1)
stride = (stride, 1)
padding = (padding, 0)
dilation = (dilation, 1)
super(FastDeconv1D, self).__init__(in_channels=in_channels, out_channels=out_channels,
kernel_size=kernel_size, stride=stride, padding=padding,
dilation=dilation, bias=bias, groups=groups, eps=eps,
n_iter=n_iter, momentum=momentum, block=block,
sampling_stride=sampling_stride, freeze=freeze, freeze_iter=freeze_iter)
def forward(self, x: torch.Tensor):
# insert dummy dimension in time channel
x_expanded = x.unsqueeze(-1) # [N, C, T, 1]
out = super(FastDeconv1D, self).forward(x_expanded)
# remove dummy dimension
x = out.squeeze(-1) # [N, C', T / stride]
return x