diff --git a/python/tvm/contrib/rocm.py b/python/tvm/contrib/rocm.py index e69b2558c0b53..4f62f1a8da263 100644 --- a/python/tvm/contrib/rocm.py +++ b/python/tvm/contrib/rocm.py @@ -73,7 +73,20 @@ def rocm_link(in_file, out_file, lld=None): The lld linker, if not specified, we will try to guess the matched clang version. """ - args = [lld if lld is not None else find_lld()[0], "-shared", in_file, "-o", out_file] + + # if our result has undefined symbols, it will fail to load + # (hipModuleLoad/hipModuleLoadData), but with a somewhat opaque message + # so we have ld.lld check this here. + # If you get a complaint about missing symbols you might want to check the + # list of bitcode files below. + args = [ + lld if lld is not None else find_lld()[0], + "--no-undefined", + "-shared", + in_file, + "-o", + out_file, + ] proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) (out, _) = proc.communicate() @@ -108,7 +121,7 @@ def callback_rocm_link(obj_bin): @tvm._ffi.register_func("tvm_callback_rocm_bitcode_path") -def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"): +def callback_rocm_bitcode_path(rocdl_dir=None): """Utility function to find ROCm device library bitcodes Parameters @@ -118,23 +131,40 @@ def callback_rocm_bitcode_path(rocdl_dir="/opt/rocm/lib/"): The default value is the standard location """ # seems link order matters. - bitcode_files = [ - "oclc_daz_opt_on.amdgcn.bc", - "ocml.amdgcn.bc", - "hc.amdgcn.bc", - "irif.amdgcn.bc", - "ockl.amdgcn.bc", - "oclc_correctly_rounded_sqrt_off.amdgcn.bc", - "oclc_correctly_rounded_sqrt_on.amdgcn.bc", - "oclc_daz_opt_off.amdgcn.bc", - "oclc_finite_only_off.amdgcn.bc", - "oclc_finite_only_on.amdgcn.bc", - "oclc_isa_version_803.amdgcn.bc", - "oclc_isa_version_900.amdgcn.bc", - "oclc_isa_version_906.amdgcn.bc", - "oclc_unsafe_math_off.amdgcn.bc", - "oclc_unsafe_math_on.amdgcn.bc", - "oclc_wavefrontsize64_on.amdgcn.bc", + + if rocdl_dir is None: + if exists("/opt/rocm/amdgcn/bitcode/"): + rocdl_dir = "/opt/rocm/amdgcn/bitcode/" # starting with rocm 3.9 + else: + rocdl_dir = "/opt/rocm/lib/" # until rocm 3.8 + + bitcode_names = [ + "oclc_daz_opt_on", + "ocml", + "hc", + "irif", # this does not exist in rocm 3.9, drop eventually + "ockl", + "oclc_correctly_rounded_sqrt_off", + "oclc_correctly_rounded_sqrt_on", + "oclc_daz_opt_off", + "oclc_finite_only_off", + "oclc_finite_only_on", + "oclc_isa_version_803", # todo (t-vi): an alternative might be to scan for the + "oclc_isa_version_900", # isa version files (if the linker throws out + "oclc_isa_version_906", # the unneeded ones or we filter for the arch we need) + "oclc_unsafe_math_off", + "oclc_unsafe_math_on", + "oclc_wavefrontsize64_on", ] - paths = [join(rocdl_dir, bitcode) for bitcode in bitcode_files] - return tvm.runtime.convert([path for path in paths if exists(path)]) + + bitcode_files = [] + for n in bitcode_names: + p = join(rocdl_dir, n + ".bc") # rocm >= 3.9 + if not exists(p): # rocm <= 3.8 + p = join(rocdl_dir, n + ".amdgcn.bc") + if exists(p): + bitcode_files.append(p) + elif "isa_version" not in n and n not in {"irif"}: + raise RuntimeError("could not find bitcode " + n) + + return tvm.runtime.convert(bitcode_files)