In [4]:
# Step 1: Export to ONNX is already provided in the repository
import torch

from mobile_sam import sam_model_registry
from mobile_sam.utils.onnx import SamOnnxModel

import warnings

checkpoint = "../weights/mobile_sam.pt"
model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint=checkpoint)

# onnx_model_path = "sam_onnx_opset11.onnx"
# onnx_model_path = "sam_onnx_opset13.onnx"
onnx_model_path = "sam_onnx_opset16.onnx"
onnx_model = SamOnnxModel(sam, return_single_mask=True)

dynamic_axes = {
    "point_coords": {1: "num_points"},
    "point_labels": {1: "num_points"},
}

embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]

dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}

output_names = ["masks", "iou_predictions", "low_res_masks"]

with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
    warnings.filterwarnings("ignore", category=UserWarning)

    with open(onnx_model_path, "wb") as f:

        # TODO: review this
        # opset_version 16 is required for dynamic axes
        # opset_version 16 is required for dynamic axes
        # opset_version 16 is required for dynamic axes
        # opset_version 16 is required for dynamic axes
        # opset_version 16 is required for dynamic axes
        # opset_version 16 is required for dynamic axes

        # opset_version 16 gives error when converting: "BackendIsNotSupposedToImplementIt: Unsqueeze version 13 is not implemented."

        torch.onnx.export(
            onnx_model,
            tuple(dummy_inputs.values()),
            f,
            export_params=True,
            verbose=False,

            opset_version=16,
            # opset_version=11,

            do_constant_folding=True,
            input_names=list(dummy_inputs.keys()),
            output_names=output_names,
            dynamic_axes=dynamic_axes,
        )


verbose: False, log level: Level.ERROR



In [None]:
# Optional: Quantize the model
should_quantize = False

if should_quantize:
    from onnxruntime.quantization import QuantType
    from onnxruntime.quantization.quantize import quantize_dynamic

    onnx_model_quantized_path = "sam_onnx_quantized.onnx"
    quantize_dynamic(
        model_input=onnx_model_path,
        model_output=onnx_model_quantized_path,
        optimize_model=True,
        per_channel=False,
        reduce_range=False,
        weight_type=QuantType.QUInt8,
    )

In [5]:
# Step 2: Convert ONNX to TensorFlow (using onnx-tf)
from onnx_tf.backend import prepare
import onnx

# onnx_model_path = "sam_onnx_opset11.onnx"
# onnx_model_path = "sam_onnx_opset13.onnx"
onnx_model_path = "sam_onnx_opset16.onnx"
onnx_model = onnx.load(onnx_model_path)

tf_rep = prepare(onnx_model)
exported_mobilesam_tf_model_dir = "exported_mobilesam_tf_model"
tf_rep.export_graph(exported_mobilesam_tf_model_dir)

# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
# TODO: consider the following:
#  https://github.com/SiliconLabs/mltk/blob/master/mltk/tutorials/onnx_to_tflite.ipynb

2023-09-04 03:58:32.669710: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



BackendIsNotSupposedToImplementIt: in user code:

    File "/Users/MRutkowski/.pyenv/versions/3.10.13/envs/mobile-sam/lib/python3.10/site-packages/onnx_tf/backend_tf_module.py", line 99, in __call__  *
        output_ops = self.backend._onnx_node_to_tensorflow_op(onnx_node,
    File "/Users/MRutkowski/.pyenv/versions/3.10.13/envs/mobile-sam/lib/python3.10/site-packages/onnx_tf/backend.py", line 347, in _onnx_node_to_tensorflow_op  *
        return handler.handle(node, tensor_dict=tensor_dict, strict=strict)
    File "/Users/MRutkowski/.pyenv/versions/3.10.13/envs/mobile-sam/lib/python3.10/site-packages/onnx_tf/handlers/handler.py", line 61, in handle  *
        raise BackendIsNotSupposedToImplementIt("{} version {} is not implemented.".format(node.op_type, cls.SINCE_VERSION))

    BackendIsNotSupposedToImplementIt: Unsqueeze version 13 is not implemented.


In [None]:
# Step 3: Convert TensorFlow to TFLite
import tensorflow as tf

converter = tf.lite.TFLiteConverter.from_saved_model("exported_mobilesam_tf_model")
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()

with open("mobilesam_tflite_model/model.tflite", "wb") as f:
    f.write(tflite_model)