diff --git a/README.rst b/README.rst index 0b37743c..c89f0961 100644 --- a/README.rst +++ b/README.rst @@ -107,6 +107,15 @@ Change input shape of the model: # if you set input channels not equal to 3, you have to set encoder_weights=None # how to handle such case with encoder_weights='imagenet' described in docs model = Unet('resnet34', input_shape=(None, None, 6), encoder_weights=None) + +Freeze the backbone (encoder): + +.. code:: python + + # Freezes all encoder layers + model = Unet('resnet34', input_shape=(None, None, 6), encoder_freeze=True) + # Freezes just the first 80% of encoder layers + model = Unet('resnet34', input_shape=(None, None, 6), encoder_freeze=0.8) Simple training pipeline ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/segmentation_models/models/_utils.py b/segmentation_models/models/_utils.py index c59d427a..e2ee0339 100644 --- a/segmentation_models/models/_utils.py +++ b/segmentation_models/models/_utils.py @@ -1,10 +1,12 @@ from keras_applications import get_submodules_from_kwargs -def freeze_model(model, **kwargs): - """Set all layers non trainable, excluding BatchNormalization layers""" +def freeze_model(model, fraction=1.0, **kwargs): + """Set layers non trainable, excluding BatchNormalization layers. + If a fraction is specified, only a fraction of the layers are + frozen (starting with the earliest layers)""" _, layers, _, _ = get_submodules_from_kwargs(kwargs) - for layer in model.layers: + for layer in model.layers[:int(len(model.layers) * fraction)]: if not isinstance(layer, layers.BatchNormalization): layer.trainable = False return diff --git a/segmentation_models/models/fpn.py b/segmentation_models/models/fpn.py index deab7f54..724c93c2 100644 --- a/segmentation_models/models/fpn.py +++ b/segmentation_models/models/fpn.py @@ -199,7 +199,8 @@ def FPN( weights: optional, path to model weights. activation: name of one of ``keras.activations`` for last model layer (e.g. ``sigmoid``, ``softmax``, ``linear``). encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). - encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. + encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes + just that fraction of the encoder layers (starting with the earliest layers) encoder_features: a list of layer numbers or names starting from top of the model. Each of these layers will be used to build features pyramid. If ``default`` is used layer names are taken from ``DEFAULT_FEATURE_PYRAMID_LAYERS``. @@ -245,7 +246,8 @@ def FPN( # lock encoder weights for fine-tuning if encoder_freeze: - freeze_model(backbone, **kwargs) + fraction = encoder_freeze if isinstance(encoder_freeze, float) else 1.0 + freeze_model(backbone, fraction=fraction, **kwargs) # loading model weights if weights is not None: diff --git a/segmentation_models/models/linknet.py b/segmentation_models/models/linknet.py index 74c533c9..5231e1d0 100644 --- a/segmentation_models/models/linknet.py +++ b/segmentation_models/models/linknet.py @@ -212,7 +212,8 @@ def Linknet( (e.g. ``sigmoid``, ``softmax``, ``linear``). weights: optional, path to model weights. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). - encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. + encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes + just that fraction of the encoder layers (starting with the earliest layers) encoder_features: a list of layer numbers or names starting from top of the model. Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. @@ -268,7 +269,8 @@ def Linknet( # lock encoder weights for fine-tuning if encoder_freeze: - freeze_model(backbone, **kwargs) + fraction = encoder_freeze if isinstance(encoder_freeze, float) else 1.0 + freeze_model(backbone, fraction=fraction, **kwargs) # loading model weights if weights is not None: diff --git a/segmentation_models/models/pspnet.py b/segmentation_models/models/pspnet.py index 001b28c9..57661ed6 100644 --- a/segmentation_models/models/pspnet.py +++ b/segmentation_models/models/pspnet.py @@ -179,7 +179,8 @@ def PSPNet( (e.g. ``sigmoid``, ``softmax``, ``linear``). weights: optional, path to model weights. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). - encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. + encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes + just that fraction of the encoder layers (starting with the earliest layers) downsample_factor: one of 4, 8 and 16. Downsampling rate or in other words backbone depth to construct PSP module on it. psp_conv_filters: number of filters in ``Conv2D`` layer in each PSP block. diff --git a/segmentation_models/models/unet.py b/segmentation_models/models/unet.py index 7da2b391..767877fa 100644 --- a/segmentation_models/models/unet.py +++ b/segmentation_models/models/unet.py @@ -57,6 +57,7 @@ def wrapper(input_tensor, skip=None): x = layers.UpSampling2D(size=2, name=up_name)(input_tensor) if skip is not None: + #skip = layers.Dropout(0.3)(skip) x = layers.Concatenate(axis=concat_axis, name=concat_name)([x, skip]) x = Conv3x3BnReLU(filters, use_batchnorm, name=conv1_name)(x) @@ -115,6 +116,7 @@ def build_unet( classes=1, activation='sigmoid', use_batchnorm=True, + center_dropout=0.0, ): input_ = backbone.input x = backbone.output @@ -123,7 +125,11 @@ def build_unet( skips = ([backbone.get_layer(name=i).output if isinstance(i, str) else backbone.get_layer(index=i).output for i in skip_connection_layers]) - # add center block if previous operation was maxpooling (for vgg models) + # Dropout between encoder/decoder + if center_dropout: + x = layers.Dropout(center_dropout)(x) + + # add center block if last encoder operation was maxpooling (for vgg models) if isinstance(backbone.layers[-1], layers.MaxPooling2D): x = Conv3x3BnReLU(512, use_batchnorm, name='center_block1')(x) x = Conv3x3BnReLU(512, use_batchnorm, name='center_block2')(x) @@ -171,6 +177,7 @@ def Unet( decoder_block_type='upsampling', decoder_filters=(256, 128, 64, 32, 16), decoder_use_batchnorm=True, + center_dropout=0.0, **kwargs ): """ Unet_ is a fully convolution neural network for image semantic segmentation @@ -186,7 +193,8 @@ def Unet( (e.g. ``sigmoid``, ``softmax``, ``linear``). weights: optional, path to model weights. encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet). - encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. + encoder_freeze: if ``True`` set all layers of encoder (backbone model) as non-trainable. If a float, freezes + just that fraction of the encoder layers (starting with the earliest layers) encoder_features: a list of layer numbers or names starting from top of the model. Each of these layers will be concatenated with corresponding decoder block. If ``default`` is used layer names are taken from ``DEFAULT_SKIP_CONNECTIONS``. @@ -198,6 +206,7 @@ def Unet( decoder_filters: list of numbers of ``Conv2D`` layer filters in decoder blocks decoder_use_batchnorm: if ``True``, ``BatchNormalisation`` layer between ``Conv2D`` and ``Activation`` layers is used. + center_dropout: Dropout fraction to apply at the center block, between encoder and decoder. Default is 0.0 (none). Returns: ``keras.models.Model``: **Unet** @@ -239,11 +248,13 @@ def Unet( activation=activation, n_upsample_blocks=len(decoder_filters), use_batchnorm=decoder_use_batchnorm, + center_dropout=center_dropout, ) # lock encoder weights for fine-tuning if encoder_freeze: - freeze_model(backbone, **kwargs) + fraction = encoder_freeze if isinstance(encoder_freeze, float) else 1.0 + freeze_model(backbone, fraction=fraction, **kwargs) # loading model weights if weights is not None: