-
Notifications
You must be signed in to change notification settings - Fork 722
/
voxelnet.py
1142 lines (1066 loc) · 47 KB
/
voxelnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import time
from enum import Enum
from functools import reduce
import numpy as np
import sparseconvnet as scn
import torch
from torch import nn
from torch.nn import functional as F
import torchplus
from torchplus import metrics
from torchplus.nn import Empty, GroupNorm, Sequential
from torchplus.ops.array_ops import gather_nd, scatter_nd
from torchplus.tools import change_default_args
from second.pytorch.core import box_torch_ops
from second.pytorch.core.losses import (WeightedSigmoidClassificationLoss,
WeightedSmoothL1LocalizationLoss,
WeightedSoftmaxClassificationLoss)
def _get_pos_neg_loss(cls_loss, labels):
# cls_loss: [N, num_anchors, num_class]
# labels: [N, num_anchors]
batch_size = cls_loss.shape[0]
if cls_loss.shape[-1] == 1 or len(cls_loss.shape) == 2:
cls_pos_loss = (labels > 0).type_as(cls_loss) * cls_loss.view(
batch_size, -1)
cls_neg_loss = (labels == 0).type_as(cls_loss) * cls_loss.view(
batch_size, -1)
cls_pos_loss = cls_pos_loss.sum() / batch_size
cls_neg_loss = cls_neg_loss.sum() / batch_size
else:
cls_pos_loss = cls_loss[..., 1:].sum() / batch_size
cls_neg_loss = cls_loss[..., 0].sum() / batch_size
return cls_pos_loss, cls_neg_loss
def get_paddings_indicator(actual_num, max_num, axis=0):
"""Create boolean mask by actually number of a padded tensor.
Args:
actual_num ([type]): [description]
max_num ([type]): [description]
Returns:
[type]: [description]
"""
actual_num = torch.unsqueeze(actual_num, axis + 1)
# tiled_actual_num: [N, M, 1]
max_num_shape = [1] * len(actual_num.shape)
max_num_shape[axis + 1] = -1
max_num = torch.arange(
max_num, dtype=torch.int, device=actual_num.device).view(max_num_shape)
# tiled_actual_num: [[3,3,3,3,3], [4,4,4,4,4], [2,2,2,2,2]]
# tiled_max_num: [[0,1,2,3,4], [0,1,2,3,4], [0,1,2,3,4]]
paddings_indicator = actual_num.int() > max_num
# paddings_indicator shape: [batch_size, max_num]
return paddings_indicator
class VFELayer(nn.Module):
def __init__(self, in_channels, out_channels, use_norm=True, name='vfe'):
super(VFELayer, self).__init__()
self.name = name
self.units = int(out_channels / 2)
if use_norm:
BatchNorm1d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm1d)
Linear = change_default_args(bias=False)(nn.Linear)
else:
BatchNorm1d = Empty
Linear = change_default_args(bias=True)(nn.Linear)
self.linear = Linear(in_channels, self.units)
self.norm = BatchNorm1d(self.units)
def forward(self, inputs):
# [K, T, 7] tensordot [7, units] = [K, T, units]
voxel_count = inputs.shape[1]
x = self.linear(inputs)
x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2,
1).contiguous()
pointwise = F.relu(x)
# [K, T, units]
aggregated = torch.max(pointwise, dim=1, keepdim=True)[0]
# [K, 1, units]
repeated = aggregated.repeat(1, voxel_count, 1)
concatenated = torch.cat([pointwise, repeated], dim=2)
# [K, T, 2 * units]
return concatenated
class VoxelFeatureExtractor(nn.Module):
def __init__(self,
num_input_features=4,
use_norm=True,
num_filters=[32, 128],
with_distance=False,
name='VoxelFeatureExtractor'):
super(VoxelFeatureExtractor, self).__init__()
self.name = name
if use_norm:
BatchNorm1d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm1d)
Linear = change_default_args(bias=False)(nn.Linear)
else:
BatchNorm1d = Empty
Linear = change_default_args(bias=True)(nn.Linear)
assert len(num_filters) == 2
num_input_features += 3 # add mean features
if with_distance:
num_input_features += 1
self._with_distance = with_distance
self.vfe1 = VFELayer(num_input_features, num_filters[0], use_norm)
self.vfe2 = VFELayer(num_filters[0], num_filters[1], use_norm)
self.linear = Linear(num_filters[1], num_filters[1])
# var_torch_init(self.linear.weight)
# var_torch_init(self.linear.bias)
self.norm = BatchNorm1d(num_filters[1])
def forward(self, features, num_voxels):
# features: [concated_num_points, num_voxel_size, 3(4)]
# num_voxels: [concated_num_points]
points_mean = features[:, :, :3].sum(
dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1)
features_relative = features[:, :, :3] - points_mean
if self._with_distance:
points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
features = torch.cat(
[features, features_relative, points_dist], dim=-1)
else:
features = torch.cat([features, features_relative], dim=-1)
voxel_count = features.shape[1]
mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
mask = torch.unsqueeze(mask, -1).type_as(features)
# mask = features.max(dim=2, keepdim=True)[0] != 0
x = self.vfe1(features)
x *= mask
x = self.vfe2(x)
x *= mask
x = self.linear(x)
x = self.norm(x.permute(0, 2, 1).contiguous()).permute(0, 2,
1).contiguous()
x = F.relu(x)
x *= mask
# x: [concated_num_points, num_voxel_size, 128]
voxelwise = torch.max(x, dim=1)[0]
return voxelwise
class VoxelFeatureExtractorV2(nn.Module):
def __init__(self,
num_input_features=4,
use_norm=True,
num_filters=[32, 128],
with_distance=False,
name='VoxelFeatureExtractor'):
super(VoxelFeatureExtractorV2, self).__init__()
self.name = name
if use_norm:
BatchNorm1d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm1d)
Linear = change_default_args(bias=False)(nn.Linear)
else:
BatchNorm1d = Empty
Linear = change_default_args(bias=True)(nn.Linear)
assert len(num_filters) > 0
num_input_features += 3
if with_distance:
num_input_features += 1
self._with_distance = with_distance
num_filters = [num_input_features] + num_filters
filters_pairs = [[num_filters[i], num_filters[i + 1]]
for i in range(len(num_filters) - 1)]
self.vfe_layers = nn.ModuleList(
[VFELayer(i, o, use_norm) for i, o in filters_pairs])
self.linear = Linear(num_filters[-1], num_filters[-1])
# var_torch_init(self.linear.weight)
# var_torch_init(self.linear.bias)
self.norm = BatchNorm1d(num_filters[-1])
def forward(self, features, num_voxels):
# features: [concated_num_points, num_voxel_size, 3(4)]
# num_voxels: [concated_num_points]
points_mean = features[:, :, :3].sum(
dim=1, keepdim=True) / num_voxels.type_as(features).view(-1, 1, 1)
features_relative = features[:, :, :3] - points_mean
if self._with_distance:
points_dist = torch.norm(features[:, :, :3], 2, 2, keepdim=True)
features = torch.cat(
[features, features_relative, points_dist], dim=-1)
else:
features = torch.cat([features, features_relative], dim=-1)
voxel_count = features.shape[1]
mask = get_paddings_indicator(num_voxels, voxel_count, axis=0)
mask = torch.unsqueeze(mask, -1).type_as(features)
for vfe in self.vfe_layers:
features = vfe(features)
features *= mask
features = self.linear(features)
features = self.norm(features.permute(0, 2, 1).contiguous()).permute(
0, 2, 1).contiguous()
features = F.relu(features)
features *= mask
# x: [concated_num_points, num_voxel_size, 128]
voxelwise = torch.max(features, dim=1)[0]
return voxelwise
class SparseMiddleExtractor(nn.Module):
def __init__(self,
output_shape,
use_norm=True,
num_input_features=128,
num_filters_down1=[64],
num_filters_down2=[64, 64],
name='SparseMiddleExtractor'):
super(SparseMiddleExtractor, self).__init__()
self.name = name
if use_norm:
BatchNorm1d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm1d)
Linear = change_default_args(bias=False)(nn.Linear)
else:
BatchNorm1d = Empty
Linear = change_default_args(bias=True)(nn.Linear)
sparse_shape = np.array(output_shape[1:4]) + [1, 0, 0]
# sparse_shape[0] = 11
print(sparse_shape)
self.scn_input = scn.InputLayer(3, sparse_shape.tolist())
self.voxel_output_shape = output_shape
middle_layers = []
num_filters = [num_input_features] + num_filters_down1
# num_filters = [64] + num_filters_down1
filters_pairs_d1 = [[num_filters[i], num_filters[i + 1]]
for i in range(len(num_filters) - 1)]
for i, o in filters_pairs_d1:
middle_layers.append(scn.SubmanifoldConvolution(3, i, o, 3, False))
middle_layers.append(scn.BatchNormReLU(o, eps=1e-3, momentum=0.99))
middle_layers.append(
scn.Convolution(
3,
num_filters[-1],
num_filters[-1], (3, 1, 1), (2, 1, 1),
bias=False))
middle_layers.append(
scn.BatchNormReLU(num_filters[-1], eps=1e-3, momentum=0.99))
# assert len(num_filters_down2) > 0
if len(num_filters_down1) == 0:
num_filters = [num_filters[-1]] + num_filters_down2
else:
num_filters = [num_filters_down1[-1]] + num_filters_down2
filters_pairs_d2 = [[num_filters[i], num_filters[i + 1]]
for i in range(len(num_filters) - 1)]
for i, o in filters_pairs_d2:
middle_layers.append(scn.SubmanifoldConvolution(3, i, o, 3, False))
middle_layers.append(scn.BatchNormReLU(o, eps=1e-3, momentum=0.99))
middle_layers.append(
scn.Convolution(
3,
num_filters[-1],
num_filters[-1], (3, 1, 1), (2, 1, 1),
bias=False))
middle_layers.append(
scn.BatchNormReLU(num_filters[-1], eps=1e-3, momentum=0.99))
middle_layers.append(scn.SparseToDense(3, num_filters[-1]))
self.middle_conv = Sequential(*middle_layers)
def forward(self, voxel_features, coors, batch_size):
# coors[:, 1] += 1
coors = coors.int()[:, [1, 2, 3, 0]]
ret = self.scn_input((coors.cpu(), voxel_features, batch_size))
ret = self.middle_conv(ret)
N, C, D, H, W = ret.shape
ret = ret.view(N, C * D, H, W)
return ret
class ZeroPad3d(nn.ConstantPad3d):
def __init__(self, padding):
super(ZeroPad3d, self).__init__(padding, 0)
class MiddleExtractor(nn.Module):
def __init__(self,
output_shape,
use_norm=True,
num_input_features=128,
num_filters_down1=[64],
num_filters_down2=[64, 64],
name='MiddleExtractor'):
super(MiddleExtractor, self).__init__()
self.name = name
if use_norm:
BatchNorm3d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm3d)
# BatchNorm3d = change_default_args(
# group=32, eps=1e-3, momentum=0.01)(GroupBatchNorm3d)
Conv3d = change_default_args(bias=False)(nn.Conv3d)
else:
BatchNorm3d = Empty
Conv3d = change_default_args(bias=True)(nn.Conv3d)
self.voxel_output_shape = output_shape
self.middle_conv = Sequential(
# ZeroPad3d(1),
ZeroPad3d([0, 0, 0, 0, 0, 1]),
Conv3d(num_input_features, 64, (3, 1, 1), stride=(2, 1, 1)),
BatchNorm3d(64),
nn.ReLU(),
# ZeroPad3d([1, 1, 1, 1, 0, 0]),
# Conv3d(64, 64, 3, stride=1),
# BatchNorm3d(64),
# nn.ReLU(),
# ZeroPad3d(1),
# ZeroPad3d([0, 0, 0, 0, 1, 0]),
Conv3d(64, 64, (3, 1, 1), stride=(2, 1, 1)),
BatchNorm3d(64),
nn.ReLU(),
)
def forward(self, voxel_features, coors, batch_size):
output_shape = [batch_size] + self.voxel_output_shape[1:]
ret = scatter_nd(coors.long(), voxel_features, output_shape)
# print('scatter_nd fw:', time.time() - t)
ret = ret.permute(0, 4, 1, 2, 3)
ret = self.middle_conv(ret)
N, C, D, H, W = ret.shape
ret = ret.view(N, C * D, H, W)
return ret
class RPN(nn.Module):
def __init__(self,
use_norm=True,
num_class=2,
layer_nums=[3, 5, 5],
layer_strides=[2, 2, 2],
num_filters=[128, 128, 256],
upsample_strides=[1, 2, 4],
num_upsample_filters=[256, 256, 256],
num_input_filters=128,
num_anchor_per_loc=2,
encode_background_as_zeros=True,
use_direction_classifier=True,
use_groupnorm=False,
num_groups=32,
use_bev=False,
box_code_size=7,
name='rpn'):
super(RPN, self).__init__()
self._num_anchor_per_loc = num_anchor_per_loc
self._use_direction_classifier = use_direction_classifier
self._use_bev = use_bev
assert len(layer_nums) == 3
assert len(layer_strides) == len(layer_nums)
assert len(num_filters) == len(layer_nums)
assert len(upsample_strides) == len(layer_nums)
assert len(num_upsample_filters) == len(layer_nums)
factors = []
for i in range(len(layer_nums)):
assert int(np.prod(layer_strides[:i + 1])) % upsample_strides[i] == 0
factors.append(np.prod(layer_strides[:i + 1]) // upsample_strides[i])
assert all([x == factors[0] for x in factors])
if use_norm:
if use_groupnorm:
BatchNorm2d = change_default_args(
num_groups=num_groups, eps=1e-3)(GroupNorm)
else:
BatchNorm2d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm2d)
Conv2d = change_default_args(bias=False)(nn.Conv2d)
ConvTranspose2d = change_default_args(bias=False)(
nn.ConvTranspose2d)
else:
BatchNorm2d = Empty
Conv2d = change_default_args(bias=True)(nn.Conv2d)
ConvTranspose2d = change_default_args(bias=True)(
nn.ConvTranspose2d)
# note that when stride > 1, conv2d with same padding isn't
# equal to pad-conv2d. we should use pad-conv2d.
block2_input_filters = num_filters[0]
if use_bev:
self.bev_extractor = Sequential(
Conv2d(6, 32, 3, padding=1),
BatchNorm2d(32),
nn.ReLU(),
# nn.MaxPool2d(2, 2),
Conv2d(32, 64, 3, padding=1),
BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2),
)
block2_input_filters += 64
self.block1 = Sequential(
nn.ZeroPad2d(1),
Conv2d(
num_input_filters, num_filters[0], 3, stride=layer_strides[0]),
BatchNorm2d(num_filters[0]),
nn.ReLU(),
)
for i in range(layer_nums[0]):
self.block1.add(
Conv2d(num_filters[0], num_filters[0], 3, padding=1))
self.block1.add(BatchNorm2d(num_filters[0]))
self.block1.add(nn.ReLU())
self.deconv1 = Sequential(
ConvTranspose2d(
num_filters[0],
num_upsample_filters[0],
upsample_strides[0],
stride=upsample_strides[0]),
BatchNorm2d(num_upsample_filters[0]),
nn.ReLU(),
)
self.block2 = Sequential(
nn.ZeroPad2d(1),
Conv2d(
block2_input_filters,
num_filters[1],
3,
stride=layer_strides[1]),
BatchNorm2d(num_filters[1]),
nn.ReLU(),
)
for i in range(layer_nums[1]):
self.block2.add(
Conv2d(num_filters[1], num_filters[1], 3, padding=1))
self.block2.add(BatchNorm2d(num_filters[1]))
self.block2.add(nn.ReLU())
self.deconv2 = Sequential(
ConvTranspose2d(
num_filters[1],
num_upsample_filters[1],
upsample_strides[1],
stride=upsample_strides[1]),
BatchNorm2d(num_upsample_filters[1]),
nn.ReLU(),
)
self.block3 = Sequential(
nn.ZeroPad2d(1),
Conv2d(num_filters[1], num_filters[2], 3, stride=layer_strides[2]),
BatchNorm2d(num_filters[2]),
nn.ReLU(),
)
for i in range(layer_nums[2]):
self.block3.add(
Conv2d(num_filters[2], num_filters[2], 3, padding=1))
self.block3.add(BatchNorm2d(num_filters[2]))
self.block3.add(nn.ReLU())
self.deconv3 = Sequential(
ConvTranspose2d(
num_filters[2],
num_upsample_filters[2],
upsample_strides[2],
stride=upsample_strides[2]),
BatchNorm2d(num_upsample_filters[2]),
nn.ReLU(),
)
if encode_background_as_zeros:
num_cls = num_anchor_per_loc * num_class
else:
num_cls = num_anchor_per_loc * (num_class + 1)
self.conv_cls = nn.Conv2d(sum(num_upsample_filters), num_cls, 1)
self.conv_box = nn.Conv2d(
sum(num_upsample_filters), num_anchor_per_loc * box_code_size, 1)
if use_direction_classifier:
self.conv_dir_cls = nn.Conv2d(
sum(num_upsample_filters), num_anchor_per_loc * 2, 1)
def forward(self, x, bev=None):
x = self.block1(x)
up1 = self.deconv1(x)
if self._use_bev:
bev[:, -1] = torch.clamp(
torch.log(1 + bev[:, -1]) / np.log(16.0), max=1.0)
x = torch.cat([x, self.bev_extractor(bev)], dim=1)
x = self.block2(x)
up2 = self.deconv2(x)
x = self.block3(x)
up3 = self.deconv3(x)
x = torch.cat([up1, up2, up3], dim=1)
box_preds = self.conv_box(x)
cls_preds = self.conv_cls(x)
# [N, C, y(H), x(W)]
box_preds = box_preds.permute(0, 2, 3, 1).contiguous()
cls_preds = cls_preds.permute(0, 2, 3, 1).contiguous()
ret_dict = {
"box_preds": box_preds,
"cls_preds": cls_preds,
}
if self._use_direction_classifier:
dir_cls_preds = self.conv_dir_cls(x)
dir_cls_preds = dir_cls_preds.permute(0, 2, 3, 1).contiguous()
ret_dict["dir_cls_preds"] = dir_cls_preds
return ret_dict
class LossNormType(Enum):
NormByNumPositives = "norm_by_num_positives"
NormByNumExamples = "norm_by_num_examples"
NormByNumPosNeg = "norm_by_num_pos_neg"
class VoxelNet(nn.Module):
def __init__(self,
output_shape,
num_class=2,
num_input_features=4,
vfe_class_name="VoxelFeatureExtractor",
vfe_num_filters=[32, 128],
with_distance=False,
middle_class_name="SparseMiddleExtractor",
middle_num_filters_d1=[64],
middle_num_filters_d2=[64, 64],
rpn_class_name="RPN",
rpn_layer_nums=[3, 5, 5],
rpn_layer_strides=[2, 2, 2],
rpn_num_filters=[128, 128, 256],
rpn_upsample_strides=[1, 2, 4],
rpn_num_upsample_filters=[256, 256, 256],
use_norm=True,
use_groupnorm=False,
num_groups=32,
use_sparse_rpn=False,
use_direction_classifier=True,
use_sigmoid_score=False,
encode_background_as_zeros=True,
use_rotate_nms=True,
multiclass_nms=False,
nms_score_threshold=0.5,
nms_pre_max_size=1000,
nms_post_max_size=20,
nms_iou_threshold=0.1,
target_assigner=None,
use_bev=False,
lidar_only=False,
cls_loss_weight=1.0,
loc_loss_weight=1.0,
pos_cls_weight=1.0,
neg_cls_weight=1.0,
direction_loss_weight=1.0,
loss_norm_type=LossNormType.NormByNumPositives,
encode_rad_error_by_sin=False,
loc_loss_ftor=None,
cls_loss_ftor=None,
name='voxelnet'):
super().__init__()
self.name = name
self._num_class = num_class
self._use_rotate_nms = use_rotate_nms
self._multiclass_nms = multiclass_nms
self._nms_score_threshold = nms_score_threshold
self._nms_pre_max_size = nms_pre_max_size
self._nms_post_max_size = nms_post_max_size
self._nms_iou_threshold = nms_iou_threshold
self._use_sigmoid_score = use_sigmoid_score
self._encode_background_as_zeros = encode_background_as_zeros
self._use_sparse_rpn = use_sparse_rpn
self._use_direction_classifier = use_direction_classifier
self._use_bev = use_bev
self._total_forward_time = 0.0
self._total_postprocess_time = 0.0
self._total_inference_count = 0
self._num_input_features = num_input_features
self._box_coder = target_assigner.box_coder
self._lidar_only = lidar_only
self.target_assigner = target_assigner
self._pos_cls_weight = pos_cls_weight
self._neg_cls_weight = neg_cls_weight
self._encode_rad_error_by_sin = encode_rad_error_by_sin
self._loss_norm_type = loss_norm_type
self._dir_loss_ftor = WeightedSoftmaxClassificationLoss()
self._loc_loss_ftor = loc_loss_ftor
self._cls_loss_ftor = cls_loss_ftor
self._direction_loss_weight = direction_loss_weight
self._cls_loss_weight = cls_loss_weight
self._loc_loss_weight = loc_loss_weight
vfe_class_dict = {
"VoxelFeatureExtractor": VoxelFeatureExtractor,
"VoxelFeatureExtractorV2": VoxelFeatureExtractorV2,
}
vfe_class = vfe_class_dict[vfe_class_name]
self.voxel_feature_extractor = vfe_class(
num_input_features,
use_norm,
num_filters=vfe_num_filters,
with_distance=with_distance)
mid_class_dict = {
"MiddleExtractor": MiddleExtractor,
"SparseMiddleExtractor": SparseMiddleExtractor,
}
mid_class = mid_class_dict[middle_class_name]
self.middle_feature_extractor = mid_class(
output_shape,
use_norm,
num_input_features=vfe_num_filters[-1],
num_filters_down1=middle_num_filters_d1,
num_filters_down2=middle_num_filters_d2)
if len(middle_num_filters_d2) == 0:
if len(middle_num_filters_d1) == 0:
num_rpn_input_filters = vfe_num_filters[-1]
else:
num_rpn_input_filters = middle_num_filters_d1[-1]
else:
num_rpn_input_filters = middle_num_filters_d2[-1]
rpn_class_dict = {
"RPN": RPN,
}
rpn_class = rpn_class_dict[rpn_class_name]
self.rpn = rpn_class(
use_norm=True,
num_class=num_class,
layer_nums=rpn_layer_nums,
layer_strides=rpn_layer_strides,
num_filters=rpn_num_filters,
upsample_strides=rpn_upsample_strides,
num_upsample_filters=rpn_num_upsample_filters,
num_input_filters=num_rpn_input_filters * 2,
num_anchor_per_loc=target_assigner.num_anchors_per_location,
encode_background_as_zeros=encode_background_as_zeros,
use_direction_classifier=use_direction_classifier,
use_bev=use_bev,
use_groupnorm=use_groupnorm,
num_groups=num_groups,
box_code_size=target_assigner.box_coder.code_size)
self.rpn_acc = metrics.Accuracy(
dim=-1, encode_background_as_zeros=encode_background_as_zeros)
self.rpn_precision = metrics.Precision(dim=-1)
self.rpn_recall = metrics.Recall(dim=-1)
self.rpn_metrics = metrics.PrecisionRecall(
dim=-1,
thresholds=[0.1, 0.3, 0.5, 0.7, 0.8, 0.9, 0.95],
use_sigmoid_score=use_sigmoid_score,
encode_background_as_zeros=encode_background_as_zeros)
self.rpn_cls_loss = metrics.Scalar()
self.rpn_loc_loss = metrics.Scalar()
self.rpn_total_loss = metrics.Scalar()
self.register_buffer("global_step", torch.LongTensor(1).zero_())
def update_global_step(self):
self.global_step += 1
def get_global_step(self):
return int(self.global_step.cpu().numpy()[0])
def forward(self, example):
"""module's forward should always accept dict and return loss.
"""
voxels = example["voxels"]
num_points = example["num_points"]
coors = example["coordinates"]
batch_anchors = example["anchors"]
batch_size_dev = batch_anchors.shape[0]
t = time.time()
# features: [num_voxels, max_num_points_per_voxel, 7]
# num_points: [num_voxels]
# coors: [num_voxels, 4]
voxel_features = self.voxel_feature_extractor(voxels, num_points)
if self._use_sparse_rpn:
preds_dict = self.sparse_rpn(voxel_features, coors, batch_size_dev)
else:
spatial_features = self.middle_feature_extractor(
voxel_features, coors, batch_size_dev)
if self._use_bev:
preds_dict = self.rpn(spatial_features, example["bev_map"])
else:
preds_dict = self.rpn(spatial_features)
# preds_dict["voxel_features"] = voxel_features
# preds_dict["spatial_features"] = spatial_features
box_preds = preds_dict["box_preds"]
cls_preds = preds_dict["cls_preds"]
self._total_forward_time += time.time() - t
if self.training:
labels = example['labels']
reg_targets = example['reg_targets']
cls_weights, reg_weights, cared = prepare_loss_weights(
labels,
pos_cls_weight=self._pos_cls_weight,
neg_cls_weight=self._neg_cls_weight,
loss_norm_type=self._loss_norm_type,
dtype=voxels.dtype)
cls_targets = labels * cared.type_as(labels)
cls_targets = cls_targets.unsqueeze(-1)
loc_loss, cls_loss = create_loss(
self._loc_loss_ftor,
self._cls_loss_ftor,
box_preds=box_preds,
cls_preds=cls_preds,
cls_targets=cls_targets,
cls_weights=cls_weights,
reg_targets=reg_targets,
reg_weights=reg_weights,
num_class=self._num_class,
encode_rad_error_by_sin=self._encode_rad_error_by_sin,
encode_background_as_zeros=self._encode_background_as_zeros,
box_code_size=self._box_coder.code_size,
)
loc_loss_reduced = loc_loss.sum() / batch_size_dev
loc_loss_reduced *= self._loc_loss_weight
cls_pos_loss, cls_neg_loss = _get_pos_neg_loss(cls_loss, labels)
cls_pos_loss /= self._pos_cls_weight
cls_neg_loss /= self._neg_cls_weight
cls_loss_reduced = cls_loss.sum() / batch_size_dev
cls_loss_reduced *= self._cls_loss_weight
loss = loc_loss_reduced + cls_loss_reduced
if self._use_direction_classifier:
dir_targets = get_direction_target(example['anchors'],
reg_targets)
dir_logits = preds_dict["dir_cls_preds"].view(
batch_size_dev, -1, 2)
weights = (labels > 0).type_as(dir_logits)
weights /= torch.clamp(weights.sum(-1, keepdim=True), min=1.0)
dir_loss = self._dir_loss_ftor(
dir_logits, dir_targets, weights=weights)
dir_loss = dir_loss.sum() / batch_size_dev
loss += dir_loss * self._direction_loss_weight
return {
"loss": loss,
"cls_loss": cls_loss,
"loc_loss": loc_loss,
"cls_pos_loss": cls_pos_loss,
"cls_neg_loss": cls_neg_loss,
"cls_preds": cls_preds,
"dir_loss_reduced": dir_loss,
"cls_loss_reduced": cls_loss_reduced,
"loc_loss_reduced": loc_loss_reduced,
"cared": cared,
}
else:
return self.predict(example, preds_dict)
def predict(self, example, preds_dict):
t = time.time()
batch_size = example['anchors'].shape[0]
batch_anchors = example["anchors"].view(batch_size, -1, 7)
self._total_inference_count += batch_size
batch_rect = example["rect"]
batch_Trv2c = example["Trv2c"]
batch_P2 = example["P2"]
if "anchors_mask" not in example:
batch_anchors_mask = [None] * batch_size
else:
batch_anchors_mask = example["anchors_mask"].view(batch_size, -1)
batch_imgidx = example['image_idx']
self._total_forward_time += time.time() - t
t = time.time()
batch_box_preds = preds_dict["box_preds"]
batch_cls_preds = preds_dict["cls_preds"]
batch_box_preds = batch_box_preds.view(batch_size, -1,
self._box_coder.code_size)
num_class_with_bg = self._num_class
if not self._encode_background_as_zeros:
num_class_with_bg = self._num_class + 1
batch_cls_preds = batch_cls_preds.view(batch_size, -1,
num_class_with_bg)
batch_box_preds = self._box_coder.decode_torch(batch_box_preds,
batch_anchors)
if self._use_direction_classifier:
batch_dir_preds = preds_dict["dir_cls_preds"]
batch_dir_preds = batch_dir_preds.view(batch_size, -1, 2)
else:
batch_dir_preds = [None] * batch_size
predictions_dicts = []
for box_preds, cls_preds, dir_preds, rect, Trv2c, P2, img_idx, a_mask in zip(
batch_box_preds, batch_cls_preds, batch_dir_preds, batch_rect,
batch_Trv2c, batch_P2, batch_imgidx, batch_anchors_mask
):
if a_mask is not None:
box_preds = box_preds[a_mask]
cls_preds = cls_preds[a_mask]
if self._use_direction_classifier:
if a_mask is not None:
dir_preds = dir_preds[a_mask]
# print(dir_preds.shape)
dir_labels = torch.max(dir_preds, dim=-1)[1]
if self._encode_background_as_zeros:
# this don't support softmax
assert self._use_sigmoid_score is True
total_scores = torch.sigmoid(cls_preds)
else:
# encode background as first element in one-hot vector
if self._use_sigmoid_score:
total_scores = torch.sigmoid(cls_preds)[..., 1:]
else:
total_scores = F.softmax(cls_preds, dim=-1)[..., 1:]
# Apply NMS in birdeye view
if self._use_rotate_nms:
nms_func = box_torch_ops.rotate_nms
else:
nms_func = box_torch_ops.nms
selected_boxes = None
selected_labels = None
selected_scores = None
selected_dir_labels = None
if self._multiclass_nms:
# curently only support class-agnostic boxes.
boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
if not self._use_rotate_nms:
box_preds_corners = box_torch_ops.center_to_corner_box2d(
boxes_for_nms[:, :2], boxes_for_nms[:, 2:4],
boxes_for_nms[:, 4])
boxes_for_nms = box_torch_ops.corner_to_standup_nd(
box_preds_corners)
boxes_for_mcnms = boxes_for_nms.unsqueeze(1)
selected_per_class = box_torch_ops.multiclass_nms(
nms_func=nms_func,
boxes=boxes_for_mcnms,
scores=total_scores,
num_class=self._num_class,
pre_max_size=self._nms_pre_max_size,
post_max_size=self._nms_post_max_size,
iou_threshold=self._nms_iou_threshold,
score_thresh=self._nms_score_threshold,
)
selected_boxes, selected_labels, selected_scores = [], [], []
selected_dir_labels = []
for i, selected in enumerate(selected_per_class):
if selected is not None:
num_dets = selected.shape[0]
selected_boxes.append(box_preds[selected])
selected_labels.append(
torch.full([num_dets], i, dtype=torch.int64))
if self._use_direction_classifier:
selected_dir_labels.append(dir_labels[selected])
selected_scores.append(total_scores[selected, i])
if len(selected_boxes) > 0:
selected_boxes = torch.cat(selected_boxes, dim=0)
selected_labels = torch.cat(selected_labels, dim=0)
selected_scores = torch.cat(selected_scores, dim=0)
if self._use_direction_classifier:
selected_dir_labels = torch.cat(
selected_dir_labels, dim=0)
else:
selected_boxes = None
selected_labels = None
selected_scores = None
selected_dir_labels = None
else:
# get highest score per prediction, than apply nms
# to remove overlapped box.
if num_class_with_bg == 1:
top_scores = total_scores.squeeze(-1)
top_labels = torch.zeros(
total_scores.shape[0],
device=total_scores.device,
dtype=torch.long)
else:
top_scores, top_labels = torch.max(total_scores, dim=-1)
if self._nms_score_threshold > 0.0:
thresh = torch.tensor(
[self._nms_score_threshold],
device=total_scores.device).type_as(total_scores)
top_scores_keep = (top_scores >= thresh)
top_scores = top_scores.masked_select(top_scores_keep)
if top_scores.shape[0] != 0:
if self._nms_score_threshold > 0.0:
box_preds = box_preds[top_scores_keep]
if self._use_direction_classifier:
dir_labels = dir_labels[top_scores_keep]
top_labels = top_labels[top_scores_keep]
boxes_for_nms = box_preds[:, [0, 1, 3, 4, 6]]
if not self._use_rotate_nms:
box_preds_corners = box_torch_ops.center_to_corner_box2d(
boxes_for_nms[:, :2], boxes_for_nms[:, 2:4],
boxes_for_nms[:, 4])
boxes_for_nms = box_torch_ops.corner_to_standup_nd(
box_preds_corners)
# the nms in 3d detection just remove overlap boxes.
selected = nms_func(
boxes_for_nms,
top_scores,
pre_max_size=self._nms_pre_max_size,
post_max_size=self._nms_post_max_size,
iou_threshold=self._nms_iou_threshold,
)
else:
selected = None
if selected is not None:
selected_boxes = box_preds[selected]
if self._use_direction_classifier:
selected_dir_labels = dir_labels[selected]
selected_labels = top_labels[selected]
selected_scores = top_scores[selected]
# finally generate predictions.
if selected_boxes is not None:
box_preds = selected_boxes
scores = selected_scores
label_preds = selected_labels
if self._use_direction_classifier:
dir_labels = selected_dir_labels
opp_labels = (box_preds[..., -1] > 0) ^ dir_labels.byte()
box_preds[..., -1] += torch.where(
opp_labels,
torch.tensor(np.pi).type_as(box_preds),
torch.tensor(0.0).type_as(box_preds))
# box_preds[..., -1] += (
# ~(dir_labels.byte())).type_as(box_preds) * np.pi
final_box_preds = box_preds
final_scores = scores
final_labels = label_preds
final_box_preds_camera = box_torch_ops.box_lidar_to_camera(
final_box_preds, rect, Trv2c)
locs = final_box_preds_camera[:, :3]
dims = final_box_preds_camera[:, 3:6]
angles = final_box_preds_camera[:, 6]
camera_box_origin = [0.5, 1.0, 0.5]
box_corners = box_torch_ops.center_to_corner_box3d(
locs, dims, angles, camera_box_origin, axis=1)
box_corners_in_image = box_torch_ops.project_to_image(
box_corners, P2)
# box_corners_in_image: [N, 8, 2]
minxy = torch.min(box_corners_in_image, dim=1)[0]
maxxy = torch.max(box_corners_in_image, dim=1)[0]
# minx = torch.min(box_corners_in_image[..., 0], dim=1)[0]
# maxx = torch.max(box_corners_in_image[..., 0], dim=1)[0]
# miny = torch.min(box_corners_in_image[..., 1], dim=1)[0]
# maxy = torch.max(box_corners_in_image[..., 1], dim=1)[0]
# box_2d_preds = torch.stack([minx, miny, maxx, maxy], dim=1)
box_2d_preds = torch.cat([minxy, maxxy], dim=1)
# predictions
predictions_dict = {
"bbox": box_2d_preds,
"box3d_camera": final_box_preds_camera,
"box3d_lidar": final_box_preds,
"scores": final_scores,
"label_preds": label_preds,
"image_idx": img_idx,
}
else:
predictions_dict = {
"bbox": None,
"box3d_camera": None,
"box3d_lidar": None,
"scores": None,
"label_preds": None,
"image_idx": img_idx,
}
predictions_dicts.append(predictions_dict)
self._total_postprocess_time += time.time() - t
return predictions_dicts
@property
def avg_forward_time(self):
return self._total_forward_time / self._total_inference_count
@property
def avg_postprocess_time(self):
return self._total_postprocess_time / self._total_inference_count
def clear_time_metrics(self):
self._total_forward_time = 0.0
self._total_postprocess_time = 0.0
self._total_inference_count = 0
def metrics_to_float(self):
self.rpn_acc.float()
self.rpn_metrics.float()
self.rpn_cls_loss.float()
self.rpn_loc_loss.float()
self.rpn_total_loss.float()
def update_metrics(self,
cls_loss,
loc_loss,
cls_preds,
labels,
sampled):
batch_size = cls_preds.shape[0]
num_class = self._num_class
if not self._encode_background_as_zeros:
num_class += 1
cls_preds = cls_preds.view(batch_size, -1, num_class)
rpn_acc = self.rpn_acc(labels, cls_preds, sampled).numpy()[0]
prec, recall = self.rpn_metrics(labels, cls_preds, sampled)
prec = prec.numpy()
recall = recall.numpy()
rpn_cls_loss = self.rpn_cls_loss(cls_loss).numpy()[0]