Skip to content

Commit

Permalink
codegen: Resolve overload ambiguities created by defaulted arguments
Browse files Browse the repository at this point in the history
This is a redux of #45666 post refactor, based off of
peterbell10@d534f7d
Credit goes to peterbell10 for the implementation.

Fixes #43945.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: 3ad69bac901e030bb5b8698bda002f7270b28679
Pull Request resolved: #49348
  • Loading branch information
ezyang committed Dec 14, 2020
1 parent 48793d2 commit d070cb9
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 4 deletions.
18 changes: 18 additions & 0 deletions aten/src/ATen/native/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/ScalarOps.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -50,5 +51,22 @@ Tensor _test_string_default(const Tensor& dummy, std::string a, std::string b) {
return dummy;
}

// Test that overloads with ambiguity created by defaulted parameters work.
// The operator declared first should have priority always

// Overload a
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) {
TORCH_CHECK(a == 1);
TORCH_CHECK(b == 1);
return c10::scalar_to_tensor(1);
}

// Overload b
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b) {
TORCH_CHECK(a == 2);
TORCH_CHECK(b == "2");
return c10::scalar_to_tensor(2);
}

} // namespace native
} // namespace at
11 changes: 11 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9646,3 +9646,14 @@
- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor
cpp_no_default_args: ['a', 'b']
use_c10_dispatcher: full
python_module: nn
8 changes: 8 additions & 0 deletions test/cpp/api/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ TEST_F(AutogradTest, CanPassCustomGradientInputs) {
z.sum().backward(torch::ones({}) * 2);
ASSERT_TRUE(x.grad().allclose(y * 2));
}

TEST(UtilsTest, AmbiguousOperatorDefaults) {
auto tmp = at::empty({}, at::kCPU);
at::_test_ambiguous_defaults(tmp);
at::_test_ambiguous_defaults(tmp, 1);
at::_test_ambiguous_defaults(tmp, 1, 1);
at::_test_ambiguous_defaults(tmp, 2, "2");
}
30 changes: 26 additions & 4 deletions tools/codegen/api/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tools.codegen.model import *
from dataclasses import dataclass
from typing import Optional, Union, Sequence, TypeVar, List
from typing import Optional, Union, Sequence, TypeVar, List, Set
from enum import Enum

_T = TypeVar('_T')
Expand Down Expand Up @@ -128,6 +128,13 @@ class CppSignature:
# (i.e. with a potential TensorOptions argument and out arguments in the front)
faithful: bool

# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: Set[str]

# Is this a fallback C++ binding? Fallback bindings are enabled by
# manual_cpp_binding: True and are alternate, non-public API that
# lets manual C++ binding implementors access the binding that would
# have been automatically generated
fallback_binding: bool = False

# Return the unpacked argument structure of this signature,
Expand All @@ -145,7 +152,10 @@ def name(self) -> str:
# Render the C++ declaration for this signature
def decl(self) -> str:
returns_type = cpp.returns_type(self.func.returns)
cpp_args_str = ', '.join(a.decl() for a in self.arguments())
cpp_args_str = ', '.join(
a.decl() if a.name not in self.cpp_no_default_args else a.no_default().decl()
for a in self.arguments()
)
return f"{returns_type} {self.name()}({cpp_args_str})"

# Render the C++ definition for this signature, not including
Expand All @@ -172,10 +182,22 @@ def from_native_function(f: NativeFunction, *, method: bool, fallback_binding: b
func = f.func
faithful_signature: Optional[CppSignature]
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = CppSignature(func=func, faithful=True, method=method, fallback_binding=fallback_binding)
faithful_signature = CppSignature(
func=func,
faithful=True,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args
)
else:
faithful_signature = None
signature = CppSignature(func=func, faithful=False, method=method, fallback_binding=fallback_binding)
signature = CppSignature(
func=func,
faithful=False,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args
)
return CppSignatureGroup(
func=func,
signature=signature,
Expand Down
13 changes: 13 additions & 0 deletions tools/codegen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class NativeFunction:
# changes the semantics of set_output to call the parent class.
structured_inherits: Optional[str]

# Argument names whose default should be excluded from the C++ interface.
# Intended for resolving overload ambiguities between signatures.
cpp_no_default_args: Set[str]

# Note [Abstract ATen methods]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# An abstract ATen method is one whose dispatch differs between
Expand Down Expand Up @@ -169,6 +173,10 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction':
assert isinstance(funcs, str), f'not a str: {funcs}'
func = FunctionSchema.parse(funcs)

cpp_no_default_args_list = e.pop('cpp_no_default_args', [])
assert isinstance(cpp_no_default_args_list, list)
cpp_no_default_args = set(cpp_no_default_args_list)

use_c10_dispatcher_s = e.pop('use_c10_dispatcher', None)
if use_c10_dispatcher_s is None:
use_c10_dispatcher = UseC10Dispatcher.with_codegenerated_unboxing_wrapper
Expand Down Expand Up @@ -258,6 +266,7 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction':
dispatch=dispatch,
device_guard=device_guard,
loc=loc,
cpp_no_default_args=cpp_no_default_args,
)

def validate_unstructured(self) -> None:
Expand Down Expand Up @@ -293,6 +302,10 @@ def __post_init__(self) -> None:
# happen
assert not (self.structured and self.structured_delegate), \
"Cannot have both structured and structured_delegate on function"
defaulted_arguments = {a.name for a in self.func.schema_order_arguments()
if a.default is not None}
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
assert len(invalid_args) == 0, f'Invalid cpp_no_default_args: {invalid_args}'

SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out'))

Expand Down

0 comments on commit d070cb9

Please sign in to comment.