In [4]:
#https://github.com/Lednik7/CLIP-ONNX

In [5]:
# !pip install git+https://github.com/Lednik7/CLIP-ONNX.git
# !pip install git+https://github.com/openai/CLIP.git
# !pip install onnxruntime-gpu

## Export CLIP embeddings to ONNX


In [1]:
import torch
import clip
from PIL import Image

# 1) Load your CLIP model
device = "cpu"
clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False)
clip_model.eval()

# 2) Define a wrapper module that returns image embeddings
class CLIPImageEncoder(torch.nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model

    def forward(self, image):
        # CLIP’s encode_image returns an unnormalized embedding of size [batch_size, 512]
        return self.clip_model.encode_image(image)

image_encoder = CLIPImageEncoder(clip_model).to(device)

# 3) Prepare a dummy input for shape tracing
dummy_image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)  # [1, 3, 224, 224]

# 4) Export the custom image-encoder to ONNX
torch.onnx.export(
    image_encoder,
    dummy_image,                       # Only one input
    "clip_image_encoder.onnx",
    export_params=True,
    opset_version=14,                 # for scaled_dot_product_attention
    do_constant_folding=True,
    input_names=["image_input"],      # single input
    output_names=["image_features"],  # single output
    dynamic_axes={
        "image_input": {0: "batch_size"},
        "image_features": {0: "batch_size"},
    },
)
print("Exported CLIP image-encoder to clip_image_encoder.onnx")

OSError: [Errno 28] No space left on device: '/tmp/tmp2rimi2vf'

# Text embeddings

In [None]:
class CLIPTextEncoder(torch.nn.Module):
    def __init__(self, clip_model):
        super().__init__()
        self.clip_model = clip_model

    def forward(self, text):
        return self.clip_model.encode_text(text)

text_encoder = CLIPTextEncoder(clip_model)
# Then export with dummy text input
dummy_text = clip.tokenize(["hello world"]).to(device)
torch.onnx.export(
    text_encoder,
    dummy_text,
    "clip_text_encoder.onnx",
    ...
    input_names=["text_input"],
    output_names=["text_features"],
)
