Skip to content

Update VitisAIQuantization to use Quark #1715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .azure_pipelines/olive-aml-ci.yaml
Original file line number Diff line number Diff line change
@@ -28,9 +28,6 @@ jobs:
resnet_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_ptq_cpu_aml
resnet_vitis_ai_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_vitis_ai_ptq_cpu_aml
llama2:
exampleFolder: llama2
exampleName: llama2
@@ -50,6 +47,3 @@ jobs:
resnet_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_ptq_cpu_aml
resnet_vitis_ai_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_vitis_ai_ptq_cpu_aml
6 changes: 0 additions & 6 deletions .azure_pipelines/olive-examples.yaml
Original file line number Diff line number Diff line change
@@ -25,9 +25,6 @@ jobs:
bert_inc:
exampleFolder: bert
exampleName: bert_inc
resnet_vitis_ai_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_vitis_ai_ptq_cpu
resnet_qat:
exampleFolder: resnet
exampleName: resnet_qat
@@ -52,9 +49,6 @@ jobs:
bert_ptq_cpu_docker:
exampleFolder: bert
exampleName: bert_ptq_cpu_docker
resnet_vitis_ai_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_vitis_ai_ptq_cpu
resnet_qat:
exampleFolder: resnet
exampleName: resnet_qat
3 changes: 0 additions & 3 deletions .azure_pipelines/olive-ort-nightly.yaml
Original file line number Diff line number Diff line change
@@ -55,9 +55,6 @@ jobs:
resnet_ptq_cpu:
exampleFolder: resnet
exampleName: resnet_ptq_cpu
# resnet_vitis_ai_ptq_cpu:
# exampleFolder: resnet
# exampleName: resnet_vitis_ai_ptq_cpu
resnet_qat:
exampleFolder: resnet
exampleName: resnet_qat
22 changes: 11 additions & 11 deletions docs/source/features/quantization.md
Original file line number Diff line number Diff line change
@@ -160,23 +160,23 @@ Please refer to [IncQuantization](inc_quantization), [IncDynamicQuantization](in
[IncStaticQuantization](inc_static_quantization) for more details about the passes and their config parameters.

## Quantize with AMD Vitis AI Quantizer
Olive also integrates [AMD Vitis AI Quantizer](https://github.com/microsoft/Olive/blob/main/olive/passes/onnx/vitis_ai/quantize.py) for quantization.
Olive also integrates [AMD Quark Quantizer](https://github.com/microsoft/Olive/blob/main/olive/passes/onnx/vitis_ai/quantize.py) for quantization.

The Vitis™ AI development environment accelerates AI inference on AMD® hardware platforms. The Vitis AI quantizer can reduce the computing complexity by converting the 32-bit floating-point weights and activations to fixed-point like INT8. The fixed-point network model requires less memory bandwidth, thus providing faster speed and higher power efficiency than the floating-point model.
Olive consolidates the Vitis™ AI quantization into a single pass called VitisAIQuantization which supports power-of-2 scale quantization methods and supports Vitis AI Execution Provider.
**AMD Quark** is a comprehensive cross-platform deep learning toolkit designed to simplify and enhance the quantization of deep learning models. Supporting both PyTorch and ONNX models, AMD Quark empowers developers to optimize their models for deployment on a wide range of hardware backends, achieving significant performance gains without compromising accuracy.
Olive consolidates the Quark quantization into a single pass called QuarkQuantization which supports power-of-2 scale quantization methods and supports Vitis AI Execution Provider.

### Example Configuration
```json
"vitis_ai_quantization": {
"type": "VitisAIQuantization",
"calibrate_method":"NonOverflow",
"quant_format":"QDQ",
"activation_type":"QUInt8",
"weight_type":"QInt8",
"data_config": "calib_data_config"
"quark_quantization": {
"type": "QuarkQuantization",
"data_config": "calib_data",
"config_template": "XINT8",
"extra_options": {
"ActivationSymmetric": true
}
}
```
Please refer to [VitisAIQuantization](vitis_ai_quantization) for more details about the pass and its config parameters.
Please refer to [QuarkQuantization](quark_quantization) for more details about the pass and its config parameters.

## NVIDIA TensorRT Model Optimizer-Windows
Olive also integrates [TensorRT Model Optimizer-Windows](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
2 changes: 1 addition & 1 deletion docs/source/reference/options.md
Original file line number Diff line number Diff line change
@@ -403,7 +403,7 @@ Please also find the detailed options from following table for each pass:
| [IncDynamicQuantization](../../reference/pass.rst#_inc_dynamic_quantization) | Intel® Neural Compressor Dynamic Quantization Pass. |
| [IncStaticQuantization](../../reference/pass.rst#_inc_static_quantization) | Intel® Neural Compressor Static Quantization Pass. |
| [IncQuantization](../../reference/pass.rst#_inc_quantization) | Quantize ONNX model with Intel® Neural Compressor where we can search for best parameters for static/dynamic quantization at same time. |
| [VitisAIQuantization](../../reference/pass.rst#_vitis_ai_quantization) | AMD-Xilinx Vitis-AI Quantization Pass. |
| [QuarkQuantization](../../reference/pass.rst#_quark_quantization) | AMD Quark Quantization Pass. |
| [AppendPrePostProcessingOps](../../reference/pass.rst#_append_pre_post_processing) | Add Pre/Post nodes to the input model. |
| [InsertBeamSearch](../../reference/pass.rst#_insert_beam_search) | Insert Beam Search Op. Only used for whisper models. Uses WhisperBeamSearch contrib op if ORT version >= 1.17.1, else uses BeamSearch contrib op. |
| [ExtractAdapters](../../reference/pass.rst#_extract_adapters) | Extract adapters from ONNX model |
6 changes: 3 additions & 3 deletions docs/source/reference/pass.rst
Original file line number Diff line number Diff line change
@@ -140,11 +140,11 @@ IncQuantization
---------------
.. autoconfigclass:: olive.passes.IncQuantization

.. _vitis_ai_quantization:
.. _quark_quantization:

VitisAIQuantization
QuarkQuantization
-------------------
.. autoconfigclass:: olive.passes.VitisAIQuantization
.. autoconfigclass:: olive.passes.QuarkQuantization

.. _append_pre_post_processing:

8 changes: 4 additions & 4 deletions examples/resnet/README.md
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@ This folder contains examples of ResNet optimization using different workflows.
- CPU: [with ONNX Runtime optimizations and static/dynamic quantization](#resnet-optimization-with-ptq-on-cpu)
- CPU: [with PyTorch QAT Default Training Loop and ORT optimizations](#resnet-optimization-with-qat-default-training-loop-on-cpu)
- CPU: [with PyTorch QAT PyTorch Lightning Module and ORT optimizations](#resnet-optimization-with-qat-pytorch-lightning-module-on-cpu)
- AMD DPU: [with AMD Vitis-AI Quantization](#resnet-optimization-with-vitis-ai-ptq-on-amd-dpu)
- AMD NPU: [with AMD Vitis-AI Quantization](#resnet-optimization-with-vitis-ai-ptq-on-amd-npu)
- Intel GPU: [with OpenVINO and DirectML execution providers in ONNX Runtime](#resnet-optimization-with-openvino-and-dml-execution-providers)
- Qualcomm NPU: [with QNN execution provider in ONNX Runtime](#resnet-optimization-with-qnn-execution-providers)

@@ -51,11 +51,11 @@ This workflow performs ResNet optimization on CPU with QAT PyTorch Lightning Mod

Config file: [resnet_qat_lightning_module_cpu.json](resnet_qat_lightning_module_cpu.json)

### ResNet optimization with Vitis-AI PTQ on AMD DPU
This workflow performs ResNet optimization on AMD DPU with AMD Vitis-AI Quantization. It performs the optimization pipeline:
### ResNet optimization with Vitis-AI PTQ on AMD NPU
This workflow performs ResNet optimization on AMD NPU with AMD Vitis-AI Quantization. It performs the optimization pipeline:
- *PyTorch Model -> Onnx Model -> AMD Vitis-AI Quantized Onnx Model*

Config file: [resnet_vitis_ai_ptq_cpu.json](resnet_vitis_ai_ptq_cpu.json)
Config file: [resnet_vitis_ai_ptq_npu.json](resnet_vitis_ai_ptq_npu.json)

### ResNet optimization with OpenVINO and DML execution providers
This example performs ResNet optimization with OpenVINO and DML execution providers in one workflow. It performs the optimization pipeline:
83 changes: 83 additions & 0 deletions examples/resnet/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from collections import OrderedDict
from functools import lru_cache
from random import Random

import numpy as np
import torch
from torchvision import transforms

from olive.data.component.dataset import BaseDataset
from olive.data.registry import Registry


@lru_cache(maxsize=1)
def get_imagenet_label_map():
import requests

imagenet_class_index_url = (
"https://raw.githubusercontent.com/pytorch/vision/main/gallery/assets/imagenet_class_index.json"
)
response = requests.get(imagenet_class_index_url, timeout=3600)
response.raise_for_status() # Ensure the request was successful

# Convert {0: ["n01440764", "tench"], ...} to {synset: index}
return {v[0]: int(k) for k, v in response.json().items()}


def preprocess_image(image):
# Convert to rgb if
# 1. black and white image (all 3 channels the same)
# 2. with alpha channel
if len(np.shape(image)) == 2 or np.shape(image)[-1] != 3:
image = image.convert(mode="RGB")

transformations = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
return transformations(image).numpy().astype(np.float32)


@Registry.register_pre_process()
def image_pre_process(
dataset,
input_col,
label_col,
max_samples=None,
shuffle=False,
seed=42,
**kwargs,
):
if max_samples is not None:
max_samples = min(max_samples, len(dataset))
dataset = dataset.select(
Random(seed).sample(range(len(dataset)), max_samples) if shuffle else range(max_samples)
)

label_names = dataset.features[label_col].names
label_map = get_imagenet_label_map()
tensor_ds = dataset.map(
lambda example: {
"pixel_values": preprocess_image(example[input_col]),
"class": label_map[label_names[example[label_col]]],
},
batched=False,
remove_columns=dataset.column_names,
)
tensor_ds.set_format("torch", output_all_columns=True)

return BaseDataset(tensor_ds, label_col="class")


@Registry.register_post_process()
def image_post_process(output):
if isinstance(output, (dict, OrderedDict)):
return output["logits"].argmax(dim=-1)
elif isinstance(output, torch.Tensor):
return output.argmax(dim=-1)

raise ValueError(f"Unsupported output type: {type(output)}")
1 change: 1 addition & 0 deletions examples/resnet/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
amd-quark
azure-ai-ml
azure-identity
azureml-fsspec
66 changes: 0 additions & 66 deletions examples/resnet/resnet_vitis_ai_ptq_cpu.json

This file was deleted.

101 changes: 101 additions & 0 deletions examples/resnet/resnet_vitis_ai_ptq_npu.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"input_model": {
"type": "HfModel",
"model_path": "microsoft/resnet-50",
"task": "image-classification",
"generative": false,
"io_config": {
"input_names": [ "pixel_values" ],
"input_shapes": [ [ 1, 3, 224, 224 ] ],
"output_names": [ "logits" ],
"dynamic_axes": { "pixel_values": { "0": "batch_size" }, "logits": { "0": "batch_size" } }
}
},
"passes": {
"conversion": {
"type": "OnnxConversion",
"target_opset": 17,
"save_as_external_data": true,
"all_tensors_to_one_file": true,
"use_dynamo_exporter": false
},
"quantization": { "type": "QuarkQuantization", "data_config": "calib_data", "config_template": "XINT8" }
},
"systems": {
"host_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ]
},
"target_system": {
"type": "LocalSystem",
"accelerators": [ { "device": "npu", "execution_providers": [ "VitisAIExecutionProvider" ] } ]
}
},
"engine": {
"host": "host_system",
"target": "target_system",
"cache_dir": "temp/cache",
"clean_cache": true,
"clean_evaluation_cache": true,
"evaluator": "common_evaluator",
"evaluate_input_model": false,
"log_to_file": false,
"output_dir": "outputs/microsoft/resnet_50"
},
"data_configs": [
{
"name": "calib_data",
"type": "HuggingfaceContainer",
"load_dataset_config": { "data_name": "timm/mini-imagenet", "split": "validation[:12]" },
"pre_process_data_config": { "type": "image_pre_process", "input_col": "image", "label_col": "label" },
"dataloader_config": { "batch_size": 1 },
"user_script": "image.py"
},
{
"name": "eval_data",
"type": "HuggingfaceContainer",
"load_dataset_config": { "data_name": "timm/mini-imagenet", "split": "test" },
"pre_process_data_config": {
"type": "image_pre_process",
"input_col": "image",
"label_col": "label",
"max_samples": 5000,
"shuffle": false
},
"post_process_data_config": { "type": "image_post_process" },
"dataloader_config": { "batch_size": 1 },
"user_script": "image.py"
}
],
"evaluators": {
"common_evaluator": {
"metrics": [
{
"name": "quality",
"type": "accuracy",
"data_config": "eval_data",
"sub_types": [
{
"name": "accuracy_score",
"priority": 1,
"metric_config": { "task": "multiclass", "num_classes": 1000 }
}
]
},
{
"name": "performance",
"type": "latency",
"data_config": "eval_data",
"sub_types": [
{ "name": "avg", "priority": 2, "metric_config": { "warmup_num": 10, "repeat_test_num": 100 } },
{ "name": "p75", "metric_config": { "warmup_num": 10, "repeat_test_num": 100 } },
{ "name": "p90", "metric_config": { "warmup_num": 10, "repeat_test_num": 100 } }
],
"user_config": {
"inference_settings": { "onnx": { "execution_provider": "VitisAIExecutionProvider" } }
}
}
]
}
}
}
Loading
Oops, something went wrong.
Loading
Oops, something went wrong.