Skip to content

Commit

Permalink
improve onnxruntime installation for rocm users
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Mar 2, 2024
1 parent 6d0a48d commit ba86730
Showing 1 changed file with 3 additions and 13 deletions.
16 changes: 3 additions & 13 deletions installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,9 @@ def is_rocm_available():
else:
torch_command = os.environ.get('TORCH_COMMAND', 'torch torchvision --index-url https://download.pytorch.org/whl/rocm5.5') # ROCm 5.5 is oldest for PyTorch 2.1
if rocm_ver is not None:
install(os.environ.get('ONNXRUNTIME_PACKAGE', get_onnxruntime_source_for_rocm(arr)), "onnxruntime-training built with ROCm", ignore=True)
ort_version = os.environ.get('ONNXRUNTIME_VERSION', None)
ort_package = os.environ.get('ONNXRUNTIME_PACKAGE', f"onnxruntime-training{'' if ort_version is None else ('==' + ort_version)} --index-url https://pypi.lsh.sh/{rocm_ver[0]}{rocm_ver[2]}")
install(ort_package, 'onnxruntime-training')
xformers_package = os.environ.get('XFORMERS_PACKAGE', 'none')
elif allow_ipex and (args.use_ipex or shutil.which('sycl-ls') is not None or shutil.which('sycl-ls.exe') is not None or os.environ.get('ONEAPI_ROOT') is not None or os.path.exists('/opt/intel/oneapi') or os.path.exists("C:/Program Files (x86)/Intel/oneAPI") or os.path.exists("C:/oneAPI")):
args.use_ipex = True # pylint: disable=attribute-defined-outside-init
Expand Down Expand Up @@ -925,18 +927,6 @@ def get_version():
return version


def get_onnxruntime_source_for_rocm(rocm_ver):
ort_version = "1.16.3" # hardcoded
cp_str = f"{sys.version_info.major}{sys.version_info.minor}"
if rocm_ver is None:
command = subprocess.run('hipconfig --version', shell=True, check=False, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
rocm_ver = command.stdout.decode(encoding="utf8", errors="ignore").split('.')
if "linux" in sys.platform:
return f"https://download.onnxruntime.ai/onnxruntime_training-{ort_version}%2Brocm{rocm_ver[0]}{rocm_ver[1]}-cp{cp_str}-cp{cp_str}-manylinux_2_17_x86_64.manylinux2014_x86_64.whl"
else:
return 'onnxruntime-gpu'


def find_zluda():
zluda_path = os.environ.get('ZLUDA', None)
if zluda_path is None:
Expand Down

0 comments on commit ba86730

Please sign in to comment.