Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCM] Fix blockReduceSum to use correct warp counts for ROCm and CUDA #3262

Merged
merged 1 commit into from
Mar 10, 2024

Conversation

dllehr-amd
Copy link
Contributor

blockReduceSum was defaulting to 32 for warp size regardless of the architecture.

Bonus, refactor cuda_compat.h to hold WARP_SIZE define instead of the attention_kernels.cuh

blockReduceSum was defaulting to 32 for warp size regardless of the
architecture.

Bonus, refactor cuda_compat.h to hold WARP_SIZE define instead of
the attention_kernels.cuh
Copy link
Collaborator

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for the fix!

@zhuohan123 zhuohan123 merged commit e4a28e5 into vllm-project:main Mar 10, 2024
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants