From df9c6ca62af994a2246feac4a4b27d18e813a8b8 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Fri, 2 Feb 2024 18:17:34 +0000 Subject: [PATCH] Fix internal failure D53291154 (#118907) Fix internal failure D53291154 from alban: the change is breaking because the alpha argument is now kwarg only (via the * marker) while it was ok for it to be positional before for the rsub.Scalar overload ``` _wrapped_call_impl return self._call_impl(*args, **kwargs) File "torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "torch/_dynamo/eval_frame.py", line 453, in _fn return fn(*args, **kwargs) File "torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "torch/_dynamo/eval_frame.py", line 615, in catch_errors return callback(frame, cache_entry, hooks, frame_state) File "torch/_dynamo/convert_frame.py", line 390, in _convert_frame_assert return _compile( File "python3.10/contextlib.py", line 79, in inner return func(*args, **kwds) File "torch/_dynamo/convert_frame.py", line 650, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "torch/_dynamo/utils.py", line 248, in time_wrapper r = func(*args, **kwargs) File "torch/_dynamo/convert_frame.py", line 531, in compile_inner out_code = transform_code_object(code, transform) File "torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object transformations(instructions, code_options) File "torch/_dynamo/convert_frame.py", line 155, in _fn return fn(*args, **kwargs) File "torch/_dynamo/convert_frame.py", line 496, in transform tracer.run() File "torch/_dynamo/symbolic_convert.py", line 2125, in run super().run() File "torch/_dynamo/symbolic_convert.py", line 787, in run and self.step() File "torch/_dynamo/symbolic_convert.py", line 750, in step getattr(self, inst.opname)(inst) File "torch/_dynamo/symbolic_convert.py", line 469, in wrapper return inner_fn(self, inst) File "torch/_dynamo/symbolic_convert.py", line 1249, in CALL_FUNCTION_KW self.call_function(fn, args, kwargs) File "torch/_dynamo/symbolic_convert.py", line 651, in call_function self.push(fn.call_function(self, args, kwargs)) File "torch/_dynamo/variables/torch.py", line 614, in call_function tensor_variable = wrap_fx_proxy( File "torch/_dynamo/variables/builder.py", line 1285, in wrap_fx_proxy return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) File "torch/_dynamo/variables/builder.py", line 1370, in wrap_fx_proxy_cls example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) File "torch/_dynamo/utils.py", line 1653, in get_fake_value raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None File "torch/_dynamo/utils.py", line 1599, in get_fake_value ret_val = wrap_fake_exception( File "torch/_dynamo/utils.py", line 1140, in wrap_fake_exception return fn() File "torch/_dynamo/utils.py", line 1600, in lambda: run_node(tx.output, node, args, kwargs, nnmodule) File "torch/_dynamo/utils.py", line 1720, in run_node raise RuntimeError(fn_str + str(e)).with_traceback(e.__traceback__) from e File "torch/_dynamo/utils.py", line 1699, in run_node return node.target(*args, **kwargs) File "torch/utils/_stats.py", line 20, in wrapper return fn(*args, **kwargs) File "torch/_subclasses/fake_tensor.py", line 1637, in __torch_dispatch__ return self.dispatch(func, types, args, kwargs) File "torch/_subclasses/fake_tensor.py", line 1975, in dispatch return self._dispatch_impl(func, types, args, kwargs) File "torch/_subclasses/fake_tensor.py", line 2190, in _dispatch_impl r = func(*args, **kwargs) File "torch/_ops.py", line 571, in __call__ return self_._op(*args, **kwargs) File "torch/_prims_common/wrappers.py", line 252, in _fn result = fn(*args, **kwargs) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/118907 Approved by: https://github.com/lezcano (cherry picked from commit 3a1ae86a9307495baf967c6a73f1535640b77654) --- torch/_refs/__init__.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index 7a78c08dd0c24..1dc43d67028cc 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -1666,8 +1666,7 @@ def remainder(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType: def rsub( a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType], - *, - alpha: Optional[NumberType] = None, + alpha: NumberType = 1, ): if isinstance(a, Number): msg = "Received a Number for the first argument, but expected a Tensor"