Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for partial encoder freezing #341

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,15 @@ Change input shape of the model:
# if you set input channels not equal to 3, you have to set encoder_weights=None
# how to handle such case with encoder_weights='imagenet' described in docs
model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None)

Freeze the backbone (encoder):

.. code:: python

# Freezes all encoder layers
model = Unet('resnet34', input_shape=(None, None, 6), encoder_freeze=True)
# Freezes just the first 80% of encoder layers
model = Unet('resnet34', input_shape=(None, None, 6), encoder_freeze=0.8)

Simple training pipeline
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 5 additions & 3 deletions segmentation_models/models/_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from keras_applications import get_submodules_from_kwargs


def freeze_model(model, **kwargs):
"""Set all layers non trainable, excluding BatchNormalization layers"""
def freeze_model(model, fraction=1.0, **kwargs):
"""Set layers non trainable, excluding BatchNormalization layers.
If a fraction is specified, only a fraction of the layers are
frozen (starting with the earliest layers)"""
_, layers, _, _ = get_submodules_from_kwargs(kwargs)
for layer in model.layers:
for layer in model.layers[:int(len(model.layers) * fraction)]:
if not isinstance(layer, layers.BatchNormalization):
layer.trainable = False
return
Expand Down
6 changes: 4 additions & 2 deletions segmentation_models/models/fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,8 @@ def FPN(
weights: optional, path to model weights.
activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``).
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes
just that fraction of the encoder layers (starting with the earliest layers)
encoder_features: a list of layer numbers or names starting from top of the model.
Each of these layers will be used to build features pyramid. If ``default`` is used
layer names are taken from ``DEFAULT_FEATURE_PYRAMID_LAYERS``.
Expand Down Expand Up @@ -245,7 +246,8 @@ def FPN(

# lock encoder weights for fine-tuning
if encoder_freeze:
freeze_model(backbone, **kwargs)
fraction = encoder_freeze if isinstance(encoder_freeze, float) else 1.0
freeze_model(backbone, fraction=fraction, **kwargs)

# loading model weights
if weights is not None:
Expand Down
6 changes: 4 additions & 2 deletions segmentation_models/models/linknet.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ def Linknet(
(e.g. ``sigmoid``, ``softmax``, ``linear``).
weights: optional, path to model weights.
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes
just that fraction of the encoder layers (starting with the earliest layers)
encoder_features: a list of layer numbers or names starting from top of the model.
Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
Expand Down Expand Up @@ -268,7 +269,8 @@ def Linknet(

# lock encoder weights for fine-tuning
if encoder_freeze:
freeze_model(backbone, **kwargs)
fraction = encoder_freeze if isinstance(encoder_freeze, float) else 1.0
freeze_model(backbone, fraction=fraction, **kwargs)

# loading model weights
if weights is not None:
Expand Down
3 changes: 2 additions & 1 deletion segmentation_models/models/pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def PSPNet(
(e.g. ``sigmoid``, ``softmax``, ``linear``).
weights: optional, path to model weights.
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes
just that fraction of the encoder layers (starting with the earliest layers)
downsample_factor: one of 4, 8 and 16. Downsampling rate or in other words backbone depth
to construct PSP module on it.
psp_conv_filters: number of filters in ``Conv2D`` layer in each PSP block.
Expand Down
17 changes: 14 additions & 3 deletions segmentation_models/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def wrapper(input_tensor, skip=None):
x = layers.UpSampling2D(size=2, name=up_name)(input_tensor)

if skip is not None:
#skip = layers.Dropout(0.3)(skip)
x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip])

x = Conv3x3BnReLU(filters, use_batchnorm, name=conv1_name)(x)
Expand Down Expand Up @@ -115,6 +116,7 @@ def build_unet(
classes=1,
activation='sigmoid',
use_batchnorm=True,
center_dropout=0.0,
):
input_ = backbone.input
x = backbone.output
Expand All @@ -123,7 +125,11 @@ def build_unet(
skips = ([backbone.get_layer(name=i).output if isinstance(i, str)
else backbone.get_layer(index=i).output for i in skip_connection_layers])

# add center block if previous operation was maxpooling (for vgg models)
# Dropout between encoder/decoder
if center_dropout:
x = layers.Dropout(center_dropout)(x)

# add center block if last encoder operation was maxpooling (for vgg models)
if isinstance(backbone.layers[-1], layers.MaxPooling2D):
x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x)
x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x)
Expand Down Expand Up @@ -171,6 +177,7 @@ def Unet(
decoder_block_type='upsampling',
decoder_filters=(256, 128, 64, 32, 16),
decoder_use_batchnorm=True,
center_dropout=0.0,
**kwargs
):
""" Unet_ is a fully convolution neural network for image semantic segmentation
Expand All @@ -186,7 +193,8 @@ def Unet(
(e.g. ``sigmoid``, ``softmax``, ``linear``).
weights: optional, path to model weights.
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable.
encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes
just that fraction of the encoder layers (starting with the earliest layers)
encoder_features: a list of layer numbers or names starting from top of the model.
Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used
layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``.
Expand All @@ -198,6 +206,7 @@ def Unet(
decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks
decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers
is used.
center_dropout: Dropout fraction to apply at the center block, between encoder and decoder. Default is 0.0 (none).

Returns:
``keras.models.Model``: **Unet**
Expand Down Expand Up @@ -239,11 +248,13 @@ def Unet(
activation=activation,
n_upsample_blocks=len(decoder_filters),
use_batchnorm=decoder_use_batchnorm,
center_dropout=center_dropout,
)

# lock encoder weights for fine-tuning
if encoder_freeze:
freeze_model(backbone, **kwargs)
fraction = encoder_freeze if isinstance(encoder_freeze, float) else 1.0
freeze_model(backbone, fraction=fraction, **kwargs)

# loading model weights
if weights is not None:
Expand Down