From b01a422ba2e3cacbbba58db4e68bcaeb5a0b9f51 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Sat, 4 Oct 2025 01:51:09 +0000 Subject: [PATCH 1/3] Enable xpu.get_device_properties in Helio kernel --- examples/grouped_gemm.py | 10 ++++++++-- helion/_compiler/type_propagation.py | 15 ++++++++++++--- test/test_examples.expected | 5 ++++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index d89623844..322e42584 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -118,7 +118,13 @@ def grouped_gemm_jagged_persistent( """ # Set worker count to match GPU streaming multiprocessor count device = A_packed.device - num_workers = torch.cuda.get_device_properties(device).multi_processor_count # type: ignore[arg-type] + if device.type == "xpu": + # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. + num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count + else: + num_workers = torch.cuda.get_device_properties( + device.index + ).multi_processor_count # Define tunable block sizes for M, N dimensions (auto-tuned at runtime) BLOCK_M = hl.register_block_size(32, 128) @@ -280,7 +286,7 @@ def _reference_grouped_gemm( # --------------------------- def main() -> None: torch.manual_seed(0) # Ensure reproducible test results - device = "cuda" + device = "xpu" if torch.xpu.is_available() else "cuda" dtype = torch.bfloat16 G = 4 # Number of groups to test K, N = 256, 128 # Shared dimensions: K (reduction), N (output columns) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index a48cd9130..143871f7f 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -271,19 +271,28 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: # This allows zip to work in list comprehensions zipped_tuples = tuple(tuple(items) for items in value) return cls.from_example(zipped_tuples, origin) - if isinstance(value, torch.cuda._CudaDeviceProperties): + if type(value) in [ + torch.cuda._CudaDeviceProperties, + torch.xpu._XpuDeviceProperties, + ]: attrs = {} env = CompileEnvironment.current() + compute_unit_literal = ( + "gpu_subslice_count" + if torch.xpu.is_available() + else "multi_processor_count" + ) + # Only `multi_processor_count` attribute is supported for now # TODO(yf225): support other torch.cuda._CudaDeviceProperties attributes - attr_origin = AttributeOrigin(origin, "multi_processor_count") + attr_origin = AttributeOrigin(origin, compute_unit_literal) # Create a symbolic integer that can be passed as kernel argument sym = env.create_unbacked_symint() HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin( origin=attr_origin ) - attrs["multi_processor_count"] = SymIntType(attr_origin, sym) + attrs[compute_unit_literal] = SymIntType(attr_origin, sym) return ClassType(origin, attrs) raise exc.UnsupportedPythonType(type(value).__name__) diff --git a/test/test_examples.expected b/test/test_examples.expected index 2c004def2..3e2bc61f0 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1318,7 +1318,10 @@ def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, grou Output tensor of shape ``[sum(M_i), N]``. """ device = A_packed.device - num_workers = torch.cuda.get_device_properties(device).multi_processor_count + if device.type == 'xpu': + num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count + else: + num_workers = torch.cuda.get_device_properties(device.index).multi_processor_count total_M, K = A_packed.shape K2, N = B.shape assert K == K2 From 17c78dc40b975ee369fb22804f712d03d7b393eb Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Sat, 4 Oct 2025 02:24:31 +0000 Subject: [PATCH 2/3] Fix typo --- examples/grouped_gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index 322e42584..0a510cc4b 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -119,7 +119,7 @@ def grouped_gemm_jagged_persistent( # Set worker count to match GPU streaming multiprocessor count device = A_packed.device if device.type == "xpu": - # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we change update it to XeCore number. + # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we will update it to XeCore number. num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count else: num_workers = torch.cuda.get_device_properties( From 6dd33c1a30ed80b0ce2c0ab205f2c5c4b28587de Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Sat, 4 Oct 2025 02:34:01 +0000 Subject: [PATCH 3/3] Refine to the type match logic for device properties --- helion/_compiler/type_propagation.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index 143871f7f..6e0ad4dd4 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -271,10 +271,9 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: # This allows zip to work in list comprehensions zipped_tuples = tuple(tuple(items) for items in value) return cls.from_example(zipped_tuples, origin) - if type(value) in [ - torch.cuda._CudaDeviceProperties, - torch.xpu._XpuDeviceProperties, - ]: + if isinstance( + value, (torch.cuda._CudaDeviceProperties, torch.xpu._XpuDeviceProperties) + ): attrs = {} env = CompileEnvironment.current()