In [1]:
!pip3 install torch transformers optimum pillow

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
from pathlib import Path
from typing import Optional, Dict, Union, Tuple

import torch
import numpy as np
from PIL import Image
from transformers import (
    CLIPVisionModelWithProjection,
    CLIPTextModelWithProjection,
    CLIPImageProcessor,
    CLIPTokenizerFast,
)
from transformers.models.clip.modeling_clip import (
    CLIPTextModelOutput,
    CLIPVisionModelOutput,
    CLIPModel,
)
from optimum.onnxruntime import ORTModelForCustomTasks
from optimum.exporters.onnx.model_configs import CLIPTextWithProjectionOnnxConfig, ViTOnnxConfig
from optimum.exporters.onnx import export_models

In [3]:
model_id = "openai/clip-vit-base-patch32"
output_dir = "split-clip-onnx"

In [4]:
class CLIPVisionModelWithProjectionOnnxConfig(ViTOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        return {
            "image_embeds": {0: "batch_size"},
        }

In [5]:
class CLIPTextModelWithProjectionAndAttentionOnnxConfig(CLIPTextWithProjectionOnnxConfig):
    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        return {
            "input_ids": {0: "batch_size", 1: "sequence_length"},
            "attention_mask": {0: "batch_size", 1: "sequence_length"},
        }

In [6]:
class CLIPTextModelWithProjectionNormalized(CLIPTextModelWithProjection):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPTextModelOutput]:
        text_outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            output_attentions,
            output_hidden_states,
            return_dict,
        )
        normalized_text_embeds = text_outputs.text_embeds / text_outputs.text_embeds.norm(
            p=2, dim=-1, keepdim=True
        )
        return CLIPTextModelOutput(
            text_embeds=normalized_text_embeds,
            last_hidden_state=text_outputs.last_hidden_state,
            hidden_states=text_outputs.hidden_states,
            attentions=text_outputs.attentions,
        )

In [7]:
class CLIPVisionModelWithProjectionNormalized(CLIPVisionModelWithProjection):
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPVisionModelOutput]:
        vision_outputs = super().forward(pixel_values, return_dict)
        normalized_image_embeds = vision_outputs.image_embeds / vision_outputs.image_embeds.norm(
            p=2, dim=-1, keepdim=True
        )
        return CLIPVisionModelOutput(
            image_embeds=normalized_image_embeds,
            last_hidden_state=vision_outputs.last_hidden_state,
            hidden_states=vision_outputs.hidden_states,
            attentions=vision_outputs.attentions,
        )

In [8]:
text_model = CLIPTextModelWithProjectionNormalized.from_pretrained(model_id)

In [9]:
vision_model = CLIPVisionModelWithProjectionNormalized.from_pretrained(model_id)

In [10]:
text_config = CLIPTextModelWithProjectionAndAttentionOnnxConfig(text_model.config)
vision_config = CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)

In [11]:
text_model.config.save_pretrained(f"./{output_dir}/text")
vision_model.config.save_pretrained(f"./{output_dir}/image")

In [12]:
export_models(
    models_and_onnx_configs={
        "text_model": (text_model, text_config),
        "vision_model": (vision_model, vision_config),
    },
    output_dir=Path(f"./{output_dir}"),
)


***** Exporting submodel 1/2: CLIPTextModelWithProjectionNormalized *****
Using framework PyTorch: 2.3.0
  if input_shape[-1] > 1 or self.sliding_window is not None:
  if past_key_values_length > 0:
  if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
  if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len):
  if attention_mask.size() != (bsz, 1, tgt_len, src_len):
  if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):

***** Exporting submodel 2/2: CLIPVisionModelWithProjectionNormalized *****
Using framework PyTorch: 2.3.0


[[['input_ids', 'attention_mask'], ['pixel_values']],
 [['text_embeds', 'last_hidden_state'], ['image_embeds']]]

In [13]:
os.rename(f"./{output_dir}/text_model.onnx", f"./{output_dir}/text/model.onnx")
os.rename(f"./{output_dir}/vision_model.onnx", f"./{output_dir}/image/model.onnx")

In [14]:
ort_vision_model = ORTModelForCustomTasks.from_pretrained(
    f"./{output_dir}/image", config=vision_config
)
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_input = image_processor(images=Image.open("assets/image.jpeg"), return_tensors="pt")

with torch.inference_mode():
    image_outputs = ort_vision_model(**image_input)
image_processor.save_pretrained(f"./{output_dir}/image")

['./split-clip-onnx/image/preprocessor_config.json']

In [15]:
ort_text_model = ORTModelForCustomTasks.from_pretrained(f"./{output_dir}/text", config=text_config)
text_processor = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
text_input = text_processor("What am I using?", return_tensors="pt")

with torch.inference_mode():
    text_outputs = ort_text_model(**text_input)
text_processor.save_pretrained(f"./{output_dir}/text")

('./split-clip-onnx/text/tokenizer_config.json',
 './split-clip-onnx/text/special_tokens_map.json',
 './split-clip-onnx/text/vocab.json',
 './split-clip-onnx/text/merges.txt',
 './split-clip-onnx/text/added_tokens.json',
 './split-clip-onnx/text/tokenizer.json')

In [16]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
inputs = {**text_input, **image_input}
clip_model.eval()
with torch.inference_mode():
    gt_output = clip_model(**inputs)

In [17]:
print(np.allclose(gt_output.text_embeds.numpy(), text_outputs.text_embeds, atol=1e-6))
print(np.allclose(gt_output.image_embeds.numpy(), image_outputs.image_embeds, atol=1e-6))

True
True


In [None]:
# from huggingface_hub import create_repo
#
# create_repo(repo_id='jmzzomg/clip-vit-base-patch32-vision-onnx', exist_ok=True, token='')
# create_repo(repo_id='jmzzomg/clip-vit-base-patch32-text-onnx', exist_ok=True, token='')
#
# ort_text_model.push_to_hub(save_directory=f'./{output_dir}/text/', repository_id='jmzzomg/clip-vit-base-patch32-text-onnx', use_auth_token='')
# ort_vision_model.push_to_hub(save_directory=f'./{output_dir}/image', repository_id='jmzzomg/clip-vit-base-patch32-vision-onnx', use_auth_token='')