Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

ROCm loses some supported GPUs by requiring hipblaslt #119081

Open
trixirt opened this issue Feb 2, 2024 · 12 comments 路 May be fixed by #120551
Open

ROCm loses some supported GPUs by requiring hipblaslt #119081

trixirt opened this issue Feb 2, 2024 · 12 comments 路 May be fixed by #120551
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@trixirt
Copy link

trixirt commented Feb 2, 2024

馃悰 Describe the bug

In ToT cmake/public/loadhip.cmake requires hipblaslt for new ROCm

hipblastlt is only supported on gfx90a and gfx94x cards readme for hipblaslt

This will means pytorch will work for other cards like gfx11XX for ROCm 5.6 but not for ROCm 6.0

Versions

Collecting environment information...
PyTorch version: 2.1.0a0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Fedora Linux 40 (Rawhide Prerelease) (x86_64)
GCC version: (GCC) 14.0.1 20240125 (Red Hat 14.0.1-0)
Clang version: 17.0.6 (Fedora 17.0.6-4.fc40)
CMake version: version 3.27.7
Libc version: glibc-2.38.9000

Python version: 3.12.1 (main, Dec 18 2023, 00:00:00) [GCC 13.2.1 20231205 (Red Hat 13.2.1-6)] (64-bit runtime)
Python platform: Linux-6.8.0-0.rc0.20240112git70d201a40823.5.fc40.x86_64-x86_64-with-glibc2.38.9000
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 64
On-line CPU(s) list: 0-63
Vendor ID: AuthenticAMD
Model name: AMD EPYC 9334 32-Core Processor
CPU family: 25
Model: 17
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 1
Stepping: 1
BogoMIPS: 5400.16
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good amd_lbr_v2 nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba perfmon_v2 ibrs ibpb stibp ibrs_enhanced vmmcall fsgsbase bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local user_shstk avx512_bf16 clzero irperf xsaveerptr rdpru wbnoinvd amd_ppin cppc arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif x2avic v_spec_ctrl vnmi avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid overflow_recov succor smca fsrm flush_l1d debug_swap
Virtualization: AMD-V
L1d cache: 1 MiB (32 instances)
L1i cache: 1 MiB (32 instances)
L2 cache: 32 MiB (32 instances)
L3 cache: 128 MiB (4 instances)
NUMA node(s): 1
NUMA node0 CPU(s): 0-63
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Mitigation; Safe RET
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced / Automatic IBRS, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected

Versions of relevant libraries:
[pip3] flake8==6.0.0
[pip3] numpy==1.26.2
[pip3] torch==2.1.0
[pip3] torchdata==0.7.0
[pip3] torchtext==0.16.0a0
[pip3] torchvision==0.16.0
[conda] Could not collect

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Feb 2, 2024
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 3, 2024
@jeffdaily
Copy link
Collaborator

The current pytorch operators that could utilize hipblaslt should be limited to gemm_and_bias, scaled_gemm (draft PR pending), and TunableOp (PR pending). The gemm_and_bias support is disabled by default and requires an env var to enable.

We did recently catch and fix a mistake in our nightly 6.0 wheels where we failed to include some hipblaslt-related files within the wheel.

Are you seeing other issues with using ROCm 6.0 with PyTorch? Do you have any reproducer steps for us?

@pruthvistony
Copy link
Collaborator

@hongxiayang
Copy link
Collaborator

@trixirt Please verify and let us know whether we can close it.

@trixirt
Copy link
Author

trixirt commented Feb 6, 2024

I am building from source.
Can you reference the changes made ?

@jithunnair-amd
Copy link
Collaborator

I am building from source. Can you reference the changes made ?

pytorch/builder#1695

@trixirt
Copy link
Author

trixirt commented Feb 9, 2024

I think you are fixing a different problem. I am not using a prebuilt whl, I am packaging pytorch as an rpm for Fedora.
I was expecting some logic in the source files to use hipblaslt only for the narrow set of gpus it supports and/or make it an optional package in loadhip.cmake.

@jeffdaily
Copy link
Collaborator

We expected the unconditional linking of hipblaslt to not be an issue for gfx targets outside of 90a and 94x. The code paths within pytorch that exercise hipblaslt check at runtime that the current GPU is supported. Are you seeing issues on other gfx targets due to the unconditional linking?

Would it help you if we made hipblaslt a compile-time option?

@trixirt
Copy link
Author

trixirt commented Feb 10, 2024

could the 'required' flag be removed from the find_package(hipblaslt REQUIRED) ?

@trixirt trixirt linked a pull request Feb 24, 2024 that will close this issue
@IMbackK
Copy link
Contributor

IMbackK commented May 30, 2024

We expected the unconditional linking of hipblaslt to not be an issue for gfx targets outside of 90a and 94x. The code paths within pytorch that exercise hipblaslt check at runtime that the current GPU is supported. Are you seeing issues on other gfx targets due to the unconditional linking?

Would it help you if we made hipblaslt a compile-time option?

Yes this is a problem. Hipblaslt not haveing code objects for non-gfx90a/gfx94x/gfx11 targets causes clr to assert when pytorch (or any binary linking to hipblaslt) is loaded here https://github.com/ROCm/clr/blob/204d35d16ef5c2c1ea1a4bb25442908a306c857a/hipamd/src/hip_code_object.cpp#L762 on any other device. This just happens to work for you for your release builds as those disable assertions, but its quite broken.

Please also lean on your colleagues to widen hipblaslt support a bit, it is especially absurd that gfx908 is not supported, given its only dual issue fp32 and 1k bf16 mfma away from gfx90a.

@jeffdaily
Copy link
Collaborator

I have filed an internal ticket with our HIP runtime team to look into this, asking them to replace these assert statements with proper error handling with gracefully handling.

Meanwhile, what can we do for work-arounds in the short term?

  1. There is a PR that needs some shepherding that would make hipblaslt an optional part of the build via cmake options. Optionally use hipblaslt聽#120551.
  2. Would it be possible to disable these HIP runtime asserts in the library builds that you maintain?

@IMbackK
Copy link
Contributor

IMbackK commented May 31, 2024

I am currently using #120551 and think this is a decent workaround. I dont want to disable the asserts as this has the potential to hide bugs elsewhere.

Changing how the runtime works to allow loading binary objects without a compatible offload target has non trivial side effects for clients other than pytorch and probubly will require thinking about the architecture a bit.
I would consider simply avoiding linking hipblaslt and dlopening instead only on supported platforms, but that would still run into issues on systems where different gpus are available and one do sent support hipblaslt while the other dose.
In the medium term simply widening the number of supported gpus in hipblaslt would of course be best.

@AngryLoki
Copy link

As reported in comfyanonymous/ComfyUI#3698 in recent versions (namely, torch==2.5.0.dev20240613+rocm6.1, probably as a result of #127944), pytorch now forcefully loads hipBLASLt on all AMD GPUs, and if it is not MI250X/MI300X fails with:

rocblaslt warning: No paths matched .../venv/lib/python3.10/site-packages/torch/lib/hipblaslt/library/*gfx906*co. Make sure that HIPBLASLT_TENSILE_LIBPATH is set correctly.
...
RuntimeError: CUDA error: HIPBLAS_STATUS_NOT_SUPPORTED when calling `HIPBLAS_STATUS_NOT_SUPPORTED`

I hope this will be considered as a bug, and eventually pytorch will automatically fallback to hipBLAS/rocBLAS libraries where hipBLASLt is not supported.

There is still a workaround: in 2.4.0-rc1 new environment variable was added to prevent loading of hipBLASLt from loading - TORCH_BLAS_PREFER_HIPBLASLT=0.

Regarding ease of build, Gentoo may solve the issue with unneeded hipblaslt by building hipblaslt without any kernel -
AngryLoki/gentoo@c4d9776#diff-7a7d2672e7ee516463452abffffc6885319642b3664733405a6eaa9d73965117. This produces dummy library, which can be linked to other projects.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: rocm AMD GPU support for Pytorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: In Progress
Development

Successfully merging a pull request may close this issue.

8 participants