Skip to content

Commit 9b88982

Browse files
xieofxietezhenghualxie
authored
update bert vit clip qdq to align with Model Lab (#1893)
## Describe your changes ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --------- Co-authored-by: Zheng Te <1221537+tezheng@users.noreply.github.com> Co-authored-by: hualxie <hualxie@microsoft.com>
1 parent 1f1f889 commit 9b88982

10 files changed

+825
-16
lines changed

examples/bert/bert_ptq_qdq.json

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,36 @@
5252
}
5353
},
5454
"passes": {
55-
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
55+
"conversion": { "type": "OnnxConversion", "target_opset": 20 },
5656
"dynamic_shape_to_fixed": {
5757
"type": "DynamicToFixedShape",
5858
"dim_param": [ "batch_size", "sequence_length" ],
5959
"dim_value": [ 1, 128 ]
6060
},
61-
"surgery": { "type": "GraphSurgeries", "surgeries": [ { "surgeon": "ReplaceAttentionMaskValue" } ] },
61+
"surgery": {
62+
"type": "GraphSurgeries",
63+
"surgeries": [
64+
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
65+
{ "surgeon": "MatMulAddToGemm" }
66+
]
67+
},
68+
"transformer_optimizer": {
69+
"type": "OrtTransformersOptimization",
70+
"model_type": "bert",
71+
"opt_level": 1,
72+
"optimization_options": {
73+
"enable_gelu": true,
74+
"enable_bias_gelu": false,
75+
"enable_layer_norm": true,
76+
"enable_skip_layer_norm": false,
77+
"enable_bias_skip_layer_norm": false,
78+
"enable_attention": false
79+
}
80+
},
6281
"quantization": {
6382
"type": "OnnxStaticQuantization",
6483
"data_config": "glue_mrpc",
84+
"quant_preprocess": true,
6585
"activation_type": "uint16",
6686
"precision": "uint8"
6787
}

examples/bert/google_bert_qdq.json

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,36 @@
4141
}
4242
},
4343
"passes": {
44-
"conversion": { "type": "OnnxConversion", "target_opset": 17 },
45-
"dynamic_shape_to_fixed": {
44+
"conversion": { "type": "OnnxConversion", "target_opset": 20 },
45+
"to_fixed_shape": {
4646
"type": "DynamicToFixedShape",
4747
"dim_param": [ "batch_size", "sequence_length" ],
4848
"dim_value": [ 1, 128 ]
4949
},
50-
"surgery": { "type": "GraphSurgeries", "surgeries": [ { "surgeon": "ReplaceAttentionMaskValue" } ] },
51-
"quantization": {
50+
"surgery": {
51+
"type": "GraphSurgeries",
52+
"surgeries": [
53+
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
54+
{ "surgeon": "MatMulAddToGemm" }
55+
]
56+
},
57+
"transformer_optimizer": {
58+
"type": "OrtTransformersOptimization",
59+
"model_type": "bert",
60+
"opt_level": 1,
61+
"optimization_options": {
62+
"enable_gelu": true,
63+
"enable_bias_gelu": false,
64+
"enable_layer_norm": true,
65+
"enable_skip_layer_norm": false,
66+
"enable_bias_skip_layer_norm": false,
67+
"enable_attention": false
68+
}
69+
},
70+
"OnnxQuantization": {
5271
"type": "OnnxStaticQuantization",
5372
"data_config": "xnli",
73+
"quant_preprocess": true,
5474
"activation_type": "uint16",
5575
"precision": "uint8"
5676
}

examples/clip/qdq/clip_script.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
from __future__ import annotations
2+
3+
from collections import OrderedDict
4+
from itertools import chain
5+
6+
import torch
7+
from transformers import (
8+
AutoProcessor,
9+
CLIPTextModelWithProjection,
10+
CLIPVisionModelWithProjection,
11+
)
12+
13+
from olive.data.component.dataset import BaseDataset
14+
from olive.data.registry import Registry
15+
16+
HF_MODEL_SUBFOLDER_MAPPING = {
17+
"sentence-transformers/clip-ViT-B-32": "0_CLIPModel",
18+
}
19+
20+
21+
def load_image_encoder(model_name):
22+
return CLIPVisionModelWithProjection.from_pretrained(
23+
model_name,
24+
subfolder=HF_MODEL_SUBFOLDER_MAPPING.get(model_name, ""),
25+
).eval()
26+
27+
28+
def load_text_encoder(model_name):
29+
return CLIPTextModelWithProjection.from_pretrained(
30+
model_name,
31+
subfolder=HF_MODEL_SUBFOLDER_MAPPING.get(model_name, ""),
32+
).eval()
33+
34+
35+
def hfdataset_pre_process_for_clip(
36+
dataset,
37+
processor,
38+
torch_model=None,
39+
image_col: str | None = None,
40+
caption_col: str | None = None,
41+
label_col: str = "label",
42+
max_samples: int | None = None,
43+
max_length: int = 77,
44+
batch_size: int = 32,
45+
):
46+
def generate_inputs(sample, indices):
47+
captions = sample.get(caption_col, None)
48+
images = sample.get(image_col, None)
49+
50+
kwargs = {
51+
"padding": "max_length",
52+
"max_length": max_length,
53+
"truncation": True,
54+
"add_special_tokens": True,
55+
"return_tensors": "pt",
56+
}
57+
if images:
58+
kwargs["images"] = [img.convert("RGB") for img in images]
59+
if captions:
60+
kwargs["text"] = list(chain([x[0] for x in captions]))
61+
62+
encoded_input = processor(**kwargs)
63+
64+
return {
65+
**encoded_input,
66+
label_col: torch_model(**encoded_input)[0] if torch_model else sample.get(label_col, indices),
67+
}
68+
69+
if max_samples is not None and max_samples < len(dataset):
70+
dataset = dataset.select(range(max_samples))
71+
72+
tokenized_datasets = dataset.map(
73+
generate_inputs,
74+
batched=True,
75+
batch_size=batch_size,
76+
with_indices=True,
77+
remove_columns=dataset.column_names,
78+
desc="Processing dataset",
79+
)
80+
tokenized_datasets.set_format("torch", output_all_columns=True)
81+
82+
return tokenized_datasets
83+
84+
85+
@Registry.register_pre_process()
86+
def pre_process_dataset(
87+
dataset,
88+
model_name: str,
89+
generate_ground_truth: bool = False,
90+
image_col: str | None = None,
91+
caption_col: str | None = None,
92+
label_col: str = "label",
93+
max_samples: int | None = None,
94+
max_length: int = 77,
95+
**kwargs,
96+
):
97+
if image_col is None and caption_col is None:
98+
raise ValueError("Either image_col or caption_col must be provided.")
99+
100+
if generate_ground_truth:
101+
if image_col and caption_col:
102+
raise ValueError("Can not generate two types of embedding at the same time.")
103+
104+
torch_model = load_image_encoder(model_name) if image_col else load_text_encoder(model_name)
105+
else:
106+
torch_model = None
107+
108+
processor = AutoProcessor.from_pretrained(model_name)
109+
dataset = hfdataset_pre_process_for_clip(
110+
dataset,
111+
processor,
112+
torch_model=torch_model,
113+
image_col=image_col,
114+
caption_col=caption_col,
115+
label_col=label_col,
116+
max_length=max_length,
117+
max_samples=max_samples,
118+
)
119+
return BaseDataset(dataset, label_col)
120+
121+
122+
@Registry.register_post_process()
123+
def embed_post_process(output):
124+
"""Post-processing for CLIP output."""
125+
if isinstance(output, (dict, OrderedDict)):
126+
if "embeds" in output:
127+
return output["embeds"]
128+
elif "text_embeds" in output:
129+
return output["text_embeds"]
130+
elif "image_embeds" in output:
131+
return output["image_embeds"]
132+
elif isinstance(output, torch.Tensor):
133+
return output.argmax(dim=-1)
134+
raise ValueError(f"Unsupported output type: {type(output)}")
135+
136+
137+
def eval_similarity_degrad(output, targets, batch_size=1024):
138+
import torch.nn.functional as F
139+
140+
preds = output.preds
141+
scores = [
142+
# pylint: disable=E1102
143+
F.cosine_similarity(preds[i : i + batch_size], targets[i : i + batch_size])
144+
# pylint: enable=E1102
145+
for i in range(0, preds.size(0), batch_size)
146+
]
147+
return {"percentage": f"{100.0 - torch.mean(torch.cat(scores)) * 100.0:.2f}"}
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
{
2+
"input_model": {
3+
"type": "PytorchModel",
4+
"model_path": "laion/clip-vit-b-32-laion2b-s34b-b79k",
5+
"generative": false,
6+
"io_config": {
7+
"input_names": [ "input_ids", "attention_mask" ],
8+
"input_shapes": [ [ 1, 77 ], [ 1, 77 ] ],
9+
"input_types": [ "int32", "int32" ],
10+
"output_names": [ "embeds", "last_hidden_state" ]
11+
},
12+
"model_loader": "load_text_encoder",
13+
"model_script": "clip_script.py"
14+
},
15+
"passes": {
16+
"conversion": { "type": "OnnxConversion", "target_opset": 20, "dynamic": true, "use_dynamo_exporter": false },
17+
"to_fixed_shape": {
18+
"type": "DynamicToFixedShape",
19+
"dim_param": [ "batch_size", "sequence_length" ],
20+
"dim_value": [ 1, 77 ]
21+
},
22+
"surgery": {
23+
"type": "GraphSurgeries",
24+
"surgeries": [
25+
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
26+
{ "surgeon": "MatMulAddToGemm" }
27+
]
28+
},
29+
"transformer_optimizer": {
30+
"type": "OrtTransformersOptimization",
31+
"model_type": "bert",
32+
"opt_level": 1,
33+
"optimization_options": {
34+
"enable_gelu": true,
35+
"enable_bias_gelu": false,
36+
"enable_layer_norm": true,
37+
"enable_skip_layer_norm": false,
38+
"enable_bias_skip_layer_norm": false,
39+
"enable_attention": false
40+
}
41+
},
42+
"quantization": {
43+
"type": "OnnxStaticQuantization",
44+
"data_config": "calib_data",
45+
"quant_preprocess": true,
46+
"activation_type": "uint16",
47+
"precision": "uint8"
48+
}
49+
},
50+
"data_configs": [
51+
{
52+
"name": "calib_data",
53+
"type": "HuggingfaceContainer",
54+
"load_dataset_config": { "data_name": "nlphuji/flickr30k", "split": "test[:12]" },
55+
"pre_process_data_config": {
56+
"type": "pre_process_dataset",
57+
"model_name": "laion/clip-vit-b-32-laion2b-s34b-b79k",
58+
"caption_col": "caption",
59+
"max_length": 77
60+
},
61+
"dataloader_config": { "batch_size": 1 },
62+
"user_script": "clip_script.py"
63+
},
64+
{
65+
"name": "eval_data",
66+
"type": "HuggingfaceContainer",
67+
"load_dataset_config": { "data_name": "nlphuji/flickr_1k_test_image_text_retrieval", "split": "test" },
68+
"pre_process_data_config": {
69+
"type": "pre_process_dataset",
70+
"model_name": "laion/clip-vit-b-32-laion2b-s34b-b79k",
71+
"generate_ground_truth": true,
72+
"caption_col": "caption",
73+
"max_length": 77
74+
},
75+
"post_process_data_config": { "type": "embed_post_process" },
76+
"dataloader_config": { "batch_size": 1 },
77+
"user_script": "clip_script.py"
78+
}
79+
],
80+
"evaluators": {
81+
"sanity_check": {
82+
"metrics": [
83+
{
84+
"name": "degrad",
85+
"type": "custom",
86+
"data_config": "eval_data",
87+
"sub_types": [ { "name": "percentage", "priority": 1, "higher_is_better": false } ],
88+
"user_config": { "user_script": "clip_script.py", "metric_func": "eval_similarity_degrad" }
89+
},
90+
{
91+
"name": "latency",
92+
"type": "latency",
93+
"sub_types": [
94+
{ "name": "avg", "priority": 2, "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } },
95+
{ "name": "p90", "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } }
96+
]
97+
}
98+
]
99+
}
100+
},
101+
"clean_cache": true,
102+
"clean_evaluation_cache": true,
103+
"evaluate_input_model": false,
104+
"output_dir": "models/laion/clip_b32/text"
105+
}

0 commit comments

Comments
 (0)