# OPS-SAT Pytorch to TF model conversion notebook.

ESA's [Kelvins](https://kelvins.esa.int) competition "[the OPS-SAT case](https://kelvins.esa.int/opssat/home/)" is a novel data-centric challenge that asks you to work with the raw data of a satellite and very few provided labels to find the best parameters for a given machine learning model. <br> Compared to previous competitions on Kelvins (like the [Pose Estimation](https://kelvins.esa.int/pose-estimation-2021/) or the [Proba-V Super-resolution challenge](https://kelvins.esa.int/proba-v-super-resolution/)) where the test-set is provided and the infered results are submitted, for the OPS-SAT case, we will run inference on the Kelvins server directly! To this aim, you need to use to submit the parameters of a [Keras](https://keras.io/api/models/model/) implementation of `EfficientNet-lite-0` (`EfficientNetLiteB0` model), included in the file `efficientnet_lite.py`. <br>The latter is provided in our [starter-kit on Gitlab](https://gitlab.com/EuropeanSpaceAgency/the_opssat_case_starter_kit), and it is based on [efficientnet-lite-keras](https://github.com/sebastian-sz/efficientnet-lite-keras), with some modifications. <br><br>To facilitate [PyTorch](https://pytorch.org/) developers, this notebook provide some utils to convert a [PyTorch](https://pytorch.org/)  model (file `.pth`) based on [efficientnet-lite-pytorch](https://pypi.org/project/efficientnet-lite-pytorch/) (`Pytorch` implementation of `EfficientNet-lite-0`) to the requested `EfficientNetLiteB0`. 

**DISCLAIMER**: Comparing `efficientnet-lite-pytorch` and `EfficientNetLiteB0` models (as for the original implementation [efficientnet-lite-keras](https://github.com/sebastian-sz/efficientnet-lite-keras)), we noticed some differences in way padding is performed. This fact leads to a different output shape after the `_blocks[5]._depthwise_conv` layer in `PyTorch`, matching the `Keras` `block4a_dwconv` layer (`PyTorch` output shape, [1, 240, 12, 12]  vs `Keras` output shape, [1, 240, 13, 13]). 
**This could impact the performance of your model**.

# 1. Module imports

If you do not have a GPU, uncomment and run the next commands.

In [None]:
#import os
#os.environ["CUDA_VISIBLE_DEVICES"]="-1"

Other imports.

In [None]:
import efficientnet_lite_pytorch
from efficientnet_lite0_pytorch_model import EfficientnetLite0ModelFile
from efficientnet_lite import EfficientNetLiteB0
import torch
import tensorflow as tf
import os

# 2. Utility Functions

The next function is used to pass the trained parameters of your `Pytorch` model to the requested `Keras` model.

In [None]:
def convert_efficientnet_lite0_parameters(model_pytorch, model_tf):
    """Converts a trained EfficientNet_lite0 Pytorch by transferring parameters to the untrained EfficientNet_lite0 Tensorflow model.

    Args:
        model_pytorch (pytorch model): trained EfficientNet_lite0 Pytorch model.
        model_tf (tf model): untrained EfficientNet_lite0 Tensorflow model.

    Returns:
        model_tf: EfficientNet_Lite 0 with trained weights in Pytorch.
    """
    #Torch layer lists
    model_pytorch_layers=[model_pytorch._conv_stem, model_pytorch._bn0, model_pytorch._blocks[0]._depthwise_conv,model_pytorch._blocks[0]._bn1,model_pytorch._blocks[0]._project_conv, model_pytorch._blocks[0]._bn2, model_pytorch._blocks[0]._swish,model_pytorch._blocks[1:],model_pytorch._conv_head, model_pytorch._bn1, model_pytorch._avg_pooling, model_pytorch._dropout, model_pytorch._fc, model_pytorch._swish]
    mb_block_layers=["_expand_conv","_bn0","_depthwise_conv","_bn1","_project_conv","_bn2","_swish"]
    weight_list=[]
    n=0
    print("Extracting Pytorch model parameters...")
    #Extracting pytorch layers
    for layer in model_pytorch_layers[:-1]:
        if n == 7:
            for n_mb_blocks in range(len(layer)):
                m_block=layer[n_mb_blocks]
                for n_layer_mb_conv in range(len(mb_block_layers[:-1])):
                    layer_block_l=getattr(m_block,mb_block_layers[n_layer_mb_conv])
                    if (n_layer_mb_conv != 1) and (n_layer_mb_conv != 3) and (n_layer_mb_conv != 5):
                        weight_list.append([layer_block_l.weight, layer_block_l.bias])
                    else:
                        weight_list.append([layer_block_l.weight, layer_block_l.bias, layer_block_l.momentum, layer_block_l.eps, layer_block_l.running_mean, layer_block_l.running_var])
                        
        elif (n == 1) or (n == 3) or (n == 5) or (n == 9):
            weight_list.append([layer.weight, layer.bias, layer.momentum, layer.eps, layer.running_mean, layer.running_var])
        elif (n != 6) and (n != 10) and (n != 11):
            weight_list.append([layer.weight, layer.bias])
        n+=1
    
    print("Converting layers...")
    last_tf_layer=-1
    for k in range(len(weight_list)):
        w_shape=weight_list[k][0].shape
        if len(w_shape) == 4:
            w=weight_list[k][0].permute(2, 3, 1, 0)
        else:
            w=weight_list[k][0]
            
        if not(weight_list[k][1] is None):
            b=weight_list[k][1]
        else:
            b=None
                
        for n in range(last_tf_layer+1,len(model_tf.layers)):
            layer=model_tf.layers[n]
            if isinstance(layer, tf.keras.layers.DepthwiseConv2D ) or isinstance(layer, tf.keras.layers.Conv2D ) or isinstance(layer, tf.keras.layers.Dense):
                if isinstance(layer, tf.keras.layers.DepthwiseConv2D ):
                    w=weight_list[k][0].permute(2, 3, 0, 1)
                elif isinstance(layer, tf.keras.layers.Dense):
                    w=weight_list[k][0].permute(1,0)
                    
                if not(b is None):
                    model_tf.layers[n].set_weights([w.detach().cpu().numpy(), b.detach().cpu().numpy()]) 
                else:
                    model_tf.layers[n].set_weights([w.detach().cpu().numpy()])
                last_tf_layer=n
                break
                
            elif isinstance(layer, tf.keras.layers.BatchNormalization):
                gamma=weight_list[k][0].detach().cpu().numpy()
                beta=weight_list[k][1].detach().cpu().numpy()
                momentum=weight_list[k][2]
                epsilon=weight_list[k][3]
                running_mean=weight_list[k][4].detach().cpu().numpy()
                running_var=weight_list[k][5].detach().cpu().numpy()
                model_tf.layers[n].set_weights([gamma, beta, running_mean, running_var])
                model_tf.layers[n].momentum=momentum
                model_tf.layers[n].epsilon=epsilon
                
                last_tf_layer=n
                break
    print("Model converted.")
    return model_tf

# 3. Loading Pytorch model

Instantiating an `efficientnet_lite_pytorch` model.

In [None]:
weights_path = EfficientnetLite0ModelFile.get_model_file_path()
model_pytorch= efficientnet_lite_pytorch.EfficientNet.from_pretrained('efficientnet-lite0', weights_path = weights_path,num_classes=8, in_channels=3)
model_pytorch.eval()

Uncomment next line and update the path to trained '`efficientnet_lite_pytorch` (`.pth`).

In [None]:
#checkpoint_path="Path to the .pth file."

If you need to load your model on the CPU, please change `torch.load(checkpoint_path) ` to `torch.load(checkpoint_path, map_location=torch.device('cpu'))`.

In [None]:
#Loading model
model_pytorch.load_state_dict(torch.load(checkpoint_path)['eval_model'])

# 4. Loading Keras model

Instantiating an `EfficientNetLiteB0` model.

In [None]:
model_tf=EfficientNetLiteB0(classes=8, weights=None, input_shape=(200, 200,3), classifier_activation=None)

# 5. Pytorch to Keras conversion.

The next function perses the trained `Pytorch` model, extracts its parameters and load them into the Keras model.

In [None]:
model_tf=convert_efficientnet_lite0_parameters(model_pytorch, model_tf)

# 6. Save output Keras model

Saving the parameters of the converted model as `.h5` file. Please, adjust `output_path` with the target path of the `.h5` file and uncomment the next line.

In [None]:
#output_path="Path to the output .h5 file"

In [None]:
#Saving output keras model.
model_tf.save(output_path)