Skip to content

Commit

Permalink
Construct CppSignatureGroup from NativeFunction (pytorch#49245)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#49245

This will make it easier to implement the POC in
peterbell10@d534f7d
see also pytorch#45666

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: smessmer

Differential Revision: D25594005

Pulled By: ezyang

fbshipit-source-id: e458d3dc3a765ec77425761b9b17f23769cecf9e
  • Loading branch information
ezyang authored and Spandan Tiwari committed Jan 5, 2021
1 parent 2f1e7f4 commit b6a5fb3
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion tools/autograd/gen_python_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def signature_original(f: NativeFunction) -> str:
opname += '_out'
if f.func.name.name.inplace and pyi:
opname += '_'
args = CppSignatureGroup.from_schema(f.func, method=False).signature.arguments()
args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
# Simply ignore TensorOptionsArguments as it does not exist in deprecated.yaml.
types = ', '.join(argument_type_str(a.argument.type)
for a in args if isinstance(a.argument, Argument))
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/gen_trace_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def dispatch_trace_input(arg: Union[Argument, TensorOptionsArguments]) -> Sequen
if f.use_c10_dispatcher.dispatcher_uses_new_style():
args = list(f.func.schema_order_arguments())
else:
sig_group = CppSignatureGroup.from_schema(f.func, method=False)
sig_group = CppSignatureGroup.from_native_function(f, method=False)
args = [cpp_args.argument for cpp_args in sig_group.signature.arguments()
if not isinstance(cpp_args.argument, SelfArgument)]

Expand Down Expand Up @@ -380,7 +380,7 @@ def method_definition(f: NativeFunction) -> Optional[str]:
for a in f.func.schema_order_arguments()
)
else:
sig_group = CppSignatureGroup.from_schema(f.func, method=False)
sig_group = CppSignatureGroup.from_native_function(f, method=False)
formals = ', '.join(f'{a.type} {a.name}' for a in sig_group.signature.arguments())

return METHOD_DEFINITION.substitute(
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def process_function(f: NativeFunction) -> Optional[str]:
if Variant.function not in f.variants or not is_factory:
return None

sig = CppSignatureGroup.from_schema(f.func, method=False).signature
sig = CppSignatureGroup.from_native_function(f, method=False).signature
formals: List[str] = []
exprs: List[str] = []
requires_grad = 'false'
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,7 +974,7 @@ def get_decl_signature(declaration: Dict[Any, Any], use_base_variant: bool = Fal

@with_native_function
def get_func_signature(f: NativeFunction) -> str:
args = CppSignatureGroup.from_schema(f.func, method=False).signature.arguments()
args = CppSignatureGroup.from_native_function(f, method=False).signature.arguments()
types = ', '.join(python.argument_type_str(a.argument.type, simple_type=True)
if isinstance(a.argument, Argument) else 'TensorOptions'
for a in args)
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/load_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def load_derivatives(derivatives_yaml_path: str, native_yaml_path: str) -> Seque

@with_native_function
def cpp_arguments(f: NativeFunction) -> Sequence[Binding]:
return CppSignatureGroup.from_schema(f.func, method=False).signature.arguments()
return CppSignatureGroup.from_native_function(f, method=False).signature.arguments()

def create_derivative(f: NativeFunction, formula: str, var_names: Tuple[str, ...]) -> Derivative:
arguments = cpp_arguments(f)
Expand Down
2 changes: 1 addition & 1 deletion tools/codegen/api/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ class DispatchLambdaArgumentExprs:
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #

def _cpp_signature(f: NativeFunction, *, method: bool = False) -> CppSignature:
return CppSignatureGroup.from_schema(f.func, method=method).signature
return CppSignatureGroup.from_native_function(f, method=method).signature

def has_tensor_options(f: NativeFunction) -> bool:
return f.func.arguments.tensor_options is not None
Expand Down
3 changes: 2 additions & 1 deletion tools/codegen/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ class CppSignatureGroup:
faithful_signature: Optional[CppSignature]

@staticmethod
def from_schema(func: FunctionSchema, *, method: bool, fallback_binding: bool = False) -> 'CppSignatureGroup':
def from_native_function(f: NativeFunction, *, method: bool, fallback_binding: bool = False) -> 'CppSignatureGroup':
func = f.func
faithful_signature: Optional[CppSignature]
if func.arguments.tensor_options is not None or len(func.arguments.out) > 0:
faithful_signature = CppSignature(func=func, faithful=True, method=method, fallback_binding=fallback_binding)
Expand Down
6 changes: 3 additions & 3 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,7 +596,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]:

name = cpp.name(f.func)

sig_group = CppSignatureGroup.from_schema(f.func, method=False, fallback_binding=f.manual_cpp_binding)
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)

if self.target is Target.DECLARATION:
result = f"TORCH_API {sig_group.signature.decl()};\n"
Expand Down Expand Up @@ -650,7 +650,7 @@ def __call__(self, f: NativeFunction) -> Optional[str]:

name = cpp.name(f.func)

sig_group = CppSignatureGroup.from_schema(f.func, method=True, fallback_binding=f.manual_cpp_binding)
sig_group = CppSignatureGroup.from_native_function(f, method=True, fallback_binding=f.manual_cpp_binding)

if self.target is Target.DECLARATION:
result = f"{sig_group.signature.decl()} const;\n"
Expand Down Expand Up @@ -1032,7 +1032,7 @@ def compute_declaration_yaml(f: NativeFunction) -> object:
kwarg_only_set = set(a.name for a in f.func.arguments.flat_kwarg_only)
out_arg_set = set(a.name for a in f.func.arguments.out)

sig_group = CppSignatureGroup.from_schema(f.func, method=False, fallback_binding=False)
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=False)
cpp_args = sig_group.signature.arguments()
arguments = [
compute_cpp_argument_yaml(
Expand Down

0 comments on commit b6a5fb3

Please sign in to comment.