# Scripts for Exporting PyTorch Models to ONNX and CoreML

Depending on the backend, we prefer different qunatization schemes.

- For ONNX we use `uint8` quantization.
- For PyTorch we use `bfloat16` quantization.
- For CoreML we use `float32` representation.

In [None]:
!pip install --upgrade "uform[torch]" coremltools

In [None]:
import os

working_directory = "../.."
model_name = "uform3-image-text-english-small"
model_directory = os.path.join(working_directory, "models", model_name)
model_weights_path = os.path.join(model_directory, "torch_weight.pt")
config_path = os.path.join(model_directory, "config.json")
tokenizer_path = os.path.join(model_directory, "tokenizer.json")

In [None]:
import torch

state_dict = torch.load(model_weights_path)
list(state_dict.keys())

In [None]:
from uform.torch_encoders import ImageEncoder, TextEncoder
from uform.torch_processors import ImageProcessor, TextProcessor

In [None]:
image_encoder = ImageEncoder.from_pretrained(config_path, state_dict)
text_encoder = TextEncoder.from_pretrained(config_path, state_dict)
image_encoder, text_encoder

In [None]:
text_processor = TextProcessor(config_path, tokenizer_path)
image_processor = ImageProcessor(config_path)
text_processor, image_processor

In [None]:
import uform
from PIL import Image

text = 'a small red panda in a zoo'
image = Image.open('../../assets/unum.png')

text_data = text_processor(text)
image_data = image_processor(image)

image_features, image_embedding = image_encoder.forward(image_data, return_features=True)
text_features, text_embedding = text_encoder.forward(text_data, return_features=True)

image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape

## CoreML

In [None]:
import coremltools as ct
import torch

In [None]:
precision = ct.precision.FLOAT32

CoreML Tools provides a way to convert ONNX models to CoreML models. This script demonstrates how to convert an ONNX model to a CoreML model. For that, we need to provide an example input, and the tensor shapes will be inferred from that.

```python
        image_input = ct.TensorType(name="images", shape=image_data.shape)
        text_input = ct.TensorType(name="input_ids", shape=text_data["input_ids"].shape)
        text_attention_input = ct.TensorType(name="attention_mask", shape=text_data["attention_mask"].shape)
```

That, however, will only work for batch-size one. To support larger batches, we need to override the input shapes.

```python
        ct.RangeDim(lower_bound=25, upper_bound=100, default=45)
```

In [None]:
def generalize_first_dimensions(input_shape, upper_bound=64):
    if upper_bound == 1:
        return input_shape
    input_shape = (ct.RangeDim(lower_bound=1, upper_bound=upper_bound, default=1),) + input_shape[1:]
    return input_shape

generalize_first_dimensions(image_data["images"].shape), generalize_first_dimensions(text_data["input_ids"].shape), generalize_first_dimensions(text_data["attention_mask"].shape)

In [None]:
image_input = ct.TensorType(name="images", shape=generalize_first_dimensions(image_data["images"].shape, 1))
text_input = ct.TensorType(name="input_ids", shape=generalize_first_dimensions(text_data["input_ids"].shape, 1))
text_attention_input = ct.TensorType(name="attention_mask", shape=generalize_first_dimensions(text_data["attention_mask"].shape, 1))
text_features = ct.TensorType(name="features")
text_embeddings = ct.TensorType(name="embeddings")
image_features = ct.TensorType(name="features")
image_embeddings = ct.TensorType(name="embeddings")

In [None]:
module = image_encoder
module.eval()
module.return_features = True

traced_script_module = torch.jit.trace(module, example_inputs=image_data["images"])
traced_script_module

In [None]:
coreml_model = ct.convert(
    traced_script_module, source="pytorch",
    inputs=[image_input], outputs=[image_features, image_embeddings],
    convert_to='mlprogram', compute_precision=precision)

coreml_model.author = 'Unum Cloud'
coreml_model.license = 'Apache 2.0'
coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'
coreml_model.save(os.path.join(model_directory, "image_encoder.mlpackage"))

In [None]:
module = text_encoder
module.eval()
module.return_features = True

traced_script_module = torch.jit.trace(module, example_inputs=[text_data['input_ids'], text_data['attention_mask']])
traced_script_module

In [None]:
coreml_model = ct.convert(
    traced_script_module, source="pytorch",
    inputs=[text_input, text_attention_input], outputs=[text_features, text_embeddings],
    convert_to='mlprogram', compute_precision=precision)

coreml_model.author = 'Unum Cloud'
coreml_model.license = 'Apache 2.0'
coreml_model.short_description = 'Pocket-Sized Multimodal AI for Content Understanding'
coreml_model.save(os.path.join(model_directory, "text_encoder.mlpackage"))

# PyTorch

Let's ensure:

- the `model.text_encoder` inputs are called `input_ids` and `attention_mask`, and outputs are `embeddings` and `features`.
- the `model.image_encoder` input is called `input`, and outputs are `embeddings` and `features`.
- the model itself works fine in `f16` half-precision, so that the model is lighter and easier to download.

In [None]:
import torch
from safetensors import safe_open
from safetensors.torch import save_file

In [None]:
image_encoder.eval()
image_encoder.to(dtype=torch.bfloat16)

In [None]:
torch.save(image_encoder.state_dict(), os.path.join(model_directory, "image_encoder.pt"))

In [None]:
save_file(image_encoder.state_dict(), os.path.join(model_directory, "image_encoder.safetensors"))

In [None]:
text_encoder.eval()
text_encoder.to(dtype=torch.bfloat16)

In [None]:
torch.save(text_encoder.state_dict(), os.path.join(model_directory, "text_encoder.pt"))

In [None]:
save_file(text_encoder.state_dict(), os.path.join(model_directory, "text_encoder.safetensors"))

In [None]:
image_features, image_embedding = image_encoder.forward(image_data["images"].to(dtype=torch.bfloat16), return_features=True)
text_features, text_embedding = text_encoder.forward(text_data, return_features=True)

image_features.shape, text_features.shape, image_embedding.shape, text_embedding.shape

## ONNX

In [None]:
!pip install onnx onnxconverter-common

In [None]:
from torch.onnx import export as onnx_export
import torch

We can't immediately export to `bfloat16` as it's not supported by ONNX, but we also can't export to `float16`, as the forward pass (that will be traced) is gonna fail. So let's export to `float32` ONNX file first.

In [None]:
module = text_encoder
module.eval()
module.return_features = True
module.to(dtype=torch.float32)

onnx_export(
    module,
    (text_data["input_ids"], text_data["attention_mask"]), 
    os.path.join(model_directory, "text_encoder.onnx"), 
    export_params=True,
    opset_version=15,
    do_constant_folding=True,
    input_names = ['input_ids', 'attention_mask'], 
    output_names = ['features', 'embeddings'],
    dynamic_axes={
        'input_ids' : {0 : 'batch_size'}, 
        'attention_mask' : {0 : 'batch_size'}, 
        'features' : {0 : 'batch_size'}, 
        'embeddings' : {0 : 'batch_size'}})

Now repeat the same for images.

In [None]:
module = image_encoder
module.eval()
module.return_features = True
module.to(dtype=torch.float32)

torch.onnx.export(
    module,
    image_data["images"], 
    os.path.join(model_directory, "image_encoder.onnx"), 
    export_params=True,
    opset_version=15,
    do_constant_folding=True,
    input_names = ['images'], 
    output_names = ['features', 'embeddings'],
    dynamic_axes={
        'images' : {0 : 'batch_size'},
        'features' : {0 : 'batch_size'},
        'embeddings' : {0 : 'batch_size'}})

### Quantizing to `float16`

Let's use [additional ONNX tooling](https://onnxruntime.ai/docs/performance/model-optimizations/float16.html#mixed-precision) to convert to half-precision.

In [None]:
import onnx
from onnxconverter_common import float16

In [None]:
module_path = os.path.join(model_directory, "text_encoder.onnx")
module = onnx.load(module_path)
module_fp16 = float16.convert_float_to_float16(module)
onnx.save(module_fp16, module_path)

In [None]:
module_path = os.path.join(model_directory, "image_encoder.onnx")
module = onnx.load(module_path)
module_fp16 = float16.convert_float_to_float16(module)
onnx.save(module_fp16, module_path)

### Quantizing to `uint8`

We can further quantize the model into `uint8` using ONNX quantization tools.
The `int8` is default variant, but [some of the operators don't support it](https://github.com/microsoft/onnxruntime/issues/15888).

In [None]:
from onnxruntime.quantization import quantize_dynamic, QuantType

In [None]:
module_path = os.path.join(model_directory, "text_encoder.onnx")
quantize_dynamic(module_path, module_path, weight_type=QuantType.QUInt8)

In [None]:
module_path = os.path.join(model_directory, "image_encoder.onnx")
quantize_dynamic(module_path, module_path, weight_type=QuantType.QUInt8)

Let's make sure that all the text inputs are integers of identical type - `int32`.

In [None]:
import onnx
import os
from onnx import helper

# Load the ONNX model
module_path = os.path.join(model_directory, "text_encoder.onnx")
module = onnx.load(module_path)

# Get the module's graph
graph = module.graph

# Iterate through the inputs and update the data type of `input_ids`
for input_tensor in graph.input:
    # Check if this is the tensor we want to change
    if input_tensor.name == 'input_ids' or input_tensor.name == 'attention_mask':
        # Get the tensor type information
        tensor_type = input_tensor.type.tensor_type
        # Set the element type to INT32 (int32's enum value in onnx is 6)
        tensor_type.elem_type = onnx.TensorProto.INT32

# Optionally, check that the module is still valid
onnx.checker.check_model(module)

# Save the modified module
onnx.save(module, module_path)

We can use the following function to print and validate the input and output types of the ONNX model files.

In [None]:
def print_model_inputs_and_outputs(onnx_model_path):
    model = onnx.load(onnx_model_path)

    # Get the model's graph
    graph = model.graph

    # Print input information
    print("Model Inputs:")
    for input_tensor in graph.input:
        tensor_type = input_tensor.type.tensor_type
        # Get the element type (data type)
        elem_type = tensor_type.elem_type
        # Convert numeric type to readable format
        readable_type = onnx.TensorProto.DataType.Name(elem_type)
        # Get tensor shape
        shape = [dim.dim_value for dim in tensor_type.shape.dim]
        print(f"Name: {input_tensor.name}, Type: {readable_type}, Shape: {shape}")

    # Print output information similarly if needed
    print("\nModel Outputs:")
    for output_tensor in graph.output:
        tensor_type = output_tensor.type.tensor_type
        elem_type = tensor_type.elem_type
        readable_type = onnx.TensorProto.DataType.Name(elem_type)
        shape = [dim.dim_value for dim in tensor_type.shape.dim]
        print(f"Name: {output_tensor.name}, Type: {readable_type}, Shape: {shape}")

Let's check that the runtime can actually load those models.

In [None]:
import onnxruntime as ort
session_options = ort.SessionOptions()
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL

In [None]:
module_path = os.path.join(model_directory, "text_encoder.onnx")
session = ort.InferenceSession(module_path, sess_options=session_options)

In [None]:
module_path = os.path.join(model_directory, "image_encoder.onnx")
session = ort.InferenceSession(module_path, sess_options=session_options)

# Upload to Hugging Face

In [None]:
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../models/uform3-image-text-english-small/ . --exclude="torch_weight.pt"

In [None]:
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../image_encoder.onnx image_encoder.onnx
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../text_encoder.onnx text_encoder.onnx
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../image_encoder.safetensors image_encoder.safetensors
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../text_encoder.safetensors text_encoder.safetensors
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../image_encoder.pt image_encoder.pt
!huggingface-cli upload unum-cloud/uform3-image-text-english-small ../../text_encoder.pt text_encoder.pt