Skip to content

Commit b8e068d

Browse files
committed
Refined QDQ recipes of BERT/CLIP/VIT for QC and AMD.
1 parent 240cc9b commit b8e068d

33 files changed

+11878
-0
lines changed

examples/bert/qdq/.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.12.9
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
{
2+
"input_model": {
3+
"type": "PytorchModel",
4+
"model_path": "google-bert/bert-base-multilingual-cased",
5+
"io_config": {
6+
"input_names": [ "input_ids", "attention_mask", "token_type_ids" ],
7+
"input_shapes": [ [ 1, 512 ], [ 1, 512 ], [ 1, 512 ] ],
8+
"input_types": [ "int32", "int32", "int32" ],
9+
"output_names": [ "logits" ]
10+
},
11+
"model_loader": "load_bert_nsp_model",
12+
"model_script": "google_bert_script.py"
13+
},
14+
"passes": {
15+
"conversion": { "type": "OnnxConversion", "target_opset": 20, "dynamic": true, "use_dynamo_exporter": false },
16+
"to_fixed_shape": {
17+
"type": "DynamicToFixedShape",
18+
"dim_param": [ "batch_size", "sequence_length" ],
19+
"dim_value": [ 1, 512 ]
20+
},
21+
"surgery": {
22+
"type": "GraphSurgeries",
23+
"surgeries": [
24+
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
25+
{ "surgeon": "MatMulAddToGemm" }
26+
]
27+
},
28+
"transformer_optimizer": {
29+
"type": "OrtTransformersOptimization",
30+
"model_type": "bert",
31+
"opt_level": 1,
32+
"optimization_options": {
33+
"enable_gelu": true,
34+
"enable_bias_gelu": false,
35+
"enable_layer_norm": true,
36+
"enable_skip_layer_norm": false,
37+
"enable_bias_skip_layer_norm": false,
38+
"enable_attention": false
39+
}
40+
},
41+
"quantization": {
42+
"type": "OnnxStaticQuantization",
43+
"data_config": "calib_data",
44+
"quant_preprocess": true,
45+
"activation_type": "QUInt16",
46+
"weight_type": "QUInt8"
47+
}
48+
},
49+
"data_configs": [
50+
{
51+
"name": "calib_data",
52+
"type": "HuggingfaceContainer",
53+
"load_dataset_config": { "data_name": "glue", "subset": "mrpc", "split": "train[:12]" },
54+
"pre_process_data_config": {
55+
"model_name": "google-bert/bert-base-multilingual-cased",
56+
"input_cols": [ "sentence1", "sentence2" ],
57+
"max_length": 512,
58+
"padding": "max_length"
59+
},
60+
"dataloader_config": { "batch_size": 1 }
61+
},
62+
{
63+
"name": "wiki_data",
64+
"type": "HuggingfaceContainer",
65+
"load_dataset_config": {
66+
"type": "dataset_to_nsp_dataset",
67+
"data_name": "wikitext",
68+
"subset": "wikitext-2-raw-v1",
69+
"split": "test",
70+
"input_cols": [ "sentence1", "sentence2" ],
71+
"label_col": "label"
72+
},
73+
"pre_process_data_config": {
74+
"model_name": "google-bert/bert-base-multilingual-cased",
75+
"input_cols": [ "sentence1", "sentence2" ],
76+
"label_col": "label",
77+
"max_length": 512,
78+
"padding": "max_length"
79+
},
80+
"post_process_data_config": { "type": "bert_scl_post_process" },
81+
"dataloader_config": { "batch_size": 1 },
82+
"user_script": "google_bert_script.py",
83+
"script_dir": "."
84+
}
85+
],
86+
"evaluators": {
87+
"nsp_evaluator": {
88+
"metrics": [
89+
{
90+
"name": "nsp",
91+
"type": "accuracy",
92+
"backend": "huggingface_metrics",
93+
"data_config": "wiki_data",
94+
"sub_types": [ { "name": "accuracy", "priority": 1 }, { "name": "f1" } ]
95+
},
96+
{ "name": "latency", "type": "latency", "sub_types": [ { "name": "avg" } ] }
97+
]
98+
},
99+
"performance": {
100+
"metrics": [
101+
{
102+
"name": "latency",
103+
"type": "latency",
104+
"sub_types": [
105+
{ "name": "avg", "priority": 1, "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } },
106+
{ "name": "p90", "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } }
107+
]
108+
}
109+
]
110+
}
111+
},
112+
"clean_cache": true,
113+
"clean_evaluation_cache": true,
114+
"evaluate_input_model": false,
115+
"output_dir": "models/google/bert_base_multilingual_cased"
116+
}
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
{
2+
"input_model": {
3+
"type": "HfModel",
4+
"model_path": "google-bert/bert-large-cased-whole-word-masking-finetuned-squad",
5+
"task": "question-answering",
6+
"io_config": {
7+
"input_names": [ "input_ids", "attention_mask" ],
8+
"input_shapes": [ [ 1, 512 ], [ 1, 512 ] ],
9+
"input_types": [ "int32", "int32" ],
10+
"output_names": [ "start_logits", "end_logits" ]
11+
}
12+
},
13+
"passes": {
14+
"conversion": { "type": "OnnxConversion", "target_opset": 20, "dynamic": true, "use_dynamo_exporter": false },
15+
"to_fixed_shape": {
16+
"type": "DynamicToFixedShape",
17+
"dim_param": [ "batch_size", "sequence_length" ],
18+
"dim_value": [ 1, 512 ]
19+
},
20+
"surgery": {
21+
"type": "GraphSurgeries",
22+
"surgeries": [
23+
{ "surgeon": "ReplaceAttentionMaskValue", "replacement": -100.0 },
24+
{ "surgeon": "MatMulAddToGemm" }
25+
]
26+
},
27+
"transformer_optimizer": {
28+
"type": "OrtTransformersOptimization",
29+
"model_type": "bert",
30+
"opt_level": 1,
31+
"optimization_options": {
32+
"enable_gelu": true,
33+
"enable_bias_gelu": false,
34+
"enable_layer_norm": true,
35+
"enable_skip_layer_norm": false,
36+
"enable_bias_skip_layer_norm": false,
37+
"enable_attention": false
38+
}
39+
},
40+
"quantization": {
41+
"type": "OnnxStaticQuantization",
42+
"data_config": "calib_data",
43+
"quant_preprocess": true,
44+
"activation_type": "QUInt16",
45+
"weight_type": "QUInt8"
46+
}
47+
},
48+
"data_configs": [
49+
{
50+
"name": "calib_data",
51+
"type": "HuggingfaceContainer",
52+
"load_dataset_config": { "data_name": "squad", "split": "train[:12]" },
53+
"pre_process_data_config": {
54+
"input_cols": [ "question", "context" ],
55+
"label_col": "id",
56+
"padding": "max_length",
57+
"max_length": 512
58+
},
59+
"dataloader_config": { "batch_size": 1 },
60+
"user_script": "google_bert_script.py"
61+
}
62+
],
63+
"evaluators": {
64+
"squad_evaluator": {
65+
"metrics": [
66+
{
67+
"name": "squad",
68+
"type": "custom",
69+
"sub_types": [
70+
{ "name": "exact_match", "priority": 1, "higher_is_better": true },
71+
{ "name": "f1", "higher_is_better": true }
72+
],
73+
"user_config": {
74+
"evaluate_func": "eval_squad",
75+
"evaluate_func_kwargs": {
76+
"model_name": "google-bert/bert-large-cased-whole-word-masking-finetuned-squad",
77+
"dataset_config": { "data_name": "squad", "split": "validation" }
78+
},
79+
"user_script": "google_bert_script.py"
80+
}
81+
},
82+
{ "name": "latency", "type": "latency", "sub_types": [ { "name": "avg" } ] }
83+
]
84+
},
85+
"performance": {
86+
"metrics": [
87+
{
88+
"name": "latency",
89+
"type": "latency",
90+
"sub_types": [
91+
{ "name": "avg", "priority": 1, "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } },
92+
{ "name": "p90", "metric_config": { "warmup_num": 20, "repeat_test_num": 100 } }
93+
]
94+
}
95+
]
96+
}
97+
},
98+
"clean_cache": true,
99+
"clean_evaluation_cache": true,
100+
"evaluate_input_model": false,
101+
"output_dir": "models/google/bert_large_cased_qa"
102+
}
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
from __future__ import annotations
2+
3+
from collections import OrderedDict
4+
from typing import TYPE_CHECKING
5+
6+
import torch
7+
from datasets import load_dataset
8+
from evaluate import load as load_metric
9+
from tqdm.auto import tqdm
10+
from transformers import (
11+
AutoModelForNextSentencePrediction,
12+
AutoTokenizer,
13+
)
14+
15+
from olive.common.utils import format_data
16+
from olive.data.registry import Registry
17+
18+
if TYPE_CHECKING:
19+
from olive.hardware.accelerator import Device
20+
from olive.model import ONNXModelHandler
21+
22+
23+
def load_bert_nsp_model(model_name: str) -> torch.nn.Module:
24+
return AutoModelForNextSentencePrediction.from_pretrained(model_name).eval()
25+
26+
27+
@Registry.register_post_process()
28+
def bert_scl_post_process(outputs):
29+
"""Post-processing for Sequence Classification task."""
30+
match outputs:
31+
case torch.Tensor():
32+
return outputs.argmax(dim=-1)
33+
case OrderedDict() | dict() if "logits" in outputs:
34+
return outputs["logits"].argmax(dim=-1)
35+
case OrderedDict() | dict() if "last_hidden_state" in outputs:
36+
return outputs["last_hidden_state"]
37+
case _:
38+
raise ValueError(f"Unsupported output type: {type(outputs)}")
39+
40+
41+
@Registry.register_dataset()
42+
def dataset_to_nsp_dataset(
43+
data_name: str,
44+
subset: str,
45+
split: str,
46+
input_cols: list[str],
47+
label_col: str,
48+
max_samples: int | None,
49+
):
50+
from wikitext import create_nsp_dataset
51+
52+
return create_nsp_dataset(
53+
dataset=data_name,
54+
subset=subset,
55+
split=split,
56+
sent_cols=input_cols,
57+
label_col=label_col,
58+
max_samples=max_samples,
59+
)
60+
61+
62+
def eval_squad(
63+
model: ONNXModelHandler,
64+
device: Device,
65+
execution_providers: str,
66+
dataset_config: dict[str, str],
67+
model_name: str,
68+
max_samples: int | None = None,
69+
) -> dict[str, float | int]:
70+
from concurrent.futures import ThreadPoolExecutor
71+
from queue import Queue
72+
73+
sample_queue, result_queue = Queue(maxsize=500), Queue(maxsize=10)
74+
75+
dataset = load_dataset(
76+
path=dataset_config["data_name"],
77+
split=dataset_config["split"],
78+
)
79+
if max_samples is not None:
80+
dataset = dataset.take(min(max_samples, len(dataset)))
81+
82+
def data_thread_func():
83+
io_config = model.io_config
84+
input_ids_index = io_config["input_names"].index("input_ids")
85+
input_ids_shape = io_config["input_shapes"][input_ids_index]
86+
tokenizer = AutoTokenizer.from_pretrained(model_name)
87+
88+
for sample in tqdm(dataset, position=0, desc="Loading Data"):
89+
encoded_input = tokenizer(
90+
sample["question"],
91+
sample["context"],
92+
padding="max_length",
93+
max_length=input_ids_shape[1],
94+
truncation=True,
95+
return_offsets_mapping=True,
96+
return_tensors="pt",
97+
)
98+
inputs = format_data(
99+
{
100+
"input_ids": encoded_input.input_ids,
101+
"attention_mask": encoded_input.attention_mask,
102+
},
103+
io_config,
104+
)
105+
sample_queue.put((inputs, encoded_input.offset_mapping, sample))
106+
107+
# Sentinel value to indicate end of data
108+
sample_queue.put((None, None, None))
109+
110+
def inference_thread_func():
111+
sess = model.prepare_session(
112+
device=device,
113+
execution_providers=execution_providers,
114+
)
115+
with tqdm(total=len(dataset), position=1, desc="Inferencing") as pbar:
116+
while True:
117+
inputs, offset_mapping, sample = sample_queue.get()
118+
if inputs is None:
119+
result_queue.put((None, None, None))
120+
break # Exit if sentinel value is received
121+
122+
pred = model.run_session(session=sess, inputs=inputs)
123+
result_queue.put((pred, offset_mapping, sample))
124+
pbar.update(1)
125+
126+
def post_process_thread_func():
127+
predictions, references = [], []
128+
with tqdm(total=len(dataset), position=2, desc="Post Processing") as pbar:
129+
while True:
130+
pred, offset_mapping, sample = result_queue.get()
131+
if pred is None:
132+
break # Exit if sentinel value is received
133+
134+
start_index, end_index = pred[0].argmax(-1), pred[1].argmax(-1)
135+
answer_start, answer_end = (
136+
offset_mapping[:, start_index, 0].squeeze(),
137+
offset_mapping[:, end_index, 1].squeeze(),
138+
)
139+
predictions.append(
140+
{
141+
"id": sample["id"],
142+
"prediction_text": sample["context"][answer_start:answer_end],
143+
}
144+
)
145+
references.append(
146+
{
147+
"id": sample["id"],
148+
"answers": {
149+
"answer_start": sample["answers"]["answer_start"],
150+
"text": sample["answers"]["text"],
151+
},
152+
}
153+
)
154+
pbar.update(1)
155+
156+
return predictions, references
157+
158+
with ThreadPoolExecutor(max_workers=3) as executor:
159+
data_future = executor.submit(data_thread_func)
160+
inference_future = executor.submit(inference_thread_func)
161+
post_process_future = executor.submit(post_process_thread_func)
162+
163+
data_future.result()
164+
inference_future.result()
165+
predictions, references = post_process_future.result()
166+
167+
results = load_metric("squad").compute(
168+
predictions=predictions,
169+
references=references,
170+
)
171+
172+
return (
173+
{"f1": results["f1"], "exact_match": results["exact_match"]}
174+
if results
175+
else {"f1": float("nan"), "exact_match": float("nan")}
176+
)

0 commit comments

Comments
 (0)