diff --git a/helion/runtime/ref_mode.py b/helion/runtime/ref_mode.py index 419c136db..8b68e8a8b 100644 --- a/helion/runtime/ref_mode.py +++ b/helion/runtime/ref_mode.py @@ -264,7 +264,13 @@ def _handle_factory_method( ) -> torch.Tensor: """Handle tensor.new_* factory methods (new_zeros, new_ones, new_full) with RefTile arguments.""" tensor = cast("torch.Tensor", args[0]) - size = convert_size_arg(args[1]) + if "size" in kwargs: + kwargs = dict(kwargs) + size = kwargs.pop("size") + assert len(args) == 1 + else: + size = args[1] + size = convert_size_arg(size) method = getattr(tensor, method_name) extra_args = args[2:] return method(size, *extra_args, **kwargs)