From ffae8d28c650d995f91a5be7ea72bce982013a79 Mon Sep 17 00:00:00 2001 From: "Wang, Eikan" Date: Thu, 2 Oct 2025 14:19:06 +0000 Subject: [PATCH] Add XPU support for RNG operations --- helion/_compiler/generate_ast.py | 2 +- test/test_rng.expected | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/helion/_compiler/generate_ast.py b/helion/_compiler/generate_ast.py index e4caf8e02..c5710b952 100644 --- a/helion/_compiler/generate_ast.py +++ b/helion/_compiler/generate_ast.py @@ -81,7 +81,7 @@ def get_rng_seed_buffer_statements(self) -> list[ast.AST]: # Create host-side seed buffer with the required number of seeds seed_buffer_stmt = statement_from_string( - f"_rng_seed_buffer = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.device('cuda'))" + f"_rng_seed_buffer = inductor_prims.seeds({self.device_function.rng_seed_count}, torch.accelerator.current_accelerator())" ) return [import_stmt, seed_buffer_stmt] diff --git a/test/test_rng.expected b/test/test_rng.expected index da0721a6a..38b160bb9 100644 --- a/test/test_rng.expected +++ b/test/test_rng.expected @@ -37,7 +37,7 @@ def _helion_multiple_rng_ops_kernel(rand1, rand2, uniform, normal, randn_a, rand def multiple_rng_ops_kernel(x: torch.Tensor, *, _launcher=_default_launcher): from torch._inductor import inductor_prims - _rng_seed_buffer = inductor_prims.seeds(7, torch.device('cuda')) + _rng_seed_buffer = inductor_prims.seeds(7, torch.accelerator.current_accelerator()) rand1 = torch.zeros_like(x) rand2 = torch.zeros_like(x) uniform = torch.zeros_like(x)