diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index 7091e4f78aef9..4afc7619c2ebd 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -191,8 +191,8 @@ ScalarType result_type(const Scalar& scalar1, const Scalar& scalar2) { return result_type(state); } -bool can_cast(const at::ScalarType from, const at::ScalarType to) { - return at::canCast(from, to); +bool can_cast(const at::ScalarType from_, const at::ScalarType to) { + return at::canCast(from_, to); } ScalarType promote_types(ScalarType type1, ScalarType type2) { diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 8cf229c69c238..10d8b1ad79cad 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -7714,7 +7714,7 @@ - func: result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType -- func: can_cast(ScalarType from, ScalarType to) -> bool +- func: can_cast(ScalarType from_, ScalarType to) -> bool variants: function - func: promote_types(ScalarType type1, ScalarType type2) -> ScalarType diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 285e410a79edc..81b85a4fe42f9 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -140,6 +140,8 @@ ("onednn::qconv2d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv3d_pointwise", datetime.date(2024, 12, 31)), ("onednn::qconv2d_pointwise.binary", datetime.date(2024, 12, 31)), + # BC-breaking change in can_cast signature: 'from' -> 'from_' + ("aten::can_cast", datetime.date(2024, 5, 31)), ] ALLOW_LIST_COMPILED = [ diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ba8e899dc9437..6d22f9dcf9845 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -2195,13 +2195,13 @@ def merge_dicts(*dicts): add_docstr( torch.can_cast, r""" -can_cast(from, to) -> bool +can_cast(from_, to) -> bool Determines if a type conversion is allowed under PyTorch casting rules described in the type promotion :ref:`documentation `. Args: - from (dtype): The original :class:`torch.dtype`. + from\_ (dtype): The original :class:`torch.dtype`. to (dtype): The target :class:`torch.dtype`. Example::