diff --git a/examples/grouped_gemm.py b/examples/grouped_gemm.py index d89623844..0a510cc4b 100644 --- a/examples/grouped_gemm.py +++ b/examples/grouped_gemm.py @@ -118,7 +118,13 @@ def grouped_gemm_jagged_persistent( """ # Set worker count to match GPU streaming multiprocessor count device = A_packed.device - num_workers = torch.cuda.get_device_properties(device).multi_processor_count # type: ignore[arg-type] + if device.type == "xpu": + # TODO(EikanWang): gpu_subslice_count is an out-of-date term. we will update it to XeCore number. + num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count + else: + num_workers = torch.cuda.get_device_properties( + device.index + ).multi_processor_count # Define tunable block sizes for M, N dimensions (auto-tuned at runtime) BLOCK_M = hl.register_block_size(32, 128) @@ -280,7 +286,7 @@ def _reference_grouped_gemm( # --------------------------- def main() -> None: torch.manual_seed(0) # Ensure reproducible test results - device = "cuda" + device = "xpu" if torch.xpu.is_available() else "cuda" dtype = torch.bfloat16 G = 4 # Number of groups to test K, N = 256, 128 # Shared dimensions: K (reduction), N (output columns) diff --git a/helion/_compiler/type_propagation.py b/helion/_compiler/type_propagation.py index a48cd9130..6e0ad4dd4 100644 --- a/helion/_compiler/type_propagation.py +++ b/helion/_compiler/type_propagation.py @@ -271,19 +271,27 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo: # This allows zip to work in list comprehensions zipped_tuples = tuple(tuple(items) for items in value) return cls.from_example(zipped_tuples, origin) - if isinstance(value, torch.cuda._CudaDeviceProperties): + if isinstance( + value, (torch.cuda._CudaDeviceProperties, torch.xpu._XpuDeviceProperties) + ): attrs = {} env = CompileEnvironment.current() + compute_unit_literal = ( + "gpu_subslice_count" + if torch.xpu.is_available() + else "multi_processor_count" + ) + # Only `multi_processor_count` attribute is supported for now # TODO(yf225): support other torch.cuda._CudaDeviceProperties attributes - attr_origin = AttributeOrigin(origin, "multi_processor_count") + attr_origin = AttributeOrigin(origin, compute_unit_literal) # Create a symbolic integer that can be passed as kernel argument sym = env.create_unbacked_symint() HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin( origin=attr_origin ) - attrs["multi_processor_count"] = SymIntType(attr_origin, sym) + attrs[compute_unit_literal] = SymIntType(attr_origin, sym) return ClassType(origin, attrs) raise exc.UnsupportedPythonType(type(value).__name__) diff --git a/test/test_examples.expected b/test/test_examples.expected index 2c004def2..3e2bc61f0 100644 --- a/test/test_examples.expected +++ b/test/test_examples.expected @@ -1318,7 +1318,10 @@ def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, grou Output tensor of shape ``[sum(M_i), N]``. """ device = A_packed.device - num_workers = torch.cuda.get_device_properties(device).multi_processor_count + if device.type == 'xpu': + num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count + else: + num_workers = torch.cuda.get_device_properties(device.index).multi_processor_count total_M, K = A_packed.shape K2, N = B.shape assert K == K2