Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,13 @@ def grouped_gemm_jagged_persistent(
"""
# Set worker count to match GPU streaming multiprocessor count
device = A_packed.device
num_workers = torch.cuda.get_device_properties(device).multi_processor_count # type: ignore[arg-type]
if device.type == "xpu":
# TODO(EikanWang): gpu_subslice_count is an out-of-date term. we will update it to XeCore number.
num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count
else:
num_workers = torch.cuda.get_device_properties(
device.index
).multi_processor_count

# Define tunable block sizes for M, N dimensions (auto-tuned at runtime)
BLOCK_M = hl.register_block_size(32, 128)
Expand Down Expand Up @@ -280,7 +286,7 @@ def _reference_grouped_gemm(
# ---------------------------
def main() -> None:
torch.manual_seed(0) # Ensure reproducible test results
device = "cuda"
device = "xpu" if torch.xpu.is_available() else "cuda"
dtype = torch.bfloat16
G = 4 # Number of groups to test
K, N = 256, 128 # Shared dimensions: K (reduction), N (output columns)
Expand Down
14 changes: 11 additions & 3 deletions helion/_compiler/type_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,19 +271,27 @@ def from_example(cls, value: object, origin: Origin) -> TypeInfo:
# This allows zip to work in list comprehensions
zipped_tuples = tuple(tuple(items) for items in value)
return cls.from_example(zipped_tuples, origin)
if isinstance(value, torch.cuda._CudaDeviceProperties):
if isinstance(
value, (torch.cuda._CudaDeviceProperties, torch.xpu._XpuDeviceProperties)
):
attrs = {}
env = CompileEnvironment.current()

compute_unit_literal = (
"gpu_subslice_count"
if torch.xpu.is_available()
else "multi_processor_count"
)

# Only `multi_processor_count` attribute is supported for now
# TODO(yf225): support other torch.cuda._CudaDeviceProperties attributes
attr_origin = AttributeOrigin(origin, "multi_processor_count")
attr_origin = AttributeOrigin(origin, compute_unit_literal)
# Create a symbolic integer that can be passed as kernel argument
sym = env.create_unbacked_symint()
HostFunction.current().expr_to_origin[sym._sympy_()] = SymbolOrigin(
origin=attr_origin
)
attrs["multi_processor_count"] = SymIntType(attr_origin, sym)
attrs[compute_unit_literal] = SymIntType(attr_origin, sym)

return ClassType(origin, attrs)
raise exc.UnsupportedPythonType(type(value).__name__)
Expand Down
5 changes: 4 additions & 1 deletion test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,10 @@ def grouped_gemm_jagged_persistent(A_packed: torch.Tensor, B: torch.Tensor, grou
Output tensor of shape ``[sum(M_i), N]``.
"""
device = A_packed.device
num_workers = torch.cuda.get_device_properties(device).multi_processor_count
if device.type == 'xpu':
num_workers = torch.xpu.get_device_properties(device.index).gpu_subslice_count
else:
num_workers = torch.cuda.get_device_properties(device.index).multi_processor_count
total_M, K = A_packed.shape
K2, N = B.shape
assert K == K2
Expand Down
Loading