Skip to content

Commit

Permalink
codegen: Resolve overload ambiguities created by defaulted arguments
Browse files Browse the repository at this point in the history
ghstack-source-id: 5cd0c2544230a366b684a1f6116ea4a392b7829b
Pull Request resolved: pytorch#45666
  • Loading branch information
peterbell10 committed Oct 5, 2020
1 parent 6852e5c commit a103754
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 64 deletions.
36 changes: 36 additions & 0 deletions aten/src/ATen/native/TestOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/ScalarOps.h>

namespace at {
namespace native {
Expand Down Expand Up @@ -50,5 +51,40 @@ Tensor _test_string_default(const Tensor& dummy, std::string a, std::string b) {
return dummy;
}

// Test that overloads with ambiguity created by defaulted parameters work.
// The operator declared first should have priority always

// Overload a
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b, std::string c) {
TORCH_CHECK(a == 1 || a == -1);
TORCH_CHECK(b == "1" || b == "a");
TORCH_CHECK(c == "1" || c == "a");
return c10::scalar_to_tensor(1);
}

// Overload b
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b, int64_t c) {
TORCH_CHECK(a == 2 || a == -2);
TORCH_CHECK(b == "2" || b == "b");
TORCH_CHECK(c == 2 || c == -2);
return c10::scalar_to_tensor(2);
}

// Overload c
Tensor _test_ambiguous_defaults(const Tensor& dummy, std::string a, std::string b) {
TORCH_CHECK(a == "3" || a == "c");
TORCH_CHECK(b == "3" || b == "c");
return c10::scalar_to_tensor(3);
}

// Overload d
Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, std::string b, std::string c, std::string d) {
TORCH_CHECK(a == 4 || a == -4);
TORCH_CHECK(b == "4" || b == "d");
TORCH_CHECK(c == "4" || c == "d");
TORCH_CHECK(d == "4" || d == "d");
return c10::scalar_to_tensor(4);
}

} // namespace native
} // namespace at
20 changes: 20 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8293,3 +8293,23 @@
- func: _test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.a(Tensor dummy, int a=1, str b="1", str c="1") -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2", int c=2) -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.c(Tensor dummy, str a="3", str b="3") -> Tensor
use_c10_dispatcher: full
python_module: nn

# Note: this function is only for testing.
- func: _test_ambiguous_defaults.d(Tensor dummy, int a=4, str b="4", str c="4", str d="4") -> Tensor
use_c10_dispatcher: full
python_module: nn
26 changes: 26 additions & 0 deletions test/test_native_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,32 @@ def f(x):
scripted_fn = torch.jit.script(f)
scripted_fn(dummy)

def test_ambiguous_defaults(self):
def _do_test(x):
fn = torch._C._nn._test_ambiguous_defaults
return [
fn(x),
fn(x, -1),
fn(x, -1, "a"),
fn(x, -1, "a", "a"),
fn(x, b="a"),

fn(x, -2, "b", -2),
fn(x, c=-2),

fn(x, "c"),

fn(x, -4, "d", "d", "d"),
fn(x, d="d"),
]

dummy = torch.rand(1)
expect = [1] * 5 + [2] * 2 + [3] + [4] * 2
self.assertEqual(expect, _do_test(dummy))

scripted_test = torch.jit.script(_do_test)
self.assertEqual(expect, scripted_test(dummy))


if __name__ == '__main__':
run_tests()
9 changes: 6 additions & 3 deletions tools/codegen/api/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,16 @@ class CppSignature:
returns: Tuple[Return, ...]
arguments: Tuple[Union[Argument, TensorOptionsArguments, ThisArgument], ...]

def cpp_arguments(self) -> Sequence[CppArgument]:
return list(map(argument, self.arguments))
def cpp_arguments(self, exclude_this: bool = False) -> Sequence[CppArgument]:
if exclude_this:
return [argument(a) for a in self.arguments if not isinstance(a, ThisArgument)]
else:
return list(map(argument, self.arguments))

# Return arguments as a comma separated list, i.e. like they would be in a C++
# function signature. Include default values for arguments.
def cpp_arguments_str(self, with_defaults: bool) -> str:
args_without_this = [argument(a) for a in self.arguments if not isinstance(a, ThisArgument)]
args_without_this = self.cpp_arguments(exclude_this=True)
if with_defaults:
return ', '.join(map(str, args_without_this))
else:
Expand Down
176 changes: 115 additions & 61 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@ def cpp_string(s: str) -> str:
s = s.replace('\t', '\\t')
return f'"{s}"'

def value_or(first : Optional[T], second : T) -> T:
return first if first is not None else second

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
# C++ CODE GENERATION
Expand Down Expand Up @@ -325,9 +328,7 @@ def exprs_str(signature: CppSignature,
process_tensoroptions: dispatcher.ProcessTensoroptions = dispatcher.ProcessTensoroptions.PASS_THROUGH,
exclude_this: bool = False,
) -> str:
args = signature.cpp_arguments()
if exclude_this:
args = [a for a in args if not isinstance(a.argument, ThisArgument)]
args = signature.cpp_arguments(exclude_this=exclude_this)
exprs = dispatcher.cpparguments_exprs(args, process_tensoroptions=process_tensoroptions)
return ', '.join(map(lambda a: a.expr, exprs))

Expand All @@ -336,10 +337,37 @@ def types_str(signature: CppSignature) -> str:
exprs = dispatcher.cpparguments_exprs(args, process_tensoroptions=dispatcher.ProcessTensoroptions.PASS_THROUGH)
return ', '.join(map(lambda a: a.type, exprs))

ArgT = TypeVar('ArgT', CppArgument, LegacyDispatcherArgument)

# Resolve any overload ambiguities introduced by default values
# Always prefers the overload declared first, since this will be picked by
# the PythonArgParser as well
def unambiguous_defaults(
name: str, args: Sequence[ArgT],
seen_functions: Dict[str, List[Sequence[ArgT]]]) -> List[bool]:
overloads: List[Sequence[ArgT]] = seen_functions.get(name, [])
n = 0
for o_args in overloads:
for i, (a, b) in enumerate(zip(args, o_args)):
if b.default is not None:
n = max(n, i + 1)
if a.type != b.type:
break
else:
if len(o_args) < len(args):
n = max(n, len(o_args) + 1)

overloads.append(args)
seen_functions[name] = overloads
return [False] * n + [True] * (len(args) - n)

# Generates Function.cpp and Function.h. These files provide the
# functional public C++ API, and the scaffolding to call into
# the dispatcher from these functions. See also compute_tensor_method.
def compute_function(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
# the dispatcher from these functions. See also compute_tensor_methods.
def compute_functions(native_functions: List[NativeFunction], *, target: Target) -> List[str]:
# Map function name to a list of its overloads C++ arguments
seen_functions: Dict[str, List[Sequence[CppArgument]]] = {}

@with_native_function
def go(f: NativeFunction) -> Optional[str]:
if f.manual_kernel_registration:
Expand All @@ -352,22 +380,27 @@ def go(f: NativeFunction) -> Optional[str]:
signature_group = cpp.signature_group(f.func, method=False)

if target is Target.DECLARATION:
if signature_group.gathered_signature is None:
# There's no TensorOptions
return f"""
CAFFE2_API {cpp_returns_type} {cpp_name}({signature_group.signature.cpp_arguments_str(with_defaults=True)});
"""
else:
# There's TensorOptions in the API. Create 2 APIs - one taking the TensorOptions object ("gathered_signature"),
# and one taking a scattered signature with ScalarType, Layout, Device separately ("signature").
# The gathered_signature already exists in several older PyTorch versions and had default arguments.
# For backward compatibility, we left it unchanged and added the scattered API on top of it.
# Note that the scattered API cannot have default arguments or calls will be ambigious.
return f"""
CAFFE2_API {cpp_returns_type} {cpp_name}({signature_group.gathered_signature.cpp_arguments_str(with_defaults=True)});
CAFFE2_API {cpp_returns_type} {cpp_name}({signature_group.signature.cpp_arguments_str(with_defaults=False)});
primary_signature = value_or(signature_group.gathered_signature, signature_group.signature)
cpp_args = primary_signature.cpp_arguments()
use_defaults = unambiguous_defaults(cpp_name, cpp_args, seen_functions)
cpp_args_str = ', '.join(
str(arg) if use_def else arg.str_no_default()
for use_def, arg in zip(use_defaults, cpp_args))

str_signatures = f"""
CAFFE2_API {cpp_returns_type} {cpp_name}({cpp_args_str});
"""

# If there's TensorOptions in the API. Create 2 APIs - one taking the TensorOptions object ("gathered_signature"),
# and one taking a scattered signature with ScalarType, Layout, Device separately ("signature").
# The gathered_signature already exists in several older PyTorch versions and had default arguments.
# For backward compatibility, we left it unchanged and added the scattered API on top of it.
# Note that the scattered API cannot have default arguments or calls will be ambigious.
if signature_group.gathered_signature is not None:
str_signatures += f"CAFFE2_API {cpp_returns_type} {cpp_name}({signature_group.signature.cpp_arguments_str(with_defaults=False)});\n"

return str_signatures

assert target is Target.DEFINITION

dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
Expand Down Expand Up @@ -414,12 +447,15 @@ def go(f: NativeFunction) -> Optional[str]:
}}
"""

return go
return list(mapMaybe(go, native_functions))

# Generates TensorBody.h (sic) and TensorMethods.cpp. These files provide the
# object-oriented (method-based) public C++ API, and the scaffolding to call into
# the dispatcher from these functions. See also compute_function.
def compute_tensor_method(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
# the dispatcher from these functions. See also compute_functions.
def compute_tensor_methods(native_functions: List[NativeFunction], *, target: Target) -> List[str]:
# Map function name to list of functions and their C++ arguments
seen_functions: Dict[str, List[Sequence[CppArgument]]] = {}

@with_native_function
def go(f: NativeFunction) -> Optional[str]:
if Variant.method not in f.variants:
Expand All @@ -434,20 +470,27 @@ def go(f: NativeFunction) -> Optional[str]:
signature_group = cpp.signature_group(f.func, method=True)

if target is Target.DECLARATION:
if signature_group.gathered_signature is None:
# There's no TensorOptions. Just create the API without concern for TensorOptions.
return f"{cpp_returns_type} {cpp_name}({signature_group.signature.cpp_arguments_str(with_defaults=True)}) const;"
else:
# There's TensorOptions in the API. Create 2 APIs - one taking the TensorOptions object ("gathered_signature"),
# and one taking a scattered signature with ScalarType, Layout, Device separately ("signature").
# The gathered_signature already exists in several older PyTorch versions and had default arguments.
# For backward compatibility, we left it unchanged and added the scattered API on top of it.
# Note that the scattered API cannot have default arguments or calls will be ambigious.
return f"""
{cpp_returns_type} {cpp_name}({signature_group.gathered_signature.cpp_arguments_str(with_defaults=True)}) const;
{cpp_returns_type} {cpp_name}({signature_group.signature.cpp_arguments_str(with_defaults=False)}) const;
primary_signature = value_or(signature_group.gathered_signature, signature_group.signature)
cpp_args = primary_signature.cpp_arguments(exclude_this=True)
use_defaults = unambiguous_defaults(cpp_name, cpp_args, seen_functions)
cpp_args_str = ', '.join(
str(arg) if use_def else arg.str_no_default()
for use_def, arg in zip(use_defaults, cpp_args))

str_signatures = f"""
{cpp_returns_type} {cpp_name}({cpp_args_str}) const;
"""

# If there's TensorOptions in the API. Create 2 APIs - one taking the TensorOptions object ("gathered_signature"),
# and one taking a scattered signature with ScalarType, Layout, Device separately ("signature").
# The gathered_signature already exists in several older PyTorch versions and had default arguments.
# For backward compatibility, we left it unchanged and added the scattered API on top of it.
# Note that the scattered API cannot have default arguments or calls will be ambigious.
if signature_group.gathered_signature is not None:
str_signatures += f"{cpp_returns_type} {cpp_name}({signature_group.signature.cpp_arguments_str(with_defaults=False)}) const;\n"

return str_signatures

assert target is Target.DEFINITION

dispatcher_returns_type = dispatcher.returns_type(f.func.returns)
Expand Down Expand Up @@ -504,7 +547,7 @@ def go(f: NativeFunction) -> Optional[str]:
}}
"""

return go
return list(mapMaybe(go, native_functions))

# Generates ATenOpList.cpp, a runtime accessible list of all aten
# operators.
Expand All @@ -517,28 +560,39 @@ def compute_aten_op(f: NativeFunction) -> str:

# Generates NativeFunctions.h, a list of forward declarations of all
# actual kernel definitions we keep in aten/src/ATen/native/
@with_native_function
def compute_native_function_declaration(f: NativeFunction) -> List[str]:
if f.dispatch is None:
ns = [cpp.name(f.func)]
else:
ns = list(f.dispatch.values())

rs = []
# Sometimes a function name shows up multiple times; only generate
# it once!
seen = set()
for n in ns:
if n in seen:
continue
if "legacy::" in n:
continue
seen.add(n)
returns_type = legacy_dispatcher.returns_type(f.func.returns)
args = legacy_dispatcher.arguments(f.func)
rs.append(f"CAFFE2_API {returns_type} {n}({', '.join(map(lambda a: a.str_with_default(), args))});")
def compute_native_function_declarations(native_functions: List[NativeFunction]) -> List[str]:
# Map function name to a list of its overloads C++ arguments
seen_functions: Dict[str, List[Sequence[LegacyDispatcherArgument]]] = {}

return rs
@with_native_function
def go(f: NativeFunction) -> List[str]:
if f.dispatch is None:
ns = [cpp.name(f.func)]
else:
ns = list(f.dispatch.values())

rs = []
# Sometimes a function name shows up multiple times; only generate
# it once!
seen = set()
for n in ns:
if n in seen:
continue
if "legacy::" in n:
continue
seen.add(n)
returns_type = legacy_dispatcher.returns_type(f.func.returns)
args = legacy_dispatcher.arguments(f.func)

use_defaults = unambiguous_defaults(n, args, seen_functions)
args_str = ', '.join(
arg.str_with_default() if use_def else str(arg)
for use_def, arg in zip(use_defaults, args))
rs.append(f"CAFFE2_API {returns_type} {n}({args_str});")

return rs

return list(concatMap(go, native_functions))

# Generates BackendSelectRegister.cpp, a series of kernels which provide
# specialized computation of dispatch key for operator signatures which cannot
Expand Down Expand Up @@ -1126,22 +1180,22 @@ def make_file_manager(install_dir: str) -> FileManager:
native_functions)),
})
cpu_fm.write('Functions.h', lambda: {
'function_declarations': list(mapMaybe(compute_function(target=Target.DECLARATION), native_functions)),
'function_declarations': compute_functions(native_functions, target=Target.DECLARATION),
})
cpu_fm.write('Functions.cpp', lambda: {
'function_definitions': list(mapMaybe(compute_function(target=Target.DEFINITION), native_functions)),
'function_definitions': compute_functions(native_functions, target=Target.DEFINITION),
})
core_fm.write('TensorBody.h', lambda: {
'tensor_method_declarations': list(mapMaybe(compute_tensor_method(target=Target.DECLARATION), native_functions)),
'tensor_method_declarations': compute_tensor_methods(native_functions, target=Target.DECLARATION),
})
core_fm.write('TensorMethods.cpp', lambda: {
'tensor_method_definitions': list(mapMaybe(compute_tensor_method(target=Target.DEFINITION), native_functions)),
'tensor_method_definitions': compute_tensor_methods(native_functions, target=Target.DEFINITION),
})
core_fm.write('ATenOpList.cpp', lambda: {
'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
})
cpu_fm.write('NativeFunctions.h', lambda: {
'native_function_declarations': list(concatMap(compute_native_function_declaration, native_functions)),
'native_function_declarations': compute_native_function_declarations(native_functions),
})
cpu_fm.write('BackendSelectRegister.cpp', lambda: {
'backend_select_method_definitions':
Expand Down

0 comments on commit a103754

Please sign in to comment.