In [None]:
!pip3 install -q torch transformers optimum pillow

In [None]:
import os
from pathlib import Path
from typing import Optional, Dict, Union, Tuple

import torch
import numpy as np
from PIL import Image
from transformers import (
    CLIPVisionModelWithProjection,
    CLIPTextModelWithProjection,
    CLIPImageProcessor,
    CLIPTokenizerFast,
)
from transformers.models.clip.modeling_clip import (
    CLIPTextModelOutput,
    CLIPVisionModelOutput,
    CLIPModel,
)
from optimum.onnxruntime import ORTModelForCustomTasks
from optimum.exporters.onnx.model_configs import CLIPTextWithProjectionOnnxConfig, ViTOnnxConfig
from optimum.exporters.onnx import export_models

In [None]:
model_id = "openai/clip-vit-base-patch32"
output_dir = "split-clip-onnx"

In [None]:
class CLIPVisionModelWithProjectionOnnxConfig(ViTOnnxConfig):
    @property
    def outputs(self) -> Dict[str, Dict[int, str]]:
        return {
            "image_embeds": {0: "batch_size"},
        }

In [None]:
class CLIPTextModelWithProjectionAndAttentionOnnxConfig(CLIPTextWithProjectionOnnxConfig):
    @property
    def inputs(self) -> Dict[str, Dict[int, str]]:
        return {
            "input_ids": {0: "batch_size", 1: "sequence_length"},
            "attention_mask": {0: "batch_size", 1: "sequence_length"},
        }

In [None]:
class CLIPTextModelWithProjectionNormalized(CLIPTextModelWithProjection):
    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPTextModelOutput]:
        text_outputs = super().forward(
            input_ids,
            attention_mask,
            position_ids,
            output_attentions,
            output_hidden_states,
            return_dict,
        )
        normalized_text_embeds = text_outputs.text_embeds / text_outputs.text_embeds.norm(
            p=2, dim=-1, keepdim=True
        )
        return CLIPTextModelOutput(
            text_embeds=normalized_text_embeds,
            last_hidden_state=text_outputs.last_hidden_state,
            hidden_states=text_outputs.hidden_states,
            attentions=text_outputs.attentions,
        )

In [None]:
class CLIPVisionModelWithProjectionNormalized(CLIPVisionModelWithProjection):
    def forward(
        self,
        pixel_values: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CLIPVisionModelOutput]:
        vision_outputs = super().forward(pixel_values, return_dict)
        normalized_image_embeds = vision_outputs.image_embeds / vision_outputs.image_embeds.norm(
            p=2, dim=-1, keepdim=True
        )
        return CLIPVisionModelOutput(
            image_embeds=normalized_image_embeds,
            last_hidden_state=vision_outputs.last_hidden_state,
            hidden_states=vision_outputs.hidden_states,
            attentions=vision_outputs.attentions,
        )

In [None]:
text_model = CLIPTextModelWithProjectionNormalized.from_pretrained(model_id)

In [None]:
vision_model = CLIPVisionModelWithProjectionNormalized.from_pretrained(model_id)

In [None]:
text_config = CLIPTextModelWithProjectionAndAttentionOnnxConfig(text_model.config)
vision_config = CLIPVisionModelWithProjectionOnnxConfig(vision_model.config)

In [None]:
text_model.config.save_pretrained(f"./{output_dir}/text")
vision_model.config.save_pretrained(f"./{output_dir}/image")

In [None]:
export_models(
    models_and_onnx_configs={
        "text_model": (text_model, text_config),
        "vision_model": (vision_model, vision_config),
    },
    output_dir=Path(f"./{output_dir}"),
)

In [None]:
os.rename(f"./{output_dir}/text_model.onnx", f"./{output_dir}/text/model.onnx")
os.rename(f"./{output_dir}/vision_model.onnx", f"./{output_dir}/image/model.onnx")

In [None]:
ort_vision_model = ORTModelForCustomTasks.from_pretrained(
    f"./{output_dir}/image", config=vision_config
)
image_processor = CLIPImageProcessor.from_pretrained("openai/clip-vit-base-patch32")
image_input = image_processor(images=Image.open("assets/image.jpeg"), return_tensors="pt")

with torch.inference_mode():
    image_outputs = ort_vision_model(**image_input)
image_processor.save_pretrained(f"./{output_dir}/image")

In [None]:
ort_text_model = ORTModelForCustomTasks.from_pretrained(f"./{output_dir}/text", config=text_config)
text_processor = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
text_input = text_processor("What am I using?", return_tensors="pt")

with torch.inference_mode():
    text_outputs = ort_text_model(**text_input)
text_processor.save_pretrained(f"./{output_dir}/text")

In [None]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
inputs = {**text_input, **image_input}
clip_model.eval()
with torch.inference_mode():
    gt_output = clip_model(**inputs)

In [None]:
print(np.allclose(gt_output.text_embeds.numpy(), text_outputs.text_embeds, atol=1e-6))
print(np.allclose(gt_output.image_embeds.numpy(), image_outputs.image_embeds, atol=1e-6))

In [None]:
token = "<token>"
# create_repo(repo_id='Qdrant/clip-ViT-B-32-vision', exist_ok=True, token=token)
# create_repo(repo_id='Qdrant/clip-ViT-B-32-text', exist_ok=True, token=token)

ort_text_model.push_to_hub(
    save_directory=f"./{output_dir}/text/",
    repository_id="Qdrant/clip-ViT-B-32-text",
    use_auth_token=token,
)
ort_vision_model.push_to_hub(
    save_directory=f"./{output_dir}/image",
    repository_id="Qdrant/clip-ViT-B-32-vision",
    use_auth_token=token,
)