Skip to content

Commit 914f2bf

Browse files
authoredJul 5, 2021
Release 0.2.0 (#430)
* new in_channels != 3 initialization * docs fixes * version resolving
1 parent 225823b commit 914f2bf

32 files changed

+233
-366
lines changed
 

‎.github/workflows/tests.yml

-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ jobs:
2929
python -m pip install codecov pytest mock
3030
pip3 install torch==1.9.0+cpu torchvision==0.10.0+cpu torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
3131
pip install .
32-
pip install -U git+https://github.com/rwightman/pytorch-image-models
3332
- name: Test
3433
run: |
3534
python -m pytest -s tests

‎README.md

+10-21
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 115 available encoders
15+
- 113 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -297,8 +297,12 @@ The following is a list of supported encoders in the SMP. Select the appropriate
297297
|Encoder |Weights |Params, M |
298298
|--------------------------------|:------------------------------:|:------------------------------:|
299299
|mobilenet_v2 |imagenet |2M |
300-
|mobilenet_v3_large |imagenet |3M |
301-
|mobilenet_v3_small |imagenet |1M |
300+
|timm-mobilenetv3_large_075 |imagenet |1.78M |
301+
|timm-mobilenetv3_large_100 |imagenet |2.97M |
302+
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
303+
|timm-mobilenetv3_small_075 |imagenet |0.57M |
304+
|timm-mobilenetv3_small_100 |imagenet |0.93M |
305+
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
302306

303307
</div>
304308
</details>
@@ -337,22 +341,6 @@ The following is a list of supported encoders in the SMP. Select the appropriate
337341
</div>
338342
</details>
339343

340-
<details>
341-
<summary style="margin-left: 25px;">MobileNetV3</summary>
342-
<div style="margin-left: 25px;">
343-
344-
|Encoder |Weights |Params, M |
345-
|--------------------------------|:------------------------------:|:------------------------------:|
346-
|timm-mobilenetv3_large_075 |imagenet |1.78M |
347-
|timm-mobilenetv3_large_100 |imagenet |2.97M |
348-
|timm-mobilenetv3_large_minimal_100|imagenet |1.41M |
349-
|timm-mobilenetv3_small_075 |imagenet |0.57M |
350-
|timm-mobilenetv3_small_100 |imagenet |0.93M |
351-
|timm-mobilenetv3_small_minimal_100|imagenet |0.43M |
352-
353-
</div>
354-
</details>
355-
356344

357345
\* `ssl`, `swsl` - semi-supervised and weakly-supervised learning on ImageNet ([repo](https://github.com/facebookresearch/semi-supervised-ImageNet1K-models)).
358346

@@ -367,8 +355,9 @@ The following is a list of supported encoders in the SMP. Select the appropriate
367355

368356
##### Input channels
369357
Input channels parameter allows you to create models, which process tensors with arbitrary number of channels.
370-
If you use pretrained weights from imagenet - weights of first convolution will be reused for
371-
1- or 2- channels inputs, for input channels > 4 weights of first convolution will be initialized randomly.
358+
If you use pretrained weights from imagenet - weights of first convolution will be reused. For
359+
1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be
360+
populated with weights like `new_weight[:, i] = pretrained_weight[:, i % 3]` and than scaled with `new_weight * 3 / new_in_channels`.
372361
```python
373362
model = smp.FPN('resnet34', in_channels=1)
374363
mask = model(torch.ones([1, 1, 64, 64]))

‎docs/encoders.rst

+17-28
Original file line numberDiff line numberDiff line change
@@ -265,15 +265,23 @@ EfficientNet
265265
MobileNet
266266
~~~~~~~~~
267267

268-
+---------------------+------------+-------------+
269-
| Encoder | Weights | Params, M |
270-
+=====================+============+=============+
271-
| mobilenet\_v2 | imagenet | 2M |
272-
+---------------------+------------+-------------+
273-
| mobilenet\_v3_large | imagenet | 3M |
274-
+---------------------+------------+-------------+
275-
| mobilenet\_v2_small | imagenet | 1M |
276-
+---------------------+------------+-------------+
268+
+---------------------------------------+------------+-------------+
269+
| Encoder | Weights | Params, M |
270+
+=======================================+============+=============+
271+
| mobilenet\_v2 | imagenet | 2M |
272+
+---------------------------------------+------------+-------------+
273+
| timm-mobilenetv3\_large\_075 | imagenet | 1.78M |
274+
+---------------------------------------+------------+-------------+
275+
| timm-mobilenetv3\_large\_100 | imagenet | 2.97M |
276+
+---------------------------------------+------------+-------------+
277+
| timm-mobilenetv3\_large\_minimal\_100 | imagenet | 1.41M |
278+
+---------------------------------------+------------+-------------+
279+
| timm-mobilenetv3\_small\_075 | imagenet | 0.57M |
280+
+---------------------------------------+------------+-------------+
281+
| timm-mobilenetv3\_small\_100 | imagenet | 0.93M |
282+
+---------------------------------------+------------+-------------+
283+
| timm-mobilenetv3\_small\_minimal\_100 | imagenet | 0.43M |
284+
+---------------------------------------+------------+-------------+
277285

278286
DPN
279287
~~~
@@ -316,22 +324,3 @@ VGG
316324
+-------------+------------+-------------+
317325
| vgg19\_bn | imagenet | 20M |
318326
+-------------+------------+-------------+
319-
320-
MobileNetV3
321-
~~~~~~~~~
322-
323-
+-----------------------------------+------------+-------------+
324-
| Encoder | Weights | Params, M |
325-
+===================================+============+=============+
326-
| timm-mobilenetv3_large_075 | imagenet | 1.78M |
327-
+-----------------------------------+------------+-------------+
328-
| timm-mobilenetv3_large_100 | imagenet | 2.97M |
329-
+-----------------------------------+------------+-------------+
330-
| timm-mobilenetv3_large_minimal_100| imagenet | 1.41M |
331-
+-----------------------------------+------------+-------------+
332-
| timm-mobilenetv3_small_075 | imagenet | 0.57M |
333-
+-----------------------------------+------------+-------------+
334-
| timm-mobilenetv3_small_100 | imagenet | 0.93M |
335-
+-----------------------------------+------------+-------------+
336-
| timm-mobilenetv3_small_minimal_100| imagenet | 0.43M |
337-
+-----------------------------------+------------+-------------+

‎docs/losses.rst

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@ DiceLoss
1717
~~~~~~~~
1818
.. autoclass:: segmentation_models_pytorch.losses.DiceLoss
1919

20+
TverskyLoss
21+
~~~~~~~~
22+
.. autoclass:: segmentation_models_pytorch.losses.TverskyLoss
23+
2024
FocalLoss
2125
~~~~~~~~~
2226
.. autoclass:: segmentation_models_pytorch.losses.FocalLoss

‎requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torchvision>=0.9.0
1+
torchvision>=0.5.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch==0.6.3
44
timm==0.4.12
+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
VERSION = (0, 1, 3)
1+
VERSION = (0, 2, 0)
22

33
__version__ = '.'.join(map(str, VERSION))

‎segmentation_models_pytorch/encoders/__init__.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,14 @@
1010
from .inceptionv4 import inceptionv4_encoders
1111
from .efficientnet import efficient_net_encoders
1212
from .mobilenet import mobilenet_encoders
13-
from .mobilenet_v3 import mobilenet_v3_encoders
1413
from .xception import xception_encoders
1514
from .timm_efficientnet import timm_efficientnet_encoders
1615
from .timm_resnest import timm_resnest_encoders
1716
from .timm_res2net import timm_res2net_encoders
1817
from .timm_regnet import timm_regnet_encoders
1918
from .timm_sknet import timm_sknet_encoders
2019
from .timm_mobilenetv3 import timm_mobilenetv3_encoders
21-
try:
22-
from .timm_gernet import timm_gernet_encoders
23-
except ImportError as e:
24-
timm_gernet_encoders = {}
25-
print("Current timm version doesn't support GERNet."
26-
"If GERNet support is needed please update timm")
20+
from .timm_gernet import timm_gernet_encoders
2721

2822
from ._preprocessing import preprocess_input
2923

@@ -37,7 +31,6 @@
3731
encoders.update(inceptionv4_encoders)
3832
encoders.update(efficient_net_encoders)
3933
encoders.update(mobilenet_encoders)
40-
encoders.update(mobilenet_v3_encoders)
4134
encoders.update(xception_encoders)
4235
encoders.update(timm_efficientnet_encoders)
4336
encoders.update(timm_resnest_encoders)
@@ -68,7 +61,7 @@ def get_encoder(name, in_channels=3, depth=5, weights=None):
6861
))
6962
encoder.load_state_dict(model_zoo.load_url(settings["url"]))
7063

71-
encoder.set_in_channels(in_channels)
64+
encoder.set_in_channels(in_channels, pretrained=weights is not None)
7265

7366
return encoder
7467

‎segmentation_models_pytorch/encoders/_base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def out_channels(self):
1717
"""Return channels dimensions for each tensor of forward output of encoder"""
1818
return self._out_channels[: self._depth + 1]
1919

20-
def set_in_channels(self, in_channels):
20+
def set_in_channels(self, in_channels, pretrained=True):
2121
"""Change first convolution channels"""
2222
if in_channels == 3:
2323
return
@@ -26,7 +26,7 @@ def set_in_channels(self, in_channels):
2626
if self._out_channels[0] == 3:
2727
self._out_channels = tuple([in_channels] + list(self._out_channels)[1:])
2828

29-
utils.patch_first_conv(model=self, in_channels=in_channels)
29+
utils.patch_first_conv(model=self, new_in_channels=in_channels, pretrained=pretrained)
3030

3131
def get_stages(self):
3232
"""Method should be overridden in encoder"""

‎segmentation_models_pytorch/encoders/_utils.py

+26-17
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch.nn as nn
33

44

5-
def patch_first_conv(model, in_channels):
5+
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True):
66
"""Change first convolution layer input channels.
77
In case:
88
in_channels == 1 or in_channels == 2 -> reuse original weights
@@ -11,29 +11,38 @@ def patch_first_conv(model, in_channels):
1111

1212
# get first conv
1313
for module in model.modules():
14-
if isinstance(module, nn.Conv2d):
14+
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels:
1515
break
16-
17-
# change input channels for first conv
18-
module.in_channels = in_channels
16+
1917
weight = module.weight.detach()
20-
reset = False
21-
22-
if in_channels == 1:
23-
weight = weight.sum(1, keepdim=True)
24-
elif in_channels == 2:
25-
weight = weight[:, :2] * (3.0 / 2.0)
18+
module.in_channels = new_in_channels
19+
20+
if not pretrained:
21+
module.weight = nn.parameter.Parameter(
22+
torch.Tensor(
23+
module.out_channels,
24+
new_in_channels // module.groups,
25+
*module.kernel_size
26+
)
27+
)
28+
module.reset_parameters()
29+
30+
elif new_in_channels == 1:
31+
new_weight = weight.sum(1, keepdim=True)
32+
module.weight = nn.parameter.Parameter(new_weight)
33+
2634
else:
27-
reset = True
28-
weight = torch.Tensor(
35+
new_weight = torch.Tensor(
2936
module.out_channels,
30-
module.in_channels // module.groups,
37+
new_in_channels // module.groups,
3138
*module.kernel_size
3239
)
3340

34-
module.weight = nn.parameter.Parameter(weight)
35-
if reset:
36-
module.reset_parameters()
41+
for i in range(new_in_channels):
42+
new_weight[:, i] = weight[:, i % default_in_channels]
43+
44+
new_weight = new_weight * (default_in_channels / new_in_channels)
45+
module.weight = nn.parameter.Parameter(new_weight)
3746

3847

3948
def replace_strides_with_dilation(module, dilation_rate):

‎segmentation_models_pytorch/encoders/densenet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def load_state_dict(self, state_dict):
9696
del state_dict[key]
9797

9898
# remove linear
99-
state_dict.pop("classifier.bias")
100-
state_dict.pop("classifier.weight")
99+
state_dict.pop("classifier.bias", None)
100+
state_dict.pop("classifier.weight", None)
101101

102102
super().load_state_dict(state_dict)
103103

‎segmentation_models_pytorch/encoders/dpn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ def forward(self, x):
6868
return features
6969

7070
def load_state_dict(self, state_dict, **kwargs):
71-
state_dict.pop("last_linear.bias")
72-
state_dict.pop("last_linear.weight")
71+
state_dict.pop("last_linear.bias", None)
72+
state_dict.pop("last_linear.weight", None)
7373
super().load_state_dict(state_dict, **kwargs)
7474

7575

‎segmentation_models_pytorch/encoders/efficientnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ def forward(self, x):
7777
return features
7878

7979
def load_state_dict(self, state_dict, **kwargs):
80-
state_dict.pop("_fc.bias")
81-
state_dict.pop("_fc.weight")
80+
state_dict.pop("_fc.bias", None)
81+
state_dict.pop("_fc.weight", None)
8282
super().load_state_dict(state_dict, **kwargs)
8383

8484

‎segmentation_models_pytorch/encoders/inceptionresnetv2.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def forward(self, x):
7676
return features
7777

7878
def load_state_dict(self, state_dict, **kwargs):
79-
state_dict.pop("last_linear.bias")
80-
state_dict.pop("last_linear.weight")
79+
state_dict.pop("last_linear.bias", None)
80+
state_dict.pop("last_linear.weight", None)
8181
super().load_state_dict(state_dict, **kwargs)
8282

8383

‎segmentation_models_pytorch/encoders/inceptionv4.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,8 @@ def forward(self, x):
7575
return features
7676

7777
def load_state_dict(self, state_dict, **kwargs):
78-
state_dict.pop("last_linear.bias")
79-
state_dict.pop("last_linear.weight")
78+
state_dict.pop("last_linear.bias", None)
79+
state_dict.pop("last_linear.weight", None)
8080
super().load_state_dict(state_dict, **kwargs)
8181

8282

‎segmentation_models_pytorch/encoders/mobilenet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ def forward(self, x):
5959
return features
6060

6161
def load_state_dict(self, state_dict, **kwargs):
62-
state_dict.pop("classifier.1.bias")
63-
state_dict.pop("classifier.1.weight")
62+
state_dict.pop("classifier.1.bias", None)
63+
state_dict.pop("classifier.1.weight", None)
6464
super().load_state_dict(state_dict, **kwargs)
6565

6666

0 commit comments

Comments
 (0)
Failed to load comments.