-
Notifications
You must be signed in to change notification settings - Fork 218
Update VitisAIQuantization to use Quark #1715
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 25 commits
56b873d
469e225
1646eb7
32f22ae
e5cb11d
c569355
c4d1543
9c36257
9844278
36c5ea9
b7fc958
c426989
a170058
01012af
b158fff
1d62630
3b62d9b
40ad7d5
498b59a
7c62a0c
6645663
40541c8
d1ebcca
42c8620
9763389
042c553
a80f9e1
0a4b9ab
1b5d4d9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
from collections import OrderedDict | ||
from functools import lru_cache | ||
from random import Random | ||
from typing import Dict | ||
|
||
|
||
import numpy as np | ||
import torch | ||
from torchvision import transforms | ||
|
||
from olive.data.component.dataset import BaseDataset | ||
from olive.data.registry import Registry | ||
|
||
|
||
@lru_cache(maxsize=1) | ||
def get_imagenet_label_map(): | ||
import requests | ||
|
||
imagenet_class_index_url = ( | ||
"https://raw.githubusercontent.com/pytorch/vision/main/gallery/assets/imagenet_class_index.json" | ||
) | ||
response = requests.get(imagenet_class_index_url, timeout=3600) | ||
response.raise_for_status() # Ensure the request was successful | ||
|
||
# Convert {0: ["n01440764", "tench"], ...} to {synset: index} | ||
return {v[0]: int(k) for k, v in response.json().items()} | ||
|
||
|
||
def preprocess_image(image): | ||
# Convert to rgb if | ||
# 1. black and white image (all 3 channels the same) | ||
# 2. with alpha channel | ||
if len(np.shape(image)) == 2 or np.shape(image)[-1] != 3: | ||
image = image.convert(mode="RGB") | ||
|
||
transformations = transforms.Compose( | ||
[ | ||
transforms.Resize(256), | ||
transforms.CenterCrop(224), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | ||
] | ||
) | ||
return transformations(image).numpy().astype(np.float32) | ||
|
||
|
||
@Registry.register_pre_process() | ||
def image_pre_process( | ||
dataset, | ||
input_col, | ||
label_col, | ||
max_samples=None, | ||
shuffle=False, | ||
seed=42, | ||
**kwargs, | ||
): | ||
if max_samples is not None: | ||
max_samples = min(max_samples, len(dataset)) | ||
dataset = dataset.select( | ||
Random(seed).sample(range(len(dataset)), max_samples) if shuffle else range(max_samples) | ||
) | ||
|
||
label_names = dataset.features[label_col].names | ||
label_map = get_imagenet_label_map() | ||
tensor_ds = dataset.map( | ||
lambda example: { | ||
"pixel_values": preprocess_image(example[input_col]), | ||
"class": label_map[label_names[example[label_col]]], | ||
}, | ||
batched=False, | ||
remove_columns=dataset.column_names, | ||
) | ||
tensor_ds.set_format("torch", output_all_columns=True) | ||
|
||
return BaseDataset(tensor_ds, label_col="class") | ||
|
||
|
||
@Registry.register_post_process() | ||
def image_post_process(output): | ||
if isinstance(output, (Dict, OrderedDict)): | ||
|
||
return output["logits"].argmax(dim=-1) | ||
elif isinstance(output, torch.Tensor): | ||
return output.argmax(dim=-1) | ||
|
||
raise ValueError(f"Unsupported output type: {type(output)}") |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
amd-quark | ||
azure-ai-ml | ||
azure-identity | ||
azureml-fsspec | ||
|
chinazhangchao marked this conversation as resolved.
Show resolved
Hide resolved
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
{ | ||
|
||
"input_model": { | ||
"type": "HfModel", | ||
"model_path": "microsoft/resnet-50", | ||
"task": "image-classification", | ||
"generative": false, | ||
"io_config": { | ||
"input_names": [ "pixel_values" ], | ||
"input_shapes": [ [ 1, 3, 224, 224 ] ], | ||
"output_names": [ "logits" ], | ||
"dynamic_axes": { "pixel_values": { "0": "batch_size" }, "logits": { "0": "batch_size" } } | ||
} | ||
}, | ||
"passes": { | ||
"conversion": { | ||
"type": "OnnxConversion", | ||
"target_opset": 17, | ||
"save_as_external_data": true, | ||
"all_tensors_to_one_file": true, | ||
"use_dynamo_exporter": false | ||
}, | ||
"quantization": { | ||
"type": "QuarkQuantization", | ||
"data_config": "calib_data", | ||
"config_template": "XINT8" | ||
} | ||
}, | ||
"systems": { | ||
"host_system": { | ||
"type": "LocalSystem", | ||
"accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] | ||
}, | ||
"target_system": { | ||
"type": "LocalSystem", | ||
"accelerators": [ { "device": "npu", "execution_providers": [ "VitisAIExecutionProvider" ] } ] | ||
} | ||
}, | ||
"engine": { | ||
"host": "host_system", | ||
"target": "target_system", | ||
"cache_dir": "temp/cache", | ||
"clean_cache": true, | ||
"clean_evaluation_cache": true, | ||
"evaluator": "common_evaluator", | ||
"evaluate_input_model": false, | ||
"log_to_file": false, | ||
"output_dir": "outputs/microsoft/resnet_50" | ||
}, | ||
"data_configs": [ | ||
{ | ||
"name": "calib_data", | ||
"type": "HuggingfaceContainer", | ||
"load_dataset_config": { "data_name": "timm/mini-imagenet", "split": "validation[:12]" }, | ||
"pre_process_data_config": { "type": "image_pre_process", "input_col": "image", "label_col": "label" }, | ||
"dataloader_config": { "batch_size": 1 }, | ||
"user_script": "image.py" | ||
}, | ||
{ | ||
"name": "eval_data", | ||
"type": "HuggingfaceContainer", | ||
"load_dataset_config": { "data_name": "timm/mini-imagenet", "split": "test" }, | ||
"pre_process_data_config": { | ||
"type": "image_pre_process", | ||
"input_col": "image", | ||
"label_col": "label", | ||
"max_samples": 5000, | ||
"shuffle": false | ||
}, | ||
"post_process_data_config": { "type": "image_post_process" }, | ||
"dataloader_config": { "batch_size": 1 }, | ||
"user_script": "image.py" | ||
} | ||
], | ||
"evaluators": { | ||
"common_evaluator": { | ||
"metrics": [ | ||
{ | ||
"name": "quality", | ||
"type": "accuracy", | ||
"data_config": "eval_data", | ||
"sub_types": [ | ||
{ | ||
"name": "accuracy_score", | ||
"priority": 1, | ||
"metric_config": { "task": "multiclass", "num_classes": 1000 } | ||
} | ||
] | ||
}, | ||
{ | ||
"name": "performance", | ||
"type": "latency", | ||
"data_config": "eval_data", | ||
"sub_types": [ | ||
{ "name": "avg", "priority": 2, "metric_config": { "warmup_num": 10, "repeat_test_num": 100 } }, | ||
{ "name": "p75", "metric_config": { "warmup_num": 10, "repeat_test_num": 100 } }, | ||
{ "name": "p90", "metric_config": { "warmup_num": 10, "repeat_test_num": 100 } } | ||
], | ||
"user_config": { | ||
"inference_settings": { "onnx": { "execution_provider": "VitisAIExecutionProvider" } } | ||
} | ||
} | ||
] | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,6 +40,7 @@ class CacheSubDirs: | |
evaluations: Path | ||
resources: Path | ||
mlflow: Path | ||
vitis_ai: Path | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you explain more about how will you use this folder? The cache folder is designed to be pass-agnostic so i want to double confirm the use case here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The folder will be created at the beginning of the evaluation step, upon the creation of a VitisAIExecutionProvider inference session (used as model cache by EP). Is evaluation considered an Olive pass? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, if VitisAIEP will need to cache a model for evaluation, can we create a temporal folder for it? (and it will be deleted after all. I assume this model cache is not needed when the workflow finish.). We can create a temporary folder in cache.evaluations like temp_model_cache or something. |
||
|
||
@classmethod | ||
def from_cache_dir(cls, cache_dir: Path) -> "CacheSubDirs": | ||
|
@@ -49,6 +50,7 @@ def from_cache_dir(cls, cache_dir: Path) -> "CacheSubDirs": | |
evaluations=cache_dir / "evaluations", | ||
resources=cache_dir / "resources", | ||
mlflow=cache_dir / "mlflow", | ||
vitis_ai=cache_dir / "vitis_ai", | ||
) | ||
|
||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.