Skip to content

Update VitisAIQuantization to use Quark #1715

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions olive/common/ort_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,44 @@
"""Raised when the onnxruntime fallback happens."""


def get_vai_apu_type():
# based on amd-quark examples
import subprocess
# Run pnputil as a subprocess to enumerate PCI devices
command = r'pnputil /enum-devices /bus PCI /deviceids '

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = process.communicate()
# Check for supported Hardware IDs
apu_type = ''

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
if 'PCI\\VEN_1022&DEV_1502&REV_00' in stdout.decode(): apu_type = 'PHX/HPT'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/E701 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
if 'PCI\\VEN_1022&DEV_17F0&REV_00' in stdout.decode(): apu_type = 'STX'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/E701 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
if 'PCI\\VEN_1022&DEV_17F0&REV_10' in stdout.decode(): apu_type = 'STX'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/E701 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
if 'PCI\\VEN_1022&DEV_17F0&REV_11' in stdout.decode(): apu_type = 'STX'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/E701 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
return apu_type

def set_vai_environment_variable(apu_type, benchmark_mode=True):
# based on amd-quark examples
import os
install_dir = os.environ['RYZEN_AI_INSTALLATION_PATH']

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
match apu_type:

Check failure

Code scanning / lintrunner

PYLINT/E0001 Error

Parsing failed: 'invalid syntax (olive.common.ort_inference, line 47)' (syntax-error)
See syntax-error.
case 'PHX/HPT':

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
print("Setting environment for PHX/HPT")

Check warning

Code scanning / lintrunner

RUFF/T201 Warning

os.environ['XLNX_VART_FIRMWARE'] = os.path.join(install_dir, 'voe-4.0-win_amd64', 'xclbins', 'phoenix', '1x4.xclbin')

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check notice

Code scanning / lintrunner

RUFF/E501 Note

Line too long (129 > 120).
See https://docs.astral.sh/ruff/rules/line-too-long
os.environ['NUM_OF_DPU_RUNNERS'] = '1'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
os.environ['XLNX_TARGET_NAME'] = 'AMD_AIE2_Nx4_Overlay'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
case 'STX':

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
print("Setting environment for STX")

Check warning

Code scanning / lintrunner

RUFF/T201 Warning

name = '4x4' if benchmark_mode else 'Nx4'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
os.environ['XLNX_VART_FIRMWARE'] = os.path.join(install_dir, 'voe-4.0-win_amd64', 'xclbins', 'strix', f'AMD_AIE2P_{name}_Overlay.xclbin')

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check notice

Code scanning / lintrunner

RUFF/E501 Note

Line too long (149 > 120).
See https://docs.astral.sh/ruff/rules/line-too-long
os.environ['NUM_OF_DPU_RUNNERS'] = '1'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
os.environ['XLNX_TARGET_NAME'] = f'AMD_AIE2_{name}_Overlay'

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
case _:
raise ValueError(f"Unrecognized APU type: {apu_type}. Supported types are 'PHX/HPT' and 'STX'.")
print('XLNX_VART_FIRMWARE=', os.environ['XLNX_VART_FIRMWARE'])

Check warning

Code scanning / lintrunner

RUFF/T201 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
print('NUM_OF_DPU_RUNNERS=', os.environ['NUM_OF_DPU_RUNNERS'])

Check warning

Code scanning / lintrunner

RUFF/T201 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
print('XLNX_TARGET_NAME=', os.environ['XLNX_TARGET_NAME'])

Check warning

Code scanning / lintrunner

RUFF/T201 Warning

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string


# NOTE: `device_id` is only used internally for inference with Distributed ONNX models.
# For regular ONNX models, the recommended way to specify the device is to set the environment variable
# `CUDA_VISIBLE_DEVICES` before runnning a workflow.
Expand Down Expand Up @@ -112,6 +150,25 @@
elif provider == "QNNExecutionProvider":
# add backend_path for QNNExecutionProvider
provider_options[idx]["backend_path"] = "QnnHtp.dll"
elif provider == "VitisAIExecutionProvider":
import os, shutil

Check warning

Code scanning / lintrunner

RUFF/E401 Warning

current_directory = os.getcwd()

Check warning

Code scanning / lintrunner

RUFF/PTH109 Warning

os.getcwd() should be replaced by Path.cwd().
See https://docs.astral.sh/ruff/rules/os-getcwd
directory_path = os.path.join(current_directory, 'cache', 'olive_model_cache')

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
cache_directory = os.path.join(current_directory, 'cache')

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

# Check if the directory exists and delete it if it does.
if os.path.exists(directory_path):
shutil.rmtree(directory_path)
print(f"Directory deleted successfully. Starting Fresh.")

Check warning

Code scanning / lintrunner

RUFF/T201 Warning

Check warning

Code scanning / lintrunner

RUFF/F541 Warning

else:
print(f"Directory '{directory_path}' does not exist.")

apu_type = get_vai_apu_type()
set_vai_environment_variable(apu_type)
install_dir = os.environ['RYZEN_AI_INSTALLATION_PATH']

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
provider_options[idx]["config_file"] = os.path.join(install_dir, 'voe-4.0-win_amd64', 'vaip_config.json')

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string

Check warning

Code scanning / lintrunner

RUFF/Q000 Warning

Single quotes found but double quotes preferred.
See https://docs.astral.sh/ruff/rules/bad-quotes-inline-string
provider_options[idx]["cacheDir"] = cache_directory
provider_options[idx]["cacheKey"] = "olive_model_cache"
logger.debug("Normalized providers: %s, provider_options: %s", providers, provider_options)

# dml specific settings
Expand Down
Loading
Loading