/
gen_efficientnet.py
1679 lines (1444 loc) · 59.7 KB
/
gen_efficientnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
""" Generic EfficientNets
A generic class with building blocks to support a variety of models with efficient architectures:
* EfficientNet (B0-B5)
* MixNet (Small, Medium, and Large)
* MnasNet B1, A1 (SE), Small
* MobileNet V1, V2, and V3
* FBNet-C (TODO A & B)
* ChamNet (TODO still guessing at architecture definition)
* Single-Path NAS Pixel1
* And likely more...
TODO not all combinations and variations have been tested. Currently working on training hyper-params...
Hacked together by Ross Wightman
"""
import math
import re
import logging
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .conv2d_helpers import select_conv2d
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
__all__ = ['GenEfficientNet']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'mnasnet_050': _cfg(url=''),
'mnasnet_075': _cfg(url=''),
'mnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth'),
'mnasnet_140': _cfg(url=''),
'semnasnet_050': _cfg(url=''),
'semnasnet_075': _cfg(url=''),
'semnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth'),
'semnasnet_140': _cfg(url=''),
'mnasnet_small': _cfg(url=''),
'mobilenetv1_100': _cfg(url=''),
'mobilenetv2_100': _cfg(url=''),
'mobilenetv3_050': _cfg(url=''),
'mobilenetv3_075': _cfg(url=''),
'mobilenetv3_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth'),
'chamnetv1_100': _cfg(url=''),
'chamnetv2_100': _cfg(url=''),
'fbnetc_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth',
interpolation='bilinear'),
'spnasnet_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
interpolation='bilinear'),
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0-d6904d92.pth'),
'efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'efficientnet_b3': _cfg(
url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'efficientnet_b4': _cfg(
url='', input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'efficientnet_b5': _cfg(
url='', input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0-0af12548.pth',
input_size=(3, 224, 224)),
'tf_efficientnet_b1': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1-5c1377c4.pth',
input_size=(3, 240, 240), pool_size=(8, 8), crop_pct=0.882),
'tf_efficientnet_b2': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2-e393ef04.pth',
input_size=(3, 260, 260), pool_size=(9, 9), crop_pct=0.890),
'tf_efficientnet_b3': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3-e3bd6955.pth',
input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904),
'tf_efficientnet_b4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4-74ee3bed.pth',
input_size=(3, 380, 380), pool_size=(12, 12), crop_pct=0.922),
'tf_efficientnet_b5': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5-c6949ce9.pth',
input_size=(3, 456, 456), pool_size=(15, 15), crop_pct=0.934),
'mixnet_s': _cfg(url=''),
'mixnet_m': _cfg(url=''),
'mixnet_l': _cfg(url=''),
'tf_mixnet_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
'tf_mixnet_m': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth'),
'tf_mixnet_l': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth'),
}
_DEBUG = False
# Default args for PyTorch BN impl
_BN_MOMENTUM_PT_DEFAULT = 0.1
_BN_EPS_PT_DEFAULT = 1e-5
_BN_ARGS_PT = dict(momentum=_BN_MOMENTUM_PT_DEFAULT, eps=_BN_EPS_PT_DEFAULT)
# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per
# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay)
# NOTE: momentum varies btw .99 and .9997 depending on source
# .99 in official TF TPU impl
# .9997 (/w .999 in search space) for paper
_BN_MOMENTUM_TF_DEFAULT = 1 - 0.99
_BN_EPS_TF_DEFAULT = 1e-3
_BN_ARGS_TF = dict(momentum=_BN_MOMENTUM_TF_DEFAULT, eps=_BN_EPS_TF_DEFAULT)
def _resolve_bn_args(kwargs):
bn_args = _BN_ARGS_TF.copy() if kwargs.pop('bn_tf', False) else _BN_ARGS_PT.copy()
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum
bn_eps = kwargs.pop('bn_eps', None)
if bn_eps is not None:
bn_args['eps'] = bn_eps
return bn_args
def _round_channels(channels, multiplier=1.0, divisor=8, channel_min=None):
"""Round number of filters based on depth multiplier."""
if not multiplier:
return channels
channels *= multiplier
channel_min = channel_min or divisor
new_channels = max(
int(channels + divisor / 2) // divisor * divisor,
channel_min)
# Make sure that round down does not go down by more than 10%.
if new_channels < 0.9 * channels:
new_channels += divisor
return new_channels
def _parse_ksize(ss):
if ss.isdigit():
return int(ss)
else:
return [int(k) for k in ss.split('.')]
def _decode_block_str(block_str, depth_multiplier=1.0):
""" Decode block definition string
Gets a list of block arg (dicts) through a string notation of arguments.
E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip
All args can exist in any order with the exception of the leading string which
is assumed to indicate the block type.
leading string - block type (
ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct)
r - number of repeat blocks,
k - kernel size,
s - strides (1-9),
e - expansion ratio,
c - output channels,
se - squeeze/excitation ratio
n - activation fn ('re', 'r6', 'hs', or 'sw')
Args:
block_str: a string representation of block arguments.
Returns:
A list of block args (dicts)
Raises:
ValueError: if the string def not properly specified (TODO)
"""
assert isinstance(block_str, str)
ops = block_str.split('_')
block_type = ops[0] # take the block type off the front
ops = ops[1:]
options = {}
noskip = False
for op in ops:
# string options being checked on individual basis, combine if they grow
if op == 'noskip':
noskip = True
elif op.startswith('n'):
# activation fn
key = op[0]
v = op[1:]
if v == 're':
value = F.relu
elif v == 'r6':
value = F.relu6
elif v == 'hs':
value = hard_swish
elif v == 'sw':
value = swish
else:
continue
options[key] = value
else:
# all numeric options
splits = re.split(r'(\d.*)', op)
if len(splits) >= 2:
key, value = splits[:2]
options[key] = value
# if act_fn is None, the model default (passed to model init) will be used
act_fn = options['n'] if 'n' in options else None
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
if block_type == 'ir':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
noskip=noskip,
)
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else None,
stride=int(options['s']),
act_fn=act_fn,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or noskip,
)
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_fn=act_fn,
)
else:
assert False, 'Unknown block type (%s)' % block_type
# return a list of block args expanded by num_repeat and
# scaled by depth_multiplier
num_repeat = int(math.ceil(num_repeat * depth_multiplier))
return [deepcopy(block_args) for _ in range(num_repeat)]
def _decode_arch_args(string_list):
block_args = []
for block_str in string_list:
block_args.append(_decode_block_str(block_str))
return block_args
def _decode_arch_def(arch_def, depth_multiplier=1.0):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
assert isinstance(block_strings, list)
stack_args = []
for block_str in block_strings:
assert isinstance(block_str, str)
stack_args.extend(_decode_block_str(block_str, depth_multiplier))
arch_args.append(stack_args)
return arch_args
def swish(x, inplace=False):
if inplace:
return x.mul_(x.sigmoid())
else:
return x * x.sigmoid()
def sigmoid(x, inplace=False):
return x.sigmoid_() if inplace else x.sigmoid()
def hard_swish(x, inplace=False):
if inplace:
return x.mul_(F.relu6(x + 3.) / 6.)
else:
return x * F.relu6(x + 3.) / 6.
def hard_sigmoid(x, inplace=False):
if inplace:
return x.add_(3.).clamp_(0., 6.).div_(6.)
else:
return F.relu6(x + 3.) / 6.
class _BlockBuilder:
""" Build Trunk Blocks
This ended up being somewhat of a cross between
https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py
and
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_fn=None, se_gate_fn=sigmoid, se_reduce_mid=False,
bn_args=_BN_ARGS_PT, drop_connect_rate=0., verbose=False):
self.channel_multiplier = channel_multiplier
self.channel_divisor = channel_divisor
self.channel_min = channel_min
self.pad_type = pad_type
self.act_fn = act_fn
self.se_gate_fn = se_gate_fn
self.se_reduce_mid = se_reduce_mid
self.bn_args = bn_args
self.drop_connect_rate = drop_connect_rate
self.verbose = verbose
# updated during build
self.in_chs = None
self.block_idx = 0
self.block_count = 0
def _round_channels(self, chs):
return _round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min)
def _make_block(self, ba):
bt = ba.pop('block_type')
ba['in_chs'] = self.in_chs
ba['out_chs'] = self._round_channels(ba['out_chs'])
ba['bn_args'] = self.bn_args
ba['pad_type'] = self.pad_type
# block act fn overrides the model default
ba['act_fn'] = ba['act_fn'] if ba['act_fn'] is not None else self.act_fn
assert ba['act_fn'] is not None
if bt == 'ir':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
ba['se_gate_fn'] = self.se_gate_fn
ba['se_reduce_mid'] = self.se_reduce_mid
if self.verbose:
logging.info(' InvertedResidual {}, Args: {}'.format(self.block_idx, str(ba)))
block = InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count
if self.verbose:
logging.info(' DepthwiseSeparable {}, Args: {}'.format(self.block_idx, str(ba)))
block = DepthwiseSeparableConv(**ba)
elif bt == 'cn':
if self.verbose:
logging.info(' ConvBnAct {}, Args: {}'.format(self.block_idx, str(ba)))
block = ConvBnAct(**ba)
else:
assert False, 'Uknkown block type (%s) while building model.' % bt
self.in_chs = ba['out_chs'] # update in_chs for arg of next block
return block
def _make_stack(self, stack_args):
blocks = []
# each stack (stage) contains a list of block arguments
for i, ba in enumerate(stack_args):
if self.verbose:
logging.info(' Block: {}'.format(i))
if i >= 1:
# only the first block in any stack can have a stride > 1
ba['stride'] = 1
block = self._make_block(ba)
blocks.append(block)
self.block_idx += 1 # incr global idx (across all stacks)
return nn.Sequential(*blocks)
def __call__(self, in_chs, block_args):
""" Build the blocks
Args:
in_chs: Number of input-channels passed to first block
block_args: A list of lists, outer list defines stages, inner
list contains strings defining block configuration(s)
Return:
List of block stacks (each stack wrapped in nn.Sequential)
"""
if self.verbose:
logging.info('Building model trunk with %d stages...' % len(block_args))
self.in_chs = in_chs
self.block_count = sum([len(x) for x in block_args])
self.block_idx = 0
blocks = []
# outer list of block_args defines the stacks ('stages' by some conventions)
for stack_idx, stack in enumerate(block_args):
if self.verbose:
logging.info('Stack: {}'.format(stack_idx))
assert isinstance(stack, list)
stack = self._make_stack(stack)
blocks.append(stack)
return blocks
def _initialize_weight_goog(m):
# weight init as per Tensorflow Official impl
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out
m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
n = m.weight.size(0) # fan-out
init_range = 1.0 / math.sqrt(n)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
def _initialize_weight_default(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
def drop_connect(inputs, training=False, drop_connect_rate=0.):
"""Apply drop connect."""
if not training:
return inputs
keep_prob = 1 - drop_connect_rate
random_tensor = keep_prob + torch.rand(
(inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device)
random_tensor.floor_() # binarize
output = inputs.div(keep_prob) * random_tensor
return output
class ChannelShuffle(nn.Module):
# FIXME haven't used yet
def __init__(self, groups):
super(ChannelShuffle, self).__init__()
self.groups = groups
def forward(self, x):
"""Channel shuffle: [N,C,H,W] -> [N,g,C/g,H,W] -> [N,C/g,g,H,w] -> [N,C,H,W]"""
N, C, H, W = x.size()
g = self.groups
assert C % g == 0, "Incompatible group size {} for input channel {}".format(
g, C
)
return (
x.view(N, g, int(C / g), H, W)
.permute(0, 2, 1, 3, 4)
.contiguous()
.view(N, C, H, W)
)
class SqueezeExcite(nn.Module):
def __init__(self, in_chs, reduce_chs=None, act_fn=F.relu, gate_fn=sigmoid):
super(SqueezeExcite, self).__init__()
self.act_fn = act_fn
self.gate_fn = gate_fn
reduced_chs = reduce_chs or in_chs
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
def forward(self, x):
# NOTE adaptiveavgpool can be used here, but seems to cause issues with NVIDIA AMP performance
x_se = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
x_se = self.conv_reduce(x_se)
x_se = self.act_fn(x_se, inplace=True)
x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size,
stride=1, pad_type='', act_fn=F.relu, bn_args=_BN_ARGS_PT):
super(ConvBnAct, self).__init__()
assert stride in [1, 2]
self.act_fn = act_fn
self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type)
self.bn1 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
x = self.conv(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
return x
class DepthwiseSeparableConv(nn.Module):
""" DepthwiseSeparable block
Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
factor of 1.0. This is an alternative to having a IR with optional first pw conv.
"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=F.relu, noskip=False,
pw_kernel_size=1, pw_act=False,
se_ratio=0., se_gate_fn=sigmoid,
bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
assert stride in [1, 2]
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
self.conv_dw = select_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
self.bn1 = nn.BatchNorm2d(in_chs, **bn_args)
# Squeeze-and-excitation
if self.has_se:
self.se = SqueezeExcite(
in_chs, reduce_chs=max(1, int(in_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
residual = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
if self.has_se:
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
if self.has_pw_act:
x = self.act_fn(x, inplace=True)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
return x
class InvertedResidual(nn.Module):
""" Inverted residual block w/ optional SE"""
def __init__(self, in_chs, out_chs, dw_kernel_size=3,
stride=1, pad_type='', act_fn=F.relu, noskip=False,
exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1,
se_ratio=0., se_reduce_mid=False, se_gate_fn=sigmoid,
shuffle_type=None, bn_args=_BN_ARGS_PT, drop_connect_rate=0.):
super(InvertedResidual, self).__init__()
mid_chs = int(in_chs * exp_ratio)
self.has_se = se_ratio is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.act_fn = act_fn
self.drop_connect_rate = drop_connect_rate
# Point-wise expansion
self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type)
self.bn1 = nn.BatchNorm2d(mid_chs, **bn_args)
self.shuffle_type = shuffle_type
if shuffle_type is not None and isinstance(exp_kernel_size, list):
self.shuffle = ChannelShuffle(len(exp_kernel_size))
# Depth-wise convolution
self.conv_dw = select_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True)
self.bn2 = nn.BatchNorm2d(mid_chs, **bn_args)
# Squeeze-and-excitation
if self.has_se:
se_base_chs = mid_chs if se_reduce_mid else in_chs
self.se = SqueezeExcite(
mid_chs, reduce_chs=max(1, int(se_base_chs * se_ratio)), act_fn=act_fn, gate_fn=se_gate_fn)
# Point-wise linear projection
self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn3 = nn.BatchNorm2d(out_chs, **bn_args)
def forward(self, x):
residual = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
# FIXME haven't tried this yet
# for channel shuffle when using groups with pointwise convs as per FBNet variants
if self.shuffle_type == "mid":
x = self.shuffle(x)
# Depth-wise convolution
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act_fn(x, inplace=True)
# Squeeze-and-excitation
if self.has_se:
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_connect_rate > 0.:
x = drop_connect(x, self.training, self.drop_connect_rate)
x += residual
# NOTE maskrcnn_benchmark building blocks have an SE module defined here for some variants
return x
class GenEfficientNet(nn.Module):
""" Generic EfficientNet
An implementation of efficient network architectures, in many cases mobile optimized networks:
* MobileNet-V1
* MobileNet-V2
* MobileNet-V3
* MnasNet A1, B1, and small
* FBNet A, B, and C
* ChamNet (arch details are murky)
* Single-Path NAS Pixel1
* EfficientNet B0-B5
* MixNet S, M, L
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
pad_type='', act_fn=F.relu, drop_rate=0., drop_connect_rate=0.,
se_gate_fn=sigmoid, se_reduce_mid=False, bn_args=_BN_ARGS_PT,
global_pool='avg', head_conv='default', weight_init='goog'):
super(GenEfficientNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.act_fn = act_fn
self.num_features = num_features
stem_size = _round_channels(stem_size, channel_multiplier, channel_divisor, channel_min)
self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = nn.BatchNorm2d(stem_size, **bn_args)
in_chs = stem_size
builder = _BlockBuilder(
channel_multiplier, channel_divisor, channel_min,
pad_type, act_fn, se_gate_fn, se_reduce_mid,
bn_args, drop_connect_rate, verbose=_DEBUG)
self.blocks = nn.Sequential(*builder(in_chs, block_args))
in_chs = builder.in_chs
if not head_conv or head_conv == 'none':
self.efficient_head = False
self.conv_head = None
assert in_chs == self.num_features
else:
self.efficient_head = head_conv == 'efficient'
self.conv_head = select_conv2d(in_chs, self.num_features, 1, padding=pad_type)
self.bn2 = None if self.efficient_head else nn.BatchNorm2d(self.num_features, **bn_args)
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
for m in self.modules():
if weight_init == 'goog':
_initialize_weight_goog(m)
else:
_initialize_weight_default(m)
def get_classifier(self):
return self.classifier
def reset_classifier(self, num_classes, global_pool='avg'):
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.num_classes = num_classes
del self.classifier
if num_classes:
self.classifier = nn.Linear(
self.num_features * self.global_pool.feat_mult(), num_classes)
else:
self.classifier = None
def forward_features(self, x, pool=True):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act_fn(x, inplace=True)
x = self.blocks(x)
if self.efficient_head:
# efficient head, currently only mobilenet-v3 performs pool before last 1x1 conv
x = self.global_pool(x) # always need to pool here regardless of flag
x = self.conv_head(x)
# no BN
x = self.act_fn(x, inplace=True)
if pool:
# expect flattened output if pool is true, otherwise keep dim
x = x.view(x.size(0), -1)
else:
if self.conv_head is not None:
x = self.conv_head(x)
x = self.bn2(x)
x = self.act_fn(x, inplace=True)
if pool:
x = self.global_pool(x)
x = x.view(x.size(0), -1)
return x
def forward(self, x):
x = self.forward_features(x)
if self.drop_rate > 0.:
x = F.dropout(x, p=self.drop_rate, training=self.training)
return self.classifier(x)
def _gen_mnasnet_a1(channel_multiplier, num_classes=1000, **kwargs):
"""Creates a mnasnet-a1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_noskip'],
# stage 1, 112x112 in
['ir_r2_k3_s2_e6_c24'],
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25'],
# stage 3, 28x28 in
['ir_r4_k3_s2_e6_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'],
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'],
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320'],
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
def _gen_mnasnet_b1(channel_multiplier, num_classes=1000, **kwargs):
"""Creates a mnasnet-b1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_c16_noskip'],
# stage 1, 112x112 in
['ir_r3_k3_s2_e3_c24'],
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40'],
# stage 3, 28x28 in
['ir_r3_k5_s2_e6_c80'],
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c96'],
# stage 5, 14x14in
['ir_r4_k5_s2_e6_c192'],
# stage 6, 7x7 in
['ir_r1_k3_s1_e6_c320_noskip']
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
def _gen_mnasnet_small(channel_multiplier, num_classes=1000, **kwargs):
"""Creates a mnasnet-b1 model.
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet
Paper: https://arxiv.org/pdf/1807.11626.pdf.
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
['ds_r1_k3_s1_c8'],
['ir_r1_k3_s2_e3_c16'],
['ir_r2_k3_s2_e6_c16'],
['ir_r4_k5_s2_e6_c32_se0.25'],
['ir_r3_k3_s1_e6_c32_se0.25'],
['ir_r3_k5_s2_e6_c88_se0.25'],
['ir_r1_k3_s1_e6_c144']
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=8,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
def _gen_mobilenet_v1(channel_multiplier, num_classes=1000, **kwargs):
""" Generate MobileNet-V1 network
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
"""
arch_def = [
['dsa_r1_k3_s1_c64'],
['dsa_r2_k3_s2_c128'],
['dsa_r2_k3_s2_c256'],
['dsa_r6_k3_s2_c512'],
['dsa_r2_k3_s2_c1024'],
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=32,
num_features=1024,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu6,
head_conv='none',
**kwargs
)
return model
def _gen_mobilenet_v2(channel_multiplier, num_classes=1000, **kwargs):
""" Generate MobileNet-V2 network
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
Paper: https://arxiv.org/abs/1801.04381
"""
arch_def = [
['ds_r1_k3_s1_c16'],
['ir_r2_k3_s2_e6_c24'],
['ir_r3_k3_s2_e6_c32'],
['ir_r4_k3_s2_e6_c64'],
['ir_r3_k3_s1_e6_c96'],
['ir_r3_k3_s2_e6_c160'],
['ir_r1_k3_s1_e6_c320'],
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=32,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
act_fn=F.relu6,
**kwargs
)
return model
def _gen_mobilenet_v3(channel_multiplier, num_classes=1000, **kwargs):
"""Creates a MobileNet-V3 model.
Ref impl: ?
Paper: https://arxiv.org/abs/1905.02244
Args:
channel_multiplier: multiplier to number of channels per layer.
"""
arch_def = [
# stage 0, 112x112 in
['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu
# stage 1, 112x112 in
['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu
# stage 2, 56x56 in
['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu
# stage 3, 28x28 in
['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish
# stage 4, 14x14in
['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish
# stage 5, 14x14in
['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=16,
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
act_fn=hard_swish,
se_gate_fn=hard_sigmoid,
se_reduce_mid=True,
head_conv='efficient',
**kwargs
)
return model
def _gen_chamnet_v1(channel_multiplier, num_classes=1000, **kwargs):
""" Generate Chameleon Network (ChamNet)
Paper: https://arxiv.org/abs/1812.08934
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
FIXME: this a bit of an educated guess based on trunkd def in maskrcnn_benchmark
"""
arch_def = [
['ir_r1_k3_s1_e1_c24'],
['ir_r2_k7_s2_e4_c48'],
['ir_r5_k3_s2_e7_c64'],
['ir_r7_k5_s2_e12_c56'],
['ir_r5_k3_s1_e8_c88'],
['ir_r4_k3_s2_e7_c152'],
['ir_r1_k3_s1_e10_c104'],
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=32,
num_features=1280, # no idea what this is? try mobile/mnasnet default?
channel_multiplier=channel_multiplier,
channel_divisor=8,
channel_min=None,
bn_args=_resolve_bn_args(kwargs),
**kwargs
)
return model
def _gen_chamnet_v2(channel_multiplier, num_classes=1000, **kwargs):
""" Generate Chameleon Network (ChamNet)
Paper: https://arxiv.org/abs/1812.08934
Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py
FIXME: this a bit of an educated guess based on trunk def in maskrcnn_benchmark
"""
arch_def = [
['ir_r1_k3_s1_e1_c24'],
['ir_r4_k5_s2_e8_c32'],
['ir_r6_k7_s2_e5_c48'],
['ir_r3_k5_s2_e9_c56'],
['ir_r6_k3_s1_e6_c56'],
['ir_r6_k3_s2_e2_c152'],
['ir_r1_k3_s1_e6_c112'],
]
model = GenEfficientNet(
_decode_arch_def(arch_def),
num_classes=num_classes,
stem_size=32,
num_features=1280, # no idea what this is? try mobile/mnasnet default?
channel_multiplier=channel_multiplier,