-
Notifications
You must be signed in to change notification settings - Fork 38
Enable torch.xpu._XpuDeviceProperties in Helion kernel #798
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR extends XPU support in the Helion kernel by enabling torch.xpu._XpuDeviceProperties
to work with the grouped GEMM functionality. The changes allow the kernel to query XPU device properties (specifically gpu_subslice_count
) similar to how it currently queries CUDA device properties (multi_processor_count
).
Key changes:
- Added XPU device property support in the type propagation system
- Updated grouped GEMM example to handle both CUDA and XPU devices
- Modified test expectations to reflect the new conditional logic
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
examples/grouped_gemm.py | Added XPU device detection and property querying logic, updated main function to prefer XPU when available |
helion/_compiler/type_propagation.py | Extended type propagation to support torch.xpu._XpuDeviceProperties alongside CUDA properties |
test/test_examples.expected | Updated test expectations to match the new conditional device property logic |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
examples/grouped_gemm.py
Outdated
device = A_packed.device | ||
num_workers = torch.cuda.get_device_properties(device).multi_processor_count # type: ignore[arg-type] | ||
if device.type == "xpu": | ||
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. |
Copilot
AI
Oct 3, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected grammar in comment: 'we change update it' should be 'we should update it'.
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. | |
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we should update it to XeCore number. |
Copilot uses AI. Check for mistakes.
helion/_compiler/type_propagation.py
Outdated
if type(value) in [ | ||
torch.cuda._CudaDeviceProperties, | ||
torch.xpu._XpuDeviceProperties, | ||
]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if type(value) in [ | |
torch.cuda._CudaDeviceProperties, | |
torch.xpu._XpuDeviceProperties, | |
]: | |
if isinstance(value, (torch.cuda._CudaDeviceProperties, | |
torch.xpu._XpuDeviceProperties)): |
why not this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure.
Currently, the
grouped_gemm_jagged_persistent
inexamples/grouped_gemm.py
gets the number of workers by queryingmulti_processor_count
of device properties for CUDA. XPU provides a similar API. This PR intends to extend the support to XPU.