1+ from typing import Any , Dict , Tuple , Union
12import warnings
23
34import torch
89except ImportError :
910 InPlaceABN = None
1011
12+ def handle_decoder_use_batchnorm_deprecation (decoder_use_batchnorm : Union [bool , str , None ], decoder_use_norm : Union [bool , str , Dict [str , Any ]]) -> Dict [str , Any ]:
13+ if decoder_use_batchnorm is not None :
14+ warnings .warn (
15+ "The usage of use_batchnorm is deprecated. Please modify your code for use_norm" ,
16+ DeprecationWarning ,
17+ )
18+ if decoder_use_batchnorm is True :
19+ decoder_use_norm = {"type" : "batchnorm" }
20+ elif decoder_use_batchnorm is False :
21+ decoder_use_norm = {"type" : "identity" }
22+ elif decoder_use_batchnorm == "inplace" :
23+ decoder_use_norm = {
24+ "type" : "inplace" ,
25+ "activation" : "leaky_relu" ,
26+ "activation_param" : 0.0 ,
27+ }
28+ else :
29+ raise ValueError ("Unrecognized value for use_batchnorm" )
30+
31+ return decoder_use_norm
32+
33+
34+ def normalize_use_norm (use_norm : Union [bool , str , Dict [str , Any ]]) -> Dict [str , Any ]:
35+ if isinstance (use_norm , str ):
36+ norm_str = use_norm .lower ()
37+ if norm_str == "inplace" :
38+ use_norm = {
39+ "type" : "inplace" ,
40+ "activation" : "leaky_relu" ,
41+ "activation_param" : 0.0 ,
42+ }
43+ elif norm_str in (
44+ "batchnorm" ,
45+ "identity" ,
46+ "layernorm" ,
47+ "groupnorm" ,
48+ "instancenorm" ,
49+ ):
50+ use_norm = {"type" : norm_str }
51+ else :
52+ raise ValueError ("Unrecognized normalization type string provided" )
53+ elif isinstance (use_norm , bool ):
54+ use_norm = {"type" : "batchnorm" if use_norm else "identity" }
55+ elif not isinstance (use_norm , dict ):
56+ raise ValueError ("use_norm must be a dictionary, boolean, or string" )
57+
58+ return use_norm
59+
60+ def get_norm_layer (use_norm : Dict [str , Any ], relu : nn .Module , out_channels : int ) -> Tuple [nn .Module , nn .Module ]:
61+ norm_type = use_norm ["type" ]
62+ extra_kwargs = {k : v for k , v in use_norm .items () if k != "type" }
63+
64+ if norm_type == "inplace" :
65+ norm = InPlaceABN (out_channels , ** extra_kwargs )
66+ relu = nn .Identity ()
67+ elif norm_type == "batchnorm" :
68+ norm = nn .BatchNorm2d (out_channels , ** extra_kwargs )
69+ elif norm_type == "identity" :
70+ norm = nn .Identity ()
71+ elif norm_type == "layernorm" :
72+ norm = nn .LayerNorm (out_channels , ** extra_kwargs )
73+ elif norm_type == "groupnorm" :
74+ norm = nn .GroupNorm (out_channels , ** extra_kwargs )
75+ elif norm_type == "instancenorm" :
76+ norm = nn .InstanceNorm2d (out_channels , ** extra_kwargs )
77+ else :
78+ raise ValueError (f"Unrecognized normalization type: { norm_type } " )
79+
80+ return norm , relu
81+
1182
1283class Conv2dReLU (nn .Sequential ):
1384 def __init__ (
@@ -17,50 +88,9 @@ def __init__(
1788 kernel_size ,
1889 padding = 0 ,
1990 stride = 1 ,
20- use_batchnorm = True ,
2191 use_norm = "batchnorm" ,
2292 ):
23- if use_batchnorm is not None :
24- warnings .warn (
25- "The usage of use_batchnorm is deprecated. Please modify your code for use_norm" ,
26- DeprecationWarning ,
27- )
28- if use_batchnorm is True :
29- use_norm = {"type" : "batchnorm" }
30- elif use_batchnorm is False :
31- use_norm = {"type" : "identity" }
32- elif use_batchnorm == "inplace" :
33- use_norm = {
34- "type" : "inplace" ,
35- "activation" : "leaky_relu" ,
36- "activation_param" : 0.0 ,
37- }
38- else :
39- raise ValueError ("Unrecognized value for use_batchnorm" )
40-
41- if isinstance (use_norm , str ):
42- norm_str = use_norm .lower ()
43- if norm_str == "inplace" :
44- use_norm = {
45- "type" : "inplace" ,
46- "activation" : "leaky_relu" ,
47- "activation_param" : 0.0 ,
48- }
49- elif norm_str in (
50- "batchnorm" ,
51- "identity" ,
52- "layernorm" ,
53- "groupnorm" ,
54- "instancenorm" ,
55- ):
56- use_norm = {"type" : norm_str }
57- else :
58- raise ValueError ("Unrecognized normalization type string provided" )
59- elif isinstance (use_norm , bool ):
60- use_norm = {"type" : "batchnorm" if use_norm else "identity" }
61- elif not isinstance (use_norm , dict ):
62- raise ValueError ("use_norm must be a dictionary, boolean, or string" )
63-
93+ use_norm = normalize_use_norm (use_norm )
6494 if use_norm ["type" ] == "inplace" and InPlaceABN is None :
6595 raise RuntimeError (
6696 "In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
@@ -77,24 +107,7 @@ def __init__(
77107 )
78108 relu = nn .ReLU (inplace = True )
79109
80- norm_type = use_norm ["type" ]
81- extra_kwargs = {k : v for k , v in use_norm .items () if k != "type" }
82-
83- if norm_type == "inplace" :
84- norm = InPlaceABN (out_channels , ** extra_kwargs )
85- relu = nn .Identity ()
86- elif norm_type == "batchnorm" :
87- norm = nn .BatchNorm2d (out_channels , ** extra_kwargs )
88- elif norm_type == "identity" :
89- norm = nn .Identity ()
90- elif norm_type == "layernorm" :
91- norm = nn .LayerNorm (out_channels , ** extra_kwargs )
92- elif norm_type == "groupnorm" :
93- norm = nn .GroupNorm (out_channels , ** extra_kwargs )
94- elif norm_type == "instancenorm" :
95- norm = nn .InstanceNorm2d (out_channels , ** extra_kwargs )
96- else :
97- raise ValueError (f"Unrecognized normalization type: { norm_type } " )
110+ norm , relu = get_norm_layer (use_norm , relu , out_channels )
98111
99112 super (Conv2dReLU , self ).__init__ (conv , norm , relu )
100113
@@ -180,9 +193,3 @@ def __init__(self, name, **params):
180193
181194 def forward (self , x ):
182195 return self .attention (x )
183-
184-
185- if __name__ == "__main__" :
186- print (Conv2dReLU (3 , 12 , 4 ))
187- print (Conv2dReLU (3 , 12 , 4 , use_norm = {"type" : "batchnorm" }))
188- print (Conv2dReLU (3 , 12 , 4 , use_norm = {"type" : "layernorm" , "eps" : 1e-3 }))
0 commit comments