Skip to content

update bert vit clip qdq to align with Model Lab #1893

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 5, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 22 additions & 2 deletions examples/bert/bert_ptq_qdq.json
Original file line number Diff line number Diff line change
@@ -52,16 +52,36 @@
}
},
"passes": {
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
"conversion": { "type": "OnnxConversion", "target_opset": 20 },
"dynamic_shape_to_fixed": {
"type": "DynamicToFixedShape",
"dim_param": [ "batch_size", "sequence_length" ],
"dim_value": [ 1, 128 ]
},
"surgery": { "type": "GraphSurgeries", "surgeries": [ { "surgeon": "ReplaceAttentionMaskValue" } ] },
"surgery": {
"type": "GraphSurgeries",
"surgeries": [
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
{ "surgeon": "MatMulAddToGemm" }
]
},
"transformer_optimizer": {
"type": "OrtTransformersOptimization",
"model_type": "bert",
"opt_level": 1,
"optimization_options": {
"enable_gelu": true,
"enable_bias_gelu": false,
"enable_layer_norm": true,
"enable_skip_layer_norm": false,
"enable_bias_skip_layer_norm": false,
"enable_attention": false
}
},
"quantization": {
"type": "OnnxStaticQuantization",
"data_config": "glue_mrpc",
"quant_preprocess": true,
"activation_type": "uint16",
"precision": "uint8"
}
28 changes: 24 additions & 4 deletions examples/bert/google_bert_qdq.json
Original file line number Diff line number Diff line change
@@ -41,16 +41,36 @@
}
},
"passes": {
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
"dynamic_shape_to_fixed": {
"conversion": { "type": "OnnxConversion", "target_opset": 20 },
"to_fixed_shape": {
"type": "DynamicToFixedShape",
"dim_param": [ "batch_size", "sequence_length" ],
"dim_value": [ 1, 128 ]
},
"surgery": { "type": "GraphSurgeries", "surgeries": [ { "surgeon": "ReplaceAttentionMaskValue" } ] },
"quantization": {
"surgery": {
"type": "GraphSurgeries",
"surgeries": [
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
{ "surgeon": "MatMulAddToGemm" }
]
},
"transformer_optimizer": {
"type": "OrtTransformersOptimization",
"model_type": "bert",
"opt_level": 1,
"optimization_options": {
"enable_gelu": true,
"enable_bias_gelu": false,
"enable_layer_norm": true,
"enable_skip_layer_norm": false,
"enable_bias_skip_layer_norm": false,
"enable_attention": false
}
},
"OnnxQuantization": {
"type": "OnnxStaticQuantization",
"data_config": "xnli",
"quant_preprocess": true,
"activation_type": "uint16",
"precision": "uint8"
}
147 changes: 147 additions & 0 deletions examples/clip/qdq/clip_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from __future__ import annotations

from collections import OrderedDict
from itertools import chain

import torch
from transformers import (
AutoProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)

from olive.data.component.dataset import BaseDataset
from olive.data.registry import Registry

HF_MODEL_SUBFOLDER_MAPPING = {
"sentence-transformers/clip-ViT-B-32": "0_CLIPModel",
}


def load_image_encoder(model_name):
return CLIPVisionModelWithProjection.from_pretrained(
model_name,
subfolder=HF_MODEL_SUBFOLDER_MAPPING.get(model_name, ""),
).eval()


def load_text_encoder(model_name):
return CLIPTextModelWithProjection.from_pretrained(
model_name,
subfolder=HF_MODEL_SUBFOLDER_MAPPING.get(model_name, ""),
).eval()


def hfdataset_pre_process_for_clip(
dataset,
processor,
torch_model=None,
image_col: str | None = None,
caption_col: str | None = None,
label_col: str = "label",
max_samples: int | None = None,
max_length: int = 77,
batch_size: int = 32,
):
def generate_inputs(sample, indices):
captions = sample.get(caption_col, None)
images = sample.get(image_col, None)

kwargs = {
"padding": "max_length",
"max_length": max_length,
"truncation": True,
"add_special_tokens": True,
"return_tensors": "pt",
}
if images:
kwargs["images"] = [img.convert("RGB") for img in images]
if captions:
kwargs["text"] = list(chain([x[0] for x in captions]))

encoded_input = processor(**kwargs)

return {
**encoded_input,
label_col: torch_model(**encoded_input)[0] if torch_model else sample.get(label_col, indices),
}

if max_samples is not None and max_samples < len(dataset):
dataset = dataset.select(range(max_samples))

tokenized_datasets = dataset.map(
generate_inputs,
batched=True,
batch_size=batch_size,
with_indices=True,
remove_columns=dataset.column_names,
desc="Processing dataset",
)
tokenized_datasets.set_format("torch", output_all_columns=True)

return tokenized_datasets


@Registry.register_pre_process()
def pre_process_dataset(
dataset,
model_name: str,
generate_ground_truth: bool = False,
image_col: str | None = None,
caption_col: str | None = None,
label_col: str = "label",
max_samples: int | None = None,
max_length: int = 77,
**kwargs,
):
if image_col is None and caption_col is None:
raise ValueError("Either image_col or caption_col must be provided.")

if generate_ground_truth:
if image_col and caption_col:
raise ValueError("Can not generate two types of embedding at the same time.")

torch_model = load_image_encoder(model_name) if image_col else load_text_encoder(model_name)
else:
torch_model = None

processor = AutoProcessor.from_pretrained(model_name)
dataset = hfdataset_pre_process_for_clip(
dataset,
processor,
torch_model=torch_model,
image_col=image_col,
caption_col=caption_col,
label_col=label_col,
max_length=max_length,
max_samples=max_samples,
)
return BaseDataset(dataset, label_col)


@Registry.register_post_process()
def embed_post_process(output):
"""Post-processing for CLIP output."""
if isinstance(output, (dict, OrderedDict)):
if "embeds" in output:
return output["embeds"]
elif "text_embeds" in output:
return output["text_embeds"]
elif "image_embeds" in output:
return output["image_embeds"]
elif isinstance(output, torch.Tensor):
return output.argmax(dim=-1)
raise ValueError(f"Unsupported output type: {type(output)}")


def eval_similarity_degrad(output, targets, batch_size=1024):
import torch.nn.functional as F

preds = output.preds
scores = [
# pylint: disable=E1102
F.cosine_similarity(preds[i : i + batch_size], targets[i : i + batch_size])
# pylint: enable=E1102
for i in range(0, preds.size(0), batch_size)
]
return {"percentage": f"{100.0 - torch.mean(torch.cat(scores)) * 100.0:.2f}"}
105 changes: 105 additions & 0 deletions examples/clip/qdq/laion_clip_text_b32_qdq.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"input_model": {
"type": "PytorchModel",
"model_path": "laion/clip-vit-b-32-laion2b-s34b-b79k",
"generative": false,
"io_config": {
"input_names": [ "input_ids", "attention_mask" ],
"input_shapes": [ [ 1, 77 ], [ 1, 77 ] ],
"input_types": [ "int32", "int32" ],
"output_names": [ "embeds", "last_hidden_state" ]
},
"model_loader": "load_text_encoder",
"model_script": "clip_script.py"
},
"passes": {
"conversion": { "type": "OnnxConversion", "target_opset": 20, "dynamic": true, "use_dynamo_exporter": false },
"to_fixed_shape": {
"type": "DynamicToFixedShape",
"dim_param": [ "batch_size", "sequence_length" ],
"dim_value": [ 1, 77 ]
},
"surgery": {
"type": "GraphSurgeries",
"surgeries": [
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
{ "surgeon": "MatMulAddToGemm" }
]
},
"transformer_optimizer": {
"type": "OrtTransformersOptimization",
"model_type": "bert",
"opt_level": 1,
"optimization_options": {
"enable_gelu": true,
"enable_bias_gelu": false,
"enable_layer_norm": true,
"enable_skip_layer_norm": false,
"enable_bias_skip_layer_norm": false,
"enable_attention": false
}
},
"quantization": {
"type": "OnnxStaticQuantization",
"data_config": "calib_data",
"quant_preprocess": true,
"activation_type": "uint16",
"precision": "uint8"
}
},
"data_configs": [
{
"name": "calib_data",
"type": "HuggingfaceContainer",
"load_dataset_config": { "data_name": "nlphuji/flickr30k", "split": "test[:12]" },
"pre_process_data_config": {
"type": "pre_process_dataset",
"model_name": "laion/clip-vit-b-32-laion2b-s34b-b79k",
"caption_col": "caption",
"max_length": 77
},
"dataloader_config": { "batch_size": 1 },
"user_script": "clip_script.py"
},
{
"name": "eval_data",
"type": "HuggingfaceContainer",
"load_dataset_config": { "data_name": "nlphuji/flickr_1k_test_image_text_retrieval", "split": "test" },
"pre_process_data_config": {
"type": "pre_process_dataset",
"model_name": "laion/clip-vit-b-32-laion2b-s34b-b79k",
"generate_ground_truth": true,
"caption_col": "caption",
"max_length": 77
},
"post_process_data_config": { "type": "embed_post_process" },
"dataloader_config": { "batch_size": 1 },
"user_script": "clip_script.py"
}
],
"evaluators": {
"sanity_check": {
"metrics": [
{
"name": "degrad",
"type": "custom",
"data_config": "eval_data",
"sub_types": [ { "name": "percentage", "priority": 1, "higher_is_better": false } ],
"user_config": { "user_script": "clip_script.py", "metric_func": "eval_similarity_degrad" }
},
{
"name": "latency",
"type": "latency",
"sub_types": [
{ "name": "avg", "priority": 2, "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } },
{ "name": "p90", "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } }
]
}
]
}
},
"clean_cache": true,
"clean_evaluation_cache": true,
"evaluate_input_model": false,
"output_dir": "models/laion/clip_b32/text"
}
Loading
Oops, something went wrong.