Skip to content

Commit

Permalink
codegen: Resolve overload ambiguities created by defaulted arguments
Browse files Browse the repository at this point in the history
This is a redux of #45666 post refactor, based off of
peterbell10@d534f7d
Credit goes to peterbell10 for the implementation.

Fixes #43945.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: 4b349c2deac1c322976fada2afa807ba6d0684f6
Pull Request resolved: #49348
  • Loading branch information
ezyang committed Dec 16, 2020
1 parent fe43a58 commit 6f6747c
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 13 deletions.
18 changes: 18 additions & 0 deletions aten/src/ATen/native/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/ScalarOps.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -50,5 +51,22 @@ Tensor _test_string_default(const Tensor& dummy, std::string a, std::string b) {
return dummy;
}

// Test that overloads with ambiguity created by defaulted parameters work.
// The operator declared first should have priority always

// Overload a
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) {
TORCH_CHECK(a == 1);
TORCH_CHECK(b == 1);
return c10::scalar_to_tensor(1);
}

// Overload b
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b) {
TORCH_CHECK(a == 2);
TORCH_CHECK(b == "2");
return c10::scalar_to_tensor(2);
}

} // namespace native
} // namespace at
11 changes: 11 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10188,3 +10188,14 @@
- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor
cpp_no_default_args: ['a', 'b']
use_c10_dispatcher: full
python_module: nn
8 changes: 8 additions & 0 deletions test/cpp/api/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,11 @@ TEST_F(AutogradTest, CanPassCustomGradientInputs) {
z.sum().backward(torch::ones({}) * 2);
ASSERT_TRUE(x.grad().allclose(y * 2));
}

TEST(UtilsTest, AmbiguousOperatorDefaults) {
auto tmp = at::empty({}, at::kCPU);
at::_test_ambiguous_defaults(tmp);
at::_test_ambiguous_defaults(tmp, 1);
at::_test_ambiguous_defaults(tmp, 1, 1);
at::_test_ambiguous_defaults(tmp, 2, "2");
}
30 changes: 22 additions & 8 deletions tools/codegen/api/cpp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tools.codegen.model import *
from tools.codegen.api.types import *
import tools.codegen.local as local
from typing import Optional, Sequence, Union, List
from typing import Optional, Sequence, Union, List, Set

# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
Expand Down Expand Up @@ -237,26 +237,37 @@ def default_expr(d: str, t: Type) -> str:

def argument(
a: Union[Argument, TensorOptionsArguments, SelfArgument],
*, method: bool = False, faithful: bool = False,
has_tensor_options: bool = False
*, cpp_no_default_args: Set[str], method: bool, faithful: bool,
has_tensor_options: bool
) -> List[Binding]:
def sub_argument(a: Union[Argument, TensorOptionsArguments, SelfArgument]) -> List[Binding]:
return argument(
a, cpp_no_default_args=cpp_no_default_args, method=method, faithful=faithful,
has_tensor_options=has_tensor_options)

if isinstance(a, Argument):
binds: ArgName
if a.name == "memory_format" and has_tensor_options:
binds = SpecialArgName.possibly_redundant_memory_format
else:
binds = a.name
default: Optional[str] = None
if a.name not in cpp_no_default_args and a.default is not None:
default = default_expr(a.default, a.type)
return [Binding(
ctype=argument_type(a, binds=binds),
name=a.name,
default=default_expr(a.default, a.type) if a.default is not None else None,
default=default,
argument=a,
)]
elif isinstance(a, TensorOptionsArguments):
if faithful:
return argument(a.dtype) + argument(a.layout) + argument(a.device) + argument(a.pin_memory)
return sub_argument(a.dtype) + sub_argument(a.layout) + \
sub_argument(a.device) + sub_argument(a.pin_memory)
else:
default = None
# Enforced by NativeFunction.__post_init__
assert 'options' not in cpp_no_default_args
if all(x.default == "None" for x in a.all()):
default = '{}'
elif a.dtype.default == "long":
Expand All @@ -272,13 +283,13 @@ def argument(
# Caller is responsible for installing implicit this in context!
return []
else:
return argument(a.argument)
return sub_argument(a.argument)
else:
assert_never(a)

def arguments(
arguments: Arguments,
*, faithful: bool, method: bool
*, faithful: bool, method: bool, cpp_no_default_args: Set[str]
) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
if faithful:
Expand All @@ -289,5 +300,8 @@ def arguments(
args.extend(arguments.non_out)
return [
r.no_default() if faithful else r for a in args
for r in argument(a, faithful=faithful, method=method, has_tensor_options=arguments.tensor_options is not None)
for r in argument(
a, faithful=faithful, method=method,
has_tensor_options=arguments.tensor_options is not None,
cpp_no_default_args=cpp_no_default_args)
]
29 changes: 25 additions & 4 deletions tools/codegen/api/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tools.codegen.model import *
from dataclasses import dataclass
from typing import Optional, Union, Sequence, TypeVar, List
from typing import Optional, Union, Sequence, TypeVar, List, Set
from enum import Enum

_T = TypeVar('_T')
Expand Down Expand Up @@ -128,13 +128,22 @@ class CppSignature:
# (i.e. with a potential TensorOptions argument and out arguments in the front)
faithful: bool

# The set of C++ arguments which should not have defaults applied to them
cpp_no_default_args: Set[str]

# Is this a fallback C++ binding? Fallback bindings are enabled by
# manual_cpp_binding: True and are alternate, non-public API that
# lets manual C++ binding implementors access the binding that would
# have been automatically generated
fallback_binding: bool = False

# Return the unpacked argument structure of this signature,
# discarding information about which arguments are semantically
# related to each other.
def arguments(self) -> Sequence[Binding]:
return cpp.arguments(self.func.arguments, faithful=self.faithful, method=self.method)
return cpp.arguments(
self.func.arguments, faithful=self.faithful,
method=self.method, cpp_no_default_args=self.cpp_no_default_args)

def name(self) -> str:
n = cpp.name(self.func, faithful_name_for_out_overloads=self.faithful)
Expand Down Expand Up @@ -172,10 +181,22 @@ def from_native_function(f: NativeFunction, *, method: bool, fallback_binding: b
func = f.func
faithful_signature: Optional[CppSignature]
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = CppSignature(func=func, faithful=True, method=method, fallback_binding=fallback_binding)
faithful_signature = CppSignature(
func=func,
faithful=True,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args
)
else:
faithful_signature = None
signature = CppSignature(func=func, faithful=False, method=method, fallback_binding=fallback_binding)
signature = CppSignature(
func=func,
faithful=False,
method=method,
fallback_binding=fallback_binding,
cpp_no_default_args=f.cpp_no_default_args
)
return CppSignatureGroup(
func=func,
signature=signature,
Expand Down
4 changes: 3 additions & 1 deletion tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,9 @@ def compute_declaration_yaml(f: NativeFunction) -> object:

cpp_schema_order_types = [
# NB: method here doesn't matter
r.type for a in schema_order_jit_arguments for r in cpp.argument(a, method=False)
r.type for a in schema_order_jit_arguments
for r in cpp.argument(
a, method=False, cpp_no_default_args=set(), faithful=False, has_tensor_options=False)
]

cpp_returns = cpp.returns_type(f.func.returns)
Expand Down
13 changes: 13 additions & 0 deletions tools/codegen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ class NativeFunction:
# changes the semantics of set_output to call the parent class.
structured_inherits: Optional[str]

# Argument names whose default should be excluded from the C++ interface.
# Intended for resolving overload ambiguities between signatures.
cpp_no_default_args: Set[str]

# Note [Abstract ATen methods]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# An abstract ATen method is one whose dispatch differs between
Expand Down Expand Up @@ -169,6 +173,10 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction':
assert isinstance(funcs, str), f'not a str: {funcs}'
func = FunctionSchema.parse(funcs)

cpp_no_default_args_list = e.pop('cpp_no_default_args', [])
assert isinstance(cpp_no_default_args_list, list)
cpp_no_default_args = set(cpp_no_default_args_list)

use_c10_dispatcher_s = e.pop('use_c10_dispatcher', None)
if use_c10_dispatcher_s is None:
use_c10_dispatcher = UseC10Dispatcher.with_codegenerated_unboxing_wrapper
Expand Down Expand Up @@ -258,6 +266,7 @@ def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction':
dispatch=dispatch,
device_guard=device_guard,
loc=loc,
cpp_no_default_args=cpp_no_default_args,
)

def validate_unstructured(self) -> None:
Expand Down Expand Up @@ -293,6 +302,10 @@ def __post_init__(self) -> None:
# happen
assert not (self.structured and self.structured_delegate), \
"Cannot have both structured and structured_delegate on function"
defaulted_arguments = {a.name for a in self.func.schema_order_arguments()
if a.default is not None}
invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments)
assert len(invalid_args) == 0, f'Invalid cpp_no_default_args: {invalid_args}'

SchemaKind = Enum('SchemaKind', ('functional', 'inplace', 'out'))

Expand Down

0 comments on commit 6f6747c

Please sign in to comment.