Skip to content

Commit 1b16b25

Browse files
First fix following review
1 parent d65001b commit 1b16b25

File tree

11 files changed

+134
-106
lines changed

11 files changed

+134
-106
lines changed

segmentation_models_pytorch/base/modules.py

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from typing import Any, Dict, Tuple, Union
12
import warnings
23

34
import torch
@@ -8,6 +9,76 @@
89
except ImportError:
910
InPlaceABN = None
1011

12+
def handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm: Union[bool, str, None], decoder_use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]:
13+
if decoder_use_batchnorm is not None:
14+
warnings.warn(
15+
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
16+
DeprecationWarning,
17+
)
18+
if decoder_use_batchnorm is True:
19+
decoder_use_norm = {"type": "batchnorm"}
20+
elif decoder_use_batchnorm is False:
21+
decoder_use_norm = {"type": "identity"}
22+
elif decoder_use_batchnorm == "inplace":
23+
decoder_use_norm = {
24+
"type": "inplace",
25+
"activation": "leaky_relu",
26+
"activation_param": 0.0,
27+
}
28+
else:
29+
raise ValueError("Unrecognized value for use_batchnorm")
30+
31+
return decoder_use_norm
32+
33+
34+
def normalize_use_norm(use_norm: Union[bool, str, Dict[str, Any]]) -> Dict[str, Any]:
35+
if isinstance(use_norm, str):
36+
norm_str = use_norm.lower()
37+
if norm_str == "inplace":
38+
use_norm = {
39+
"type": "inplace",
40+
"activation": "leaky_relu",
41+
"activation_param": 0.0,
42+
}
43+
elif norm_str in (
44+
"batchnorm",
45+
"identity",
46+
"layernorm",
47+
"groupnorm",
48+
"instancenorm",
49+
):
50+
use_norm = {"type": norm_str}
51+
else:
52+
raise ValueError("Unrecognized normalization type string provided")
53+
elif isinstance(use_norm, bool):
54+
use_norm = {"type": "batchnorm" if use_norm else "identity"}
55+
elif not isinstance(use_norm, dict):
56+
raise ValueError("use_norm must be a dictionary, boolean, or string")
57+
58+
return use_norm
59+
60+
def get_norm_layer(use_norm: Dict[str, Any], relu: nn.Module, out_channels: int) -> Tuple[nn.Module, nn.Module]:
61+
norm_type = use_norm["type"]
62+
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"}
63+
64+
if norm_type == "inplace":
65+
norm = InPlaceABN(out_channels, **extra_kwargs)
66+
relu = nn.Identity()
67+
elif norm_type == "batchnorm":
68+
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
69+
elif norm_type == "identity":
70+
norm = nn.Identity()
71+
elif norm_type == "layernorm":
72+
norm = nn.LayerNorm(out_channels, **extra_kwargs)
73+
elif norm_type == "groupnorm":
74+
norm = nn.GroupNorm(out_channels, **extra_kwargs)
75+
elif norm_type == "instancenorm":
76+
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
77+
else:
78+
raise ValueError(f"Unrecognized normalization type: {norm_type}")
79+
80+
return norm, relu
81+
1182

1283
class Conv2dReLU(nn.Sequential):
1384
def __init__(
@@ -17,50 +88,9 @@ def __init__(
1788
kernel_size,
1889
padding=0,
1990
stride=1,
20-
use_batchnorm=True,
2191
use_norm="batchnorm",
2292
):
23-
if use_batchnorm is not None:
24-
warnings.warn(
25-
"The usage of use_batchnorm is deprecated. Please modify your code for use_norm",
26-
DeprecationWarning,
27-
)
28-
if use_batchnorm is True:
29-
use_norm = {"type": "batchnorm"}
30-
elif use_batchnorm is False:
31-
use_norm = {"type": "identity"}
32-
elif use_batchnorm == "inplace":
33-
use_norm = {
34-
"type": "inplace",
35-
"activation": "leaky_relu",
36-
"activation_param": 0.0,
37-
}
38-
else:
39-
raise ValueError("Unrecognized value for use_batchnorm")
40-
41-
if isinstance(use_norm, str):
42-
norm_str = use_norm.lower()
43-
if norm_str == "inplace":
44-
use_norm = {
45-
"type": "inplace",
46-
"activation": "leaky_relu",
47-
"activation_param": 0.0,
48-
}
49-
elif norm_str in (
50-
"batchnorm",
51-
"identity",
52-
"layernorm",
53-
"groupnorm",
54-
"instancenorm",
55-
):
56-
use_norm = {"type": norm_str}
57-
else:
58-
raise ValueError("Unrecognized normalization type string provided")
59-
elif isinstance(use_norm, bool):
60-
use_norm = {"type": "batchnorm" if use_norm else "identity"}
61-
elif not isinstance(use_norm, dict):
62-
raise ValueError("use_norm must be a dictionary, boolean, or string")
63-
93+
use_norm = normalize_use_norm(use_norm)
6494
if use_norm["type"] == "inplace" and InPlaceABN is None:
6595
raise RuntimeError(
6696
"In order to use `use_batchnorm='inplace'` or `use_norm='inplace'` the inplace_abn package must be installed. "
@@ -77,24 +107,7 @@ def __init__(
77107
)
78108
relu = nn.ReLU(inplace=True)
79109

80-
norm_type = use_norm["type"]
81-
extra_kwargs = {k: v for k, v in use_norm.items() if k != "type"}
82-
83-
if norm_type == "inplace":
84-
norm = InPlaceABN(out_channels, **extra_kwargs)
85-
relu = nn.Identity()
86-
elif norm_type == "batchnorm":
87-
norm = nn.BatchNorm2d(out_channels, **extra_kwargs)
88-
elif norm_type == "identity":
89-
norm = nn.Identity()
90-
elif norm_type == "layernorm":
91-
norm = nn.LayerNorm(out_channels, **extra_kwargs)
92-
elif norm_type == "groupnorm":
93-
norm = nn.GroupNorm(out_channels, **extra_kwargs)
94-
elif norm_type == "instancenorm":
95-
norm = nn.InstanceNorm2d(out_channels, **extra_kwargs)
96-
else:
97-
raise ValueError(f"Unrecognized normalization type: {norm_type}")
110+
norm, relu = get_norm_layer(use_norm, relu, out_channels)
98111

99112
super(Conv2dReLU, self).__init__(conv, norm, relu)
100113

@@ -180,9 +193,3 @@ def __init__(self, name, **params):
180193

181194
def forward(self, x):
182195
return self.attention(x)
183-
184-
185-
if __name__ == "__main__":
186-
print(Conv2dReLU(3, 12, 4))
187-
print(Conv2dReLU(3, 12, 4, use_norm={"type": "batchnorm"}))
188-
print(Conv2dReLU(3, 12, 4, use_norm={"type": "layernorm", "eps": 1e-3}))

segmentation_models_pytorch/decoders/linknet/decoder.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def __init__(
1010
self,
1111
in_channels: int,
1212
out_channels: int,
13-
use_batchnorm: Union[bool, str, None] = True,
1413
use_norm: Union[bool, str, Dict[str, Any]] = True,
1514
):
1615
super().__init__()
@@ -21,7 +20,7 @@ def __init__(
2120
nn.ReLU(inplace=True),
2221
]
2322

24-
if use_batchnorm or use_norm:
23+
if use_norm:
2524
layers.insert(1, nn.BatchNorm2d(out_channels))
2625

2726
super().__init__(*layers)
@@ -32,7 +31,6 @@ def __init__(
3231
self,
3332
in_channels: int,
3433
out_channels: int,
35-
use_batchnorm: Union[bool, str, None] = True,
3634
use_norm: Union[bool, str, Dict[str, Any]] = True,
3735
):
3836
super().__init__()
@@ -42,17 +40,15 @@ def __init__(
4240
in_channels,
4341
in_channels // 4,
4442
kernel_size=1,
45-
use_batchnorm=use_batchnorm,
4643
use_norm=use_norm,
4744
),
4845
TransposeX2(
49-
in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm
46+
in_channels // 4, in_channels // 4, use_norm=use_norm
5047
),
5148
modules.Conv2dReLU(
5249
in_channels // 4,
5350
out_channels,
5451
kernel_size=1,
55-
use_batchnorm=use_batchnorm,
5652
use_norm=use_norm,
5753
),
5854
)
@@ -72,7 +68,6 @@ def __init__(
7268
encoder_channels: List[int],
7369
prefinal_channels: int = 32,
7470
n_blocks: int = 5,
75-
use_batchnorm: Union[bool, str, None] = True,
7671
use_norm: Union[bool, str, Dict[str, Any]] = True,
7772
):
7873
super().__init__()
@@ -89,7 +84,6 @@ def __init__(
8984
DecoderBlock(
9085
channels[i],
9186
channels[i + 1],
92-
use_batchnorm=use_batchnorm,
9387
use_norm=use_norm,
9488
)
9589
for i in range(n_blocks)

segmentation_models_pytorch/decoders/linknet/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
99
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
10+
from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation
1011

1112
from .decoder import LinknetDecoder
1213

@@ -101,11 +102,11 @@ def __init__(
101102
**kwargs,
102103
)
103104

105+
decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm)
104106
self.decoder = LinknetDecoder(
105107
encoder_channels=self.encoder.out_channels,
106108
n_blocks=encoder_depth,
107109
prefinal_channels=32,
108-
use_batchnorm=decoder_use_batchnorm,
109110
use_norm=decoder_use_norm,
110111
)
111112

segmentation_models_pytorch/decoders/manet/decoder.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(
4949
in_channels: int,
5050
skip_channels: int,
5151
out_channels: int,
52-
use_batchnorm: Union[bool, str, None] = True,
5352
use_norm: Union[bool, str, Dict[str, Any]] = True,
5453
reduction: int = 16,
5554
):
@@ -61,14 +60,12 @@ def __init__(
6160
in_channels,
6261
kernel_size=3,
6362
padding=1,
64-
use_batchnorm=use_batchnorm,
6563
use_norm=use_norm,
6664
),
6765
md.Conv2dReLU(
6866
in_channels,
6967
skip_channels,
7068
kernel_size=1,
71-
use_batchnorm=use_batchnorm,
7269
use_norm=use_norm,
7370
),
7471
)
@@ -93,15 +90,13 @@ def __init__(
9390
out_channels,
9491
kernel_size=3,
9592
padding=1,
96-
use_batchnorm=use_batchnorm,
9793
use_norm=use_norm,
9894
)
9995
self.conv2 = md.Conv2dReLU(
10096
out_channels,
10197
out_channels,
10298
kernel_size=3,
10399
padding=1,
104-
use_batchnorm=use_batchnorm,
105100
use_norm=use_norm,
106101
)
107102

@@ -127,7 +122,6 @@ def __init__(
127122
in_channels: int,
128123
skip_channels: int,
129124
out_channels: int,
130-
use_batchnorm: Union[bool, str, None] = True,
131125
use_norm: Union[bool, str, Dict[str, Any]] = True,
132126
):
133127
super().__init__()
@@ -136,15 +130,13 @@ def __init__(
136130
out_channels,
137131
kernel_size=3,
138132
padding=1,
139-
use_batchnorm=use_batchnorm,
140133
use_norm=use_norm,
141134
)
142135
self.conv2 = md.Conv2dReLU(
143136
out_channels,
144137
out_channels,
145138
kernel_size=3,
146139
padding=1,
147-
use_batchnorm=use_batchnorm,
148140
use_norm=use_norm,
149141
)
150142

@@ -166,7 +158,6 @@ def __init__(
166158
decoder_channels: List[int],
167159
n_blocks: int = 5,
168160
reduction: int = 16,
169-
use_batchnorm: Union[bool, str, None] = True,
170161
use_norm: Union[bool, str, Dict[str, Any]] = True,
171162
pab_channels: int = 64,
172163
):
@@ -195,7 +186,7 @@ def __init__(
195186

196187
# combine decoder keyword arguments
197188
kwargs = dict(
198-
use_batchnorm=use_batchnorm, use_norm=use_norm
189+
use_norm=use_norm
199190
) # no attention type here
200191
blocks = [
201192
MFABBlock(in_ch, skip_ch, out_ch, reduction=reduction, **kwargs)

segmentation_models_pytorch/decoders/manet/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
)
88
from segmentation_models_pytorch.encoders import get_encoder
99
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading
10+
from segmentation_models_pytorch.base.modules import handle_decoder_use_batchnorm_deprecation
1011

1112
from .decoder import MAnetDecoder
1213

@@ -101,11 +102,12 @@ def __init__(
101102
**kwargs,
102103
)
103104

105+
decoder_use_norm = handle_decoder_use_batchnorm_deprecation(decoder_use_batchnorm, decoder_use_norm)
106+
104107
self.decoder = MAnetDecoder(
105108
encoder_channels=self.encoder.out_channels,
106109
decoder_channels=decoder_channels,
107110
n_blocks=encoder_depth,
108-
use_batchnorm=decoder_use_batchnorm,
109111
use_norm=decoder_use_norm,
110112
pab_channels=decoder_pab_channels,
111113
)

0 commit comments

Comments
 (0)