

# Converting PyTorch CLIP model to ONNX

In this notebook I'd like to show how to successfully convert PyTorch CLIP model to ONNX, simplify it, load into onnxruntime and get a 25% speed boost on CPU. 

Why? In the context of this particular competition - no reason, I don't think it's profitable to inference models on CPU rather than GPU. However, outside of Kaggle it might become a necessity to inference models on CPU. 

Another reason:

<center><img src="https://i.imgflip.com/1ch27o.jpg" alt="drawing" width="450"/></center>

In [1]:
!pip install git+https://github.com/openai/CLIP.git onnxruntime onnx-simplifier

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-lcsch6nq
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-lcsch6nq
  Resolved https://github.com/openai/CLIP.git to commit d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
  Preparing metadata (setup.py) ... [?25l- done
[?25hCollecting onnxruntime
  Downloading onnxruntime-1.12.1-cp37-cp37m-manylinux_2_27_x86_64.whl (4.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.9/4.9 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting onnx-simplifier
  Downloading onnx_simplifier-0.4.7-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m53.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting ftfy
  Downloading ftfy-6.1.1-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
import clip
import time
import torch
import onnx
import onnxruntime as ort
from onnxsim import simplify

from typing import Tuple

  from .autonotebook import tqdm as notebook_tqdm


## Config

In [2]:
# CLIP_BACKBONE = 'RN50'
CLIP_BACKBONE = 'ViT-B/32'
CLIP_ONNX_EXPORT_PATH = 'clip_resnet.onnx'
CLIP_ONNX_EXPORT_PATH_SIMP = 'clip_resnet_simplified.onnx'

ONNX_INPUT_NAMES = ["IMAGE", "TEXT"]
ONNX_OUTPUT_NAMES = ["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"]
ONNX_DYNAMIC_AXES = {
    "IMAGE": {
        0: "image_batch_size",
    },
    "TEXT": {
        0: "text_batch_size",
    },
    "LOGITS_PER_IMAGE": {
        0: "image_batch_size",
        1: "text_batch_size",
    },
    "LOGITS_PER_TEXT": {
        0: "text_batch_size",
        1: "image_batch_size",
    },
}

## Helpers 

Define some basic helper functions to easily load, export, and benchmark models.

In [4]:
def measure_mean_time_no_warmup(
    func, 
    func_inputs, 
    num_iters=250
) -> float:
    start_time = time.perf_counter()
    for _ in range(num_iters):
        func(*func_inputs)
    return (time.perf_counter() - start_time) / num_iters


def load_clip(backbone='RN50') -> Tuple[clip.model.CLIP, Tuple[torch.Tensor, torch.Tensor]]:
    pytorch_model, pre = clip.load(backbone)
    npx = pytorch_model.visual.input_resolution
    print(f"npx is {npx}")
    dummy_image = torch.randn(10, 3, npx, npx)
    dummy_texts = clip.tokenize(["quick brown fox", "lorem ipsum"])
    
    return pytorch_model, (dummy_image, dummy_texts)


def export_onnx(
    model, 
    inputs, 
    input_names,
    output_names,
    dynamic_axes,
    export_path
) -> None:
    torch.onnx.export(
        model=model, 
        args=inputs, 
        f=export_path, 
        export_params=True,
        input_names=input_names,
        output_names=output_names,
        opset_version=14,
        dynamic_axes=dynamic_axes
    )

### Load and export
Load PyTorch version of CLIP and export it to ONNX.

In [5]:
pytorch_model, dummy_input = load_clip(backbone=CLIP_BACKBONE)
pytorch_model.eval()

export_onnx(
    model=pytorch_model,
    inputs=dummy_input,
    input_names=ONNX_INPUT_NAMES,
    output_names=ONNX_OUTPUT_NAMES,
    dynamic_axes=ONNX_DYNAMIC_AXES,
    export_path=CLIP_ONNX_EXPORT_PATH,
)

npx is 224




## Check + simplify
Make sure ONNX model exported successfully and run onnx-simplifier on it.

In [6]:
# run checks
onnx_model = onnx.load(CLIP_ONNX_EXPORT_PATH)
onnx.checker.check_model(onnx_model)

# run additional checks and simplify
model_simp, check = simplify(onnx_model, skip_fuse_bn=True)
assert check, "Simplified ONNX model could not be validated"
onnx.save(model_simp, CLIP_ONNX_EXPORT_PATH_SIMP)

## onnxruntime

Load ONNX model into onnxruntime.

In [7]:
ort_sess = ort.InferenceSession(CLIP_ONNX_EXPORT_PATH_SIMP)

## Inference

Run inference for both PyTorch and ONNX version to verife that results match.

In [8]:
with torch.no_grad():
    pytorch_output = pytorch_model(*dummy_input)
onnx_output = ort_sess.run(ONNX_OUTPUT_NAMES, {"IMAGE": dummy_input[0].numpy(), "TEXT": dummy_input[1].numpy()})

assert all([torch.allclose(pt_pred, torch.tensor(onnx_pred)) for pt_pred, onnx_pred in zip(pytorch_output, onnx_output)])

print(f'Pytorch output: {pytorch_output}\n\nONNX output: {onnx_output}')

Pytorch output: (tensor([[20.0043, 24.3510],
        [19.5753, 24.7952],
        [19.9016, 25.3230],
        [20.2808, 24.6213],
        [19.5199, 24.3743],
        [19.7097, 24.8914],
        [20.0616, 24.7587],
        [19.5600, 24.4032],
        [20.2513, 24.6054],
        [19.4823, 24.6365]]), tensor([[20.0043, 19.5753, 19.9016, 20.2808, 19.5199, 19.7097, 20.0616, 19.5600,
         20.2513, 19.4823],
        [24.3510, 24.7952, 25.3230, 24.6213, 24.3743, 24.8914, 24.7587, 24.4032,
         24.6054, 24.6365]]))

ONNX output: [array([[20.00431 , 24.351057],
       [19.57537 , 24.795254],
       [19.901592, 25.322998],
       [20.280855, 24.621252],
       [19.519897, 24.37428 ],
       [19.709751, 24.891382],
       [20.061575, 24.758688],
       [19.560017, 24.403189],
       [20.251362, 24.605425],
       [19.482271, 24.636532]], dtype=float32), array([[20.00431 , 19.57537 , 19.901592, 20.280855, 19.519897, 19.709751,
        20.061575, 19.560017, 20.251362, 19.482271],
       [24.3

In [9]:
with torch.no_grad():
    pytorch_mean_time = measure_mean_time_no_warmup(func=pytorch_model, func_inputs=dummy_input)
onnx_runtime_mean_time = measure_mean_time_no_warmup(func=ort_sess.run, func_inputs=(["LOGITS_PER_IMAGE", "LOGITS_PER_TEXT"], {"IMAGE": dummy_input[0].numpy(), "TEXT": dummy_input[1].numpy()}))

print(f'PyTorch mean time: {round(pytorch_mean_time, 3)} sec\nONNX Runtime mean time: {round(onnx_runtime_mean_time, 3)} sec\nBoost from PT -> ONNX (%) {100*round(1 - onnx_runtime_mean_time/pytorch_mean_time, 2)}')

PyTorch mean time: 1.401 sec
ONNX Runtime mean time: 0.723 sec
Boost from PT -> ONNX (%) 48.0


## Where to go next?

The best course of action would be to take the converted model and load it into OpenVino for additional speed-ups. However, I personally couldn't do it due to the fact that some layers of CLIP don't seem to be supported by OpenVino. In case you manage to make it work - please, let me know, I'm curious about future improvements. 

<center><img src="https://i.pinimg.com/originals/06/82/e2/0682e26f337825b366e8e3e3e0003ad1.jpg" alt="drawing" width="450"/></center>