-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathFastFlowNet_v2.py
182 lines (152 loc) · 7.03 KB
/
FastFlowNet_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from spatial_correlation_sampler import SpatialCorrelationSampler
class Correlation(nn.Module):
def __init__(self, max_displacement):
super(Correlation, self).__init__()
self.max_displacement = max_displacement
self.kernel_size = 2*max_displacement+1
self.corr = SpatialCorrelationSampler(1, self.kernel_size, 1, 0, 1)
def forward(self, x, y):
b, c, h, w = x.shape
return self.corr(x, y).view(b, -1, h, w) / c
def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
nn.LeakyReLU(0.1, inplace=True)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.ConvTranspose2d(in_planes, out_planes, kernel_size, stride, padding, bias=True)
class Decoder(nn.Module):
def __init__(self, in_channels, groups):
super(Decoder, self).__init__()
self.in_channels = in_channels
self.groups = groups
self.conv1 = convrelu(in_channels, 96, 3, 1)
self.conv2 = convrelu(96, 96, 3, 1, groups=groups)
self.conv3 = convrelu(96, 96, 3, 1, groups=groups)
self.conv4 = convrelu(96, 96, 3, 1, groups=groups)
self.conv5 = convrelu(96, 64, 3, 1)
self.conv6 = convrelu(64, 32, 3, 1)
self.conv7 = nn.Conv2d(32, 2, 3, 1, 1)
def channel_shuffle(self, x, groups):
b, c, h, w = x.size()
channels_per_group = c // groups
x = x.view(b, groups, channels_per_group, h, w)
x = x.transpose(1, 2).contiguous()
x = x.view(b, -1, h, w)
return x
def forward(self, x):
if self.groups == 1:
out = self.conv7(self.conv6(self.conv5(self.conv4(self.conv3(self.conv2(self.conv1(x)))))))
else:
out = self.conv1(x)
out = self.channel_shuffle(self.conv2(out), self.groups)
out = self.channel_shuffle(self.conv3(out), self.groups)
out = self.channel_shuffle(self.conv4(out), self.groups)
out = self.conv7(self.conv6(self.conv5(out)))
return out
class FastFlowNet(nn.Module):
def __init__(self, groups=3):
super(FastFlowNet, self).__init__()
self.groups = groups
self.pconv1_1 = convrelu(3, 16, 3, 2)
self.pconv1_2 = convrelu(16, 16, 3, 1)
self.pconv2_1 = convrelu(16, 32, 3, 2)
self.pconv2_2 = convrelu(32, 32, 3, 1)
self.pconv2_3 = convrelu(32, 32, 3, 1)
self.pconv3_1 = convrelu(32, 64, 3, 2)
self.pconv3_2 = convrelu(64, 64, 3, 1)
self.pconv3_3 = convrelu(64, 64, 3, 1)
self.corr = Correlation(4)
self.index = torch.tensor([0, 2, 4, 6, 8,
10, 12, 14, 16,
18, 20, 21, 22, 23, 24, 26,
28, 29, 30, 31, 32, 33, 34,
36, 38, 39, 40, 41, 42, 44,
46, 47, 48, 49, 50, 51, 52,
54, 56, 57, 58, 59, 60, 62,
64, 66, 68, 70,
72, 74, 76, 78, 80])
self.rconv2 = convrelu(32, 32, 3, 1)
self.rconv3 = convrelu(64, 32, 3, 1)
self.rconv4 = convrelu(64, 32, 3, 1)
self.rconv5 = convrelu(64, 32, 3, 1)
self.rconv6 = convrelu(64, 32, 3, 1)
self.up3 = deconv(2, 2)
self.up4 = deconv(2, 2)
self.up5 = deconv(2, 2)
self.up6 = deconv(2, 2)
self.decoder2 = Decoder(87, groups)
self.decoder3 = Decoder(87, groups)
self.decoder4 = Decoder(87, groups)
self.decoder5 = Decoder(87, groups)
self.decoder6 = Decoder(87, groups)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def warp(self, x, flo):
B, C, H, W = x.size()
xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
grid = torch.cat([xx, yy], 1).to(x)
vgrid = grid + flo
vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W-1, 1) - 1.0
vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H-1, 1) - 1.0
vgrid = vgrid.permute(0, 2, 3, 1)
output = F.grid_sample(x, vgrid, mode='bilinear', align_corners=True)
return output
def forward(self, x):
img1 = x[:, :3, :, :]
img2 = x[:, 3:6, :, :]
f11 = self.pconv1_2(self.pconv1_1(img1))
f21 = self.pconv1_2(self.pconv1_1(img2))
f12 = self.pconv2_3(self.pconv2_2(self.pconv2_1(f11)))
f22 = self.pconv2_3(self.pconv2_2(self.pconv2_1(f21)))
f13 = self.pconv3_3(self.pconv3_2(self.pconv3_1(f12)))
f23 = self.pconv3_3(self.pconv3_2(self.pconv3_1(f22)))
f14 = F.avg_pool2d(f13, kernel_size=(2, 2), stride=(2, 2))
f24 = F.avg_pool2d(f23, kernel_size=(2, 2), stride=(2, 2))
f15 = F.avg_pool2d(f14, kernel_size=(2, 2), stride=(2, 2))
f25 = F.avg_pool2d(f24, kernel_size=(2, 2), stride=(2, 2))
f16 = F.avg_pool2d(f15, kernel_size=(2, 2), stride=(2, 2))
f26 = F.avg_pool2d(f25, kernel_size=(2, 2), stride=(2, 2))
flow7_up = torch.zeros(f16.size(0), 2, f16.size(2), f16.size(3)).to(f15)
cv6 = torch.index_select(self.corr(f16, f26), dim=1, index=self.index.to(f16).long())
r16 = self.rconv6(f16)
cat6 = torch.cat([cv6, r16, flow7_up], 1)
flow6 = self.decoder6(cat6)
flow6_up = self.up6(flow6)
f25_w = self.warp(f25, flow6_up*0.625)
cv5 = torch.index_select(self.corr(f15, f25_w), dim=1, index=self.index.to(f15).long())
r15 = self.rconv5(f15)
cat5 = torch.cat([cv5, r15, flow6_up], 1)
flow5 = self.decoder5(cat5) + flow6_up
flow5_up = self.up5(flow5)
f24_w = self.warp(f24, flow5_up*1.25)
cv4 = torch.index_select(self.corr(f14, f24_w), dim=1, index=self.index.to(f14).long())
r14 = self.rconv4(f14)
cat4 = torch.cat([cv4, r14, flow5_up], 1)
flow4 = self.decoder4(cat4) + flow5_up
flow4_up = self.up4(flow4)
f23_w = self.warp(f23, flow4_up*2.5)
cv3 = torch.index_select(self.corr(f13, f23_w), dim=1, index=self.index.to(f13).long())
r13 = self.rconv3(f13)
cat3 = torch.cat([cv3, r13, flow4_up], 1)
flow3 = self.decoder3(cat3) + flow4_up
flow3_up = self.up3(flow3)
f22_w = self.warp(f22, flow3_up*5.0)
cv2 = torch.index_select(self.corr(f12, f22_w), dim=1, index=self.index.to(f12).long())
r12 = self.rconv2(f12)
cat2 = torch.cat([cv2, r12, flow3_up], 1)
flow2 = self.decoder2(cat2) + flow3_up
if self.training:
return flow2, flow3, flow4, flow5, flow6
else:
return flow2