# Segmentation models + Timm encoders

Traditionally, most classification models, like the classic VGG and ResNet models, followed a 5-"stage" architecture. In each stage, these models increased the number of channels in the features while **halving the spatial resolution**. Classic decoders for semantic segmentations were designed with this architecture in mind, expecting features with progressively reduced spatial resolutions at each stage. For instance, the 5 feature maps should be at resolutions: 

 - h // 2,  w // 2
 - h // 4,  w // 4
 - h // 8,  w // 8
 - h // 16, w // 16 
 - h // 32, w // 32

Recent advancements in convolutional and transformer models architectures have significantly diversified intermediate representations of features. Modern models might produce feature maps with uniform spatial resolutions, such as multiple maps all at \[h // 16, w // 16\] or \[h // 8, w // 8\], or fewer/greater number of feature maps, such as just 3 instead of the traditional 5.

The [timm](https://github.com/huggingface/pytorch-image-models) library provides detailed feature information for various architectures, allowing us to adapt these features to our standard approach.

Here’s how we can adapt to these changes:

1) **Feature Selection:** We should choose the appropriate feature maps when the feature extractor model provides a larger number of output features than needed.
2) **Feature Padding:** Add ("pad") feature maps when the number of output features is less than required.
3) **Spatial Resolution Adjustment:** Adjust the spatial resolutions to match the classic scheme [h/2, w/2; h/4, w/4; h/8, w/8; h/16, w/16; h/32, w/32], which is expected by most decoders.
4) **Channel Reduction:** Optionally reduce the number of channels in the features to create a lighter, more efficient model.

By implementing these adaptations, we can ensure compatibility with various decoder architectures and leverage the full potential of modern feature extractor models.

Below, we are going to explore the following encoder cases:

 - [Traditional encoder with 5 stages](#traditional-encoder)
 - [Selecting features from encoder](#feature-selection)
 - [Encoder with less than 5 blocs](#encoders-with-less-than-5-blocks)
 - [Specifying number of channels for encoder](#specifying-number-of-output-channels-for-encoder)

In [1]:
import segmentation_models_pytorch as smp

  from .autonotebook import tqdm as notebook_tqdm


### Traditional encoder

For traditional encoder, with standard number of features and standard spatial feature resolutions, the approach is simple, we just take all features and pass them as is to the decoder. As you can see below, timm model output features, selected features and adapted features are the same in shape.

In [2]:
model = smp.Unet(encoder_name="tu-resnet18")
print(model.encoder.features_info_str)

index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 0       (act1):         64, hw /  2     ->         64, hw /  2     ->         64, hw /  2    
 1     (layer1):         64, hw /  4     ->         64, hw /  4     ->         64, hw /  4    
 2     (layer2):        128, hw /  8     ->        128, hw /  8     ->        128, hw /  8    
 3     (layer3):        256, hw / 16     ->        256, hw / 16     ->        256, hw / 16    
 4     (layer4):        512, hw / 32     ->        512, hw / 32     ->        512, hw / 32    


### Encoder Features Selection

You can also customize model with `timm` encoder by reducing its depth and specifying which features to use.
For example, if you set `encoder_depth=3` the first 3 feature maps of encoder will be used, and the rest will be ignored:

In [4]:
model = smp.Unet(encoder_name="tu-resnet18", encoder_depth=3)
print(model.encoder.features_info_str)

index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 0       (act1):         64, hw /  2     ->         64, hw /  2     ->         64, hw /  2    
 1     (layer1):         64, hw /  4     ->         64, hw /  4     ->         64, hw /  4    
 2     (layer2):        128, hw /  8     ->        128, hw /  8     ->        128, hw /  8    
 3     (layer3):        256, hw / 16     -x
 4     (layer4):        512, hw / 32     -x


Alternatively, you can also chose last encoder features instead of the first ones:

In [5]:
model = smp.Unet(encoder_name="tu-resnet18", encoder_depth=3, encoder_indices="last")
print(model.encoder.features_info_str)

index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 0       (act1):         64, hw /  2     -x
 1     (layer1):         64, hw /  4     -x
 2     (layer2):        128, hw /  8     ->        128, hw /  8     ->        128, hw /  2    
 3     (layer3):        256, hw / 16     ->        256, hw / 16     ->        256, hw /  4    
 4     (layer4):        512, hw / 32     ->        512, hw / 32     ->        512, hw /  8    


Or, even specify particular indices of encoder features to be used:

In [6]:
model = smp.Unet(encoder_name="tu-resnet18", encoder_depth=3, encoder_indices=[0, 2, 4])
print(model.encoder.features_info_str)

index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 0       (act1):         64, hw /  2     ->         64, hw /  2     ->         64, hw /  2    
 1     (layer1):         64, hw /  4     -x
 2     (layer2):        128, hw /  8     ->        128, hw /  8     ->        128, hw /  4    
 3     (layer3):        256, hw / 16     -x
 4     (layer4):        512, hw / 32     ->        512, hw / 32     ->        512, hw /  8    


As you can see, final encoder features are always adapted to have reductions `2^1 .. 2^encoder_depth`, for above examples, with `encoder_depth=3` reductions are [2, 4, 8]. For `encoder_depth=4` reductions will be [2, 4, 8, 16].

### Encoders with less than 5 blocks

For encoders with less than 5 feature blocks, the following approaches applied:
 - Depth is adjusted to match maximum features reduction.
 - Missing features are filled with `dymmy` feature maps, that have 0 dims and not influence training/inference somehow.

**How the depth is adjusted?**

We take the maximum depth by the number of feature maps or by reduction. Lets see on examples:
 - Encoder with 3 feature maps [16, 16, 16] reductions. Maximum reduction 16 = 2^**4** -> so, the encoder depth will be adjusted to 4.
 - Encoder with 3 feature maps [4, 4, 4] reductions. Maximum reduction 4 = 2^**2**, however, the number of feature maps is 3 -> so, the encoder depth will be adjusted to 3.

##### Example 1:
Encoder with 4 feature maps and reductions [4, 8, 16, 32]. Maximum reduction is 32 = 2^**5** -> encoder depth is 5. However, the number of features is just 4, feature map with reduction `2` is missed. This feature map will be filled with `dummy` feature of shape `[0, h / 2, w / 2]`.

In [7]:
model = smp.Unet(encoder_name="tu-efficientformer_l1", encoder_weights=None)
print(model.encoder.features_info_str)

[32m2024-06-08 22:48:48.695[0m | [34m[1mDEBUG   [0m | [36msegmentation_models_pytorch.encoders.timm_universal[0m:[36m__init__[0m:[36m172[0m - [34m[1mEncoder has 1 dummy feature(s), because the real number of features (4) is less than specified encoder depth (5).[0m


index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 x     (-none-):                         ->          0, hw /  2     ->          0, hw /  2    
 0   (stages.0):         48, hw /  4     ->         48, hw /  4     ->         48, hw /  4    
 1   (stages.1):         96, hw /  8     ->         96, hw /  8     ->         96, hw /  8    
 2   (stages.2):        224, hw / 16     ->        224, hw / 16     ->        224, hw / 16    
 3   (stages.3):        448, hw / 32     ->        448, hw / 32     ->        448, hw / 32    


##### Example 2:

Despite specified `encoder_depth=5` encoder has only 3 feature maps with [16, 16, 16] reductions. Maximum reduction is 16 = 2^**4** -> so, the encoder depth will be adjusted to **4**.
One `dummy` feature is created of shape `[0, h / 2, w / 2]`. Other features are resized to match [4, 8, 16] reductions.

In [8]:
model = smp.Unet(encoder_name="tu-xcit_tiny_24_p16_224", encoder_depth=5)
print(model.encoder.features_info_str)

[32m2024-06-08 22:48:49.088[0m | [1mINFO    [0m | [36msegmentation_models_pytorch.encoders.timm_universal[0m:[36m__init__[0m:[36m167[0m - [1mEncoder depth is adjusted to `encoder_depth=4` to match `timm` model features reductions [16, 16, 16].[0m
[32m2024-06-08 22:48:49.089[0m | [34m[1mDEBUG   [0m | [36msegmentation_models_pytorch.encoders.timm_universal[0m:[36m__init__[0m:[36m172[0m - [34m[1mEncoder has 1 dummy feature(s), because the real number of features (3) is less than specified encoder depth (4).[0m


index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 x     (-none-):                         ->          0, hw /  2     ->          0, hw /  2    
21  (blocks.21):        192, hw / 16     ->        192, hw / 16     ->        192, hw /  4    
22  (blocks.22):        192, hw / 16     ->        192, hw / 16     ->        192, hw /  8    
23  (blocks.23):        192, hw / 16     ->        192, hw / 16     ->        192, hw / 16    


## Specifying number of output channels for encoder

By default adapted features wil have the same number of channels as selected features. You can change this by passing `encoder_channels` argument to the encoder constructor. See the shape difference between in selected and adapted features:

In [9]:
model = smp.Unet(encoder_name="tu-resnet18", encoder_channels=[64, 64, 64, 64, 64])
print(model.encoder.features_info_str)

index / module :    Timm model features  ->     Selected features   ->     Adapted features   
-----------------------------------------------------------------------------------------------
 0       (act1):         64, hw /  2     ->         64, hw /  2     ->         64, hw /  2    
 1     (layer1):         64, hw /  4     ->         64, hw /  4     ->         64, hw /  4    
 2     (layer2):        128, hw /  8     ->        128, hw /  8     ->         64, hw /  8    
 3     (layer3):        256, hw / 16     ->        256, hw / 16     ->         64, hw / 16    
 4     (layer4):        512, hw / 32     ->        512, hw / 32     ->         64, hw / 32    


The number of channels is changed by applying `1x1` convolution without any non-linearity.