Skip to content

Conversation

ethanwee1
Copy link
Contributor

@ethanwee1 ethanwee1 commented Mar 28, 2025

Follow up to #145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always.

Fixes #ISSUE_NUMBER

cc @jeffdaily @sunway513 @jithunnair-amd @pruthvistony @ROCmSupport @dllehr-amd @jataylo @hongxiayang @naromero77amd

Copy link

pytorch-bot bot commented Mar 28, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/150227

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 2 Unrelated Failures

As of commit 03e8aa0 with merge base cbc0964 (image):

NEW FAILURES - The following jobs have failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the module: rocm AMD GPU support for Pytorch label Mar 28, 2025
@jeffdaily jeffdaily added the topic: not user facing topic category label Mar 28, 2025
jeffdaily
jeffdaily previously approved these changes Mar 28, 2025
@jeffdaily jeffdaily added rocm This tag is for PRs from ROCm team ciflow/rocm Trigger "default" config CI on ROCm labels Mar 28, 2025
@pytorch-bot pytorch-bot bot removed the ciflow/rocm Trigger "default" config CI on ROCm label Mar 28, 2025
@jeffdaily jeffdaily added the ciflow/rocm Trigger "default" config CI on ROCm label Mar 28, 2025
@cyyever
Copy link
Collaborator

cyyever commented Mar 31, 2025

@pytorchmergebot merge -r

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 31, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

ethanwee1 and others added 2 commits March 31, 2025 04:52
Follow up to pytorch#145130. That PR caused a warning on ROCm the first time
hipblaslt was called for any workload, always.
@pytorchmergebot
Copy link
Collaborator

Successfully rebased rocm_fix_hipblaslt_workspace onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout rocm_fix_hipblaslt_workspace && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the rocm_fix_hipblaslt_workspace branch from 7a5da19 to 03e8aa0 Compare March 31, 2025 04:52
@pytorch-bot pytorch-bot bot removed ciflow/trunk Trigger trunk jobs on your pull request ciflow/rocm Trigger "default" config CI on ROCm labels Mar 31, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: 3 mandatory check(s) failed. The first few are:

Dig deeper by viewing the failures on hud

Details for Dev Infra team Raised by workflow job

Failing merge rule: Core Maintainers

@cyyever
Copy link
Collaborator

cyyever commented Mar 31, 2025

@pytorchmergebot merge -i

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 31, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 1 checks: pull / cuda12.4-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu)

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Apr 1, 2025
The default workspace for hipblaslt is larger than for cublas/cublaslt which requires a slight increase to the buffer needed.

Forward-fix for #150227 that broke ROCm distributed tests but wasn't part of initial CI signal.

Pull Request resolved: #150348
Approved by: https://github.com/jeffdaily
@facebook-github-bot
Copy link
Contributor

@pytorchbot revert -m="Diff reverted internally" -c="ghfirst"

This Pull Request has been reverted by a revert inside Meta. To re-land this change, please open another pull request, assign the same reviewers, fix the CI failures that caused the revert and make sure that the failing CI runs on the PR by applying the proper ciflow label (e.g., ciflow/trunk).)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@ethanwee1 your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Apr 1, 2025
@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Apr 1, 2025
@pytorch-bot pytorch-bot bot dismissed jeffdaily’s stale review April 1, 2025 22:31

This PR was reopened (likely due to being reverted), so your approval was removed. Please request another review.

Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@atalman
Copy link
Contributor

atalman commented Apr 2, 2025

@pytorchmergebot merge -i

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged while ignoring the following 4 checks: pull / linux-jammy-py3-clang12-executorch / build, pull / cuda12.4-py3.10-gcc9-sm75 / test (pr_time_benchmarks, 1, 1, linux.g4dn.metal.nvidia.gpu), linux-binary-manywheel / manywheel-py3_9-cuda12_6-test / test, linux-binary-manywheel / manywheel-py3_9-cuda12_8-test / test

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: Command git -C /home/runner/work/pytorch/pytorch merge --squash __pull-request-150227__init__ returned non-zero exit code 1

Auto-merging aten/src/ATen/cuda/CUDABlas.cpp
CONFLICT (content): Merge conflict in aten/src/ATen/cuda/CUDABlas.cpp
Squash commit -- not updating HEAD
Automatic merge failed; fix conflicts and then commit the result.
Details for Dev Infra team Raised by workflow job

@jeffdaily
Copy link
Collaborator

@atalman I appreciate you trying to reland this. Thanks, truly. But this PR is only needed if #145130 is relanded.

// See Note [hipblaslt handles].
// ROCm's hipblas and hipblaslt do not share handles, unlike with CUDA.
// Using getCurrentCUDABlasLtHandle is on purpose. For CUDA it's the same as
// getCurrentCUDABlasHandle, but for ROCm it's a unique handle.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That handles can be shared between cuBLAS and cuBLASLt is factually true, but I think it is not really relevant here. The cuBLAS handle here is really only used as a key to get a corresponding workspace, and as we do not expect to run a cuBLAS-API backed and cuBLASLt-API backed matmul (on the same stream) at the same time, it's safe to use the workspace that is already allocated for one for the other.

My guess is the real reason the warning shows up on ROCm but not on CUDA is that at present the default CUBLAS_WORKSPACE_CONFIG effective size is always >= the default CUBLASLT_WORKSPACE_SIZE setting. On the CUDA side the intent is to only allocate the cuBLAS workspace and reuse it for Lt, but if Lt requests a larger workspace it precludes this unification.

If you agree with this I think a clearer explanation would be along the lines of "CUDA attempts to share workspaces with the assumption that cuBLAS workspace size >= cuBLASLt workspace size, but as this assumption may not hold on ROCm, we also add a mapping for Lt handle -> workspace in addition to BLAS handle -> workspace."

amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
…150227)

Follow up to pytorch#145130. That PR caused a warning on ROCm the first time hipblaslt was called for any workload, always.

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#150227
Approved by: https://github.com/jeffdaily

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
The default workspace for hipblaslt is larger than for cublas/cublaslt which requires a slight increase to the buffer needed.

Forward-fix for pytorch#150227 that broke ROCm distributed tests but wasn't part of initial CI signal.

Pull Request resolved: pytorch#150348
Approved by: https://github.com/jeffdaily
amathewc pushed a commit to amathewc/pytorch that referenced this pull request Apr 17, 2025
Copy link
Contributor

github-actions bot commented Jun 1, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jun 1, 2025
@jeffdaily
Copy link
Collaborator

This PR is no longer needed.

@jeffdaily jeffdaily closed this Jun 2, 2025
akashveramd pushed a commit to ROCm/pytorch that referenced this pull request Aug 13, 2025
The default workspace for hipblaslt is larger than for cublas/cublaslt which requires a slight increase to the buffer needed.

Forward-fix for pytorch#150227 that broke ROCm distributed tests but wasn't part of initial CI signal.

Pull Request resolved: pytorch#150348
Approved by: https://github.com/jeffdaily
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: rocm AMD GPU support for Pytorch open source Reverted rocm This tag is for PRs from ROCm team Stale topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants