Skip to content

Commit

Permalink
Convert from higher order functions to classes in tools.codegen.gen (#…
Browse files Browse the repository at this point in the history
…47008)

Summary:
Pull Request resolved: #47008

bhosmer has been complaining about how it is difficult to distinguish
between local variables and closed over variables in the higher order
functions.  Well, closures and objects do basically the same thing, so
just convert all these HOFs into objects.

The decoder ring:
- Higher order function => Constructor for object
- Access to closed over variable => Access to member variable on object
- with_native_function => method_with_native_function (because it's
  hard writing decorators that work for both functions and methods)

I didn't even have to change indentation (much).

When there is no need for closed over variables (a few functions), I
kept them as plain old functions, no need for an object with no
members.

While I was at it, I also deleted the kwargs, since the types are
enough to prevent mistakes.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D24600805

Pulled By: ezyang

fbshipit-source-id: 7e3ce8cb2446e3788f934ddcc17f7da6e9299511
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 11, 2020
1 parent d478605 commit 0c64f9f
Showing 1 changed file with 86 additions and 68 deletions.
154 changes: 86 additions & 68 deletions tools/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import pathlib
import functools
import json
from dataclasses import dataclass

from tools.codegen.code_template import CodeTemplate
from tools.codegen.model import *
Expand Down Expand Up @@ -102,13 +103,25 @@ def parse_native_yaml(path: str) -> List[NativeFunction]:
def with_native_function(func: Callable[[NativeFunction], T]) -> Callable[[NativeFunction], T]:
@functools.wraps(func)
def wrapper(f: NativeFunction) -> T:
with context(f'in {f.loc}:\n {f.func}'):
with local.parametrize(
use_c10_dispatcher=f.use_c10_dispatcher,
):
return func(f)
with native_function_manager(f):
return func(f)
return wrapper

def method_with_native_function(func: Callable[[S, NativeFunction], T]) -> Callable[[S, NativeFunction], T]:
@functools.wraps(func)
def wrapper(slf: S, f: NativeFunction) -> T:
with native_function_manager(f):
return func(slf, f)
return wrapper

@contextlib.contextmanager
def native_function_manager(f: NativeFunction) -> Iterator[None]:
with context(f'in {f.loc}:\n {f.func}'):
with local.parametrize(
use_c10_dispatcher=f.use_c10_dispatcher,
):
yield

# These two functions purposely return generators in analogy to map()
# so that you don't mix up when you need to list() them

Expand Down Expand Up @@ -180,49 +193,53 @@ def cpp_string(s: str) -> str:
#
# This function is also used for a secondary purpose: the registration
# logic is also reused to implement per-operator registration.
def compute_type_method(
dispatch: Optional[str], *,
@dataclass(frozen=True)
class ComputeTypeMethod:
dispatch: Optional[str]

# TODO: Give more precise type Union[Literal[Target.DEFINITION,
# Target.REGISTRATION]]; requires Literal from typing_extensions
# which we don't have a dep for yet.
target: Target,
target: Target

# Selector object to determine which operators to generate
# registration code for.
selector: SelectiveBuilder
) -> Callable[[NativeFunction], Optional[str]]:

if dispatch is None:
assert target is Target.REGISTRATION
def __post_init__(self) -> None:
assert self.target is not Target.DECLARATION
if self.dispatch is None:
assert self.target is Target.REGISTRATION

@with_native_function
def func(f: NativeFunction) -> Optional[str]:
# Has to be here as mypy won't transfer asserts into closures
assert target is not Target.DECLARATION
@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
# for mypy type refinement; would be fixed by TODO on target
assert self.target is not Target.DECLARATION

if dispatch is not None:
if dispatch not in f.dispatch:
if self.dispatch is not None:
if self.dispatch not in f.dispatch:
return None

op_name = f"aten::{f.func.name}"
if target is Target.REGISTRATION and not selector.is_operator_selected(op_name):
if self.target is Target.REGISTRATION and not self.selector.is_operator_selected(op_name):
return None

name = native.name(f.func)
returns_type = native.returns_type(f.func.returns)
args = native.arguments(f.func)
args_str = ', '.join(map(str, args))
dispatch_to_all_backends = dispatch is not None and dispatch in KEYWORD_ALL_BACKENDS
dispatch_to_all_backends = self.dispatch is not None and self.dispatch in KEYWORD_ALL_BACKENDS

if target is Target.DEFINITION:
assert dispatch is not None
impl_name = f"at::native::{f.dispatch[dispatch]}"
if self.target is Target.DEFINITION:
assert self.dispatch is not None
impl_name = f"at::native::{f.dispatch[self.dispatch]}"

args_exprs_str = ', '.join(a.name for a in args)

return_kw = " return "

cuda_guard = ""
if dispatch_to_all_backends or 'CUDA' in dispatch:
if dispatch_to_all_backends or 'CUDA' in self.dispatch:
self_args = (a for a in f.func.arguments if a.name == "self")

# There is precedence for which argument we use to do
Expand All @@ -249,7 +266,7 @@ def func(f: NativeFunction) -> Optional[str]:
# works just as well.
if f.device_guard and dispatch_to_all_backends and has_tensor_options:
cuda_guard = cuda_guard_from_tensor_options
elif f.device_guard and dispatch is not None and 'CUDA' in dispatch and has_tensor_options:
elif f.device_guard and self.dispatch is not None and 'CUDA' in self.dispatch and has_tensor_options:
cuda_guard = f"""\
globalContext().lazyInitCUDA();
{cuda_guard_from_tensor_options}
Expand All @@ -269,16 +286,16 @@ def func(f: NativeFunction) -> Optional[str]:
}}
"""

elif target is Target.REGISTRATION:
if dispatch is None:
elif self.target is Target.REGISTRATION:
if self.dispatch is None:
return f'm.def({cpp_string(str(f.func))});\n'
elif f.manual_kernel_registration:
return None
else:
if dispatch_to_all_backends:
type_name = f'TypeDefault::{name}'
else:
type_name = f'{dispatch}Type::{name}'
type_name = f'{self.dispatch}Type::{name}'

dispatcher_sig = DispatcherSignature.from_schema(f.func)

Expand All @@ -302,21 +319,22 @@ def func(f: NativeFunction) -> Optional[str]:
# in a TORCH_LIBRARY_FRAGMENT that does not have an ambient backend. So
# the torch::dispatch specification here is important! See
# Note [Redundancy in registration code is OK] for how we handle redundant info.
if dispatch is not None:
payload = f"torch::dispatch(DispatchKey::{dispatch},\n{payload})\n"
if self.dispatch is not None:
payload = f"torch::dispatch(DispatchKey::{self.dispatch},\n{payload})\n"

return f'm.impl("{f.func.name}",\n{payload});\n'
else:
assert_never(target)

return func
assert_never(self.target)

# Generates Function.cpp and Function.h. These files provide the
# functional public C++ API, and the scaffolding to call into
# the dispatcher from these functions. See also compute_tensor_method.
def compute_function(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
@with_native_function
def go(f: NativeFunction) -> Optional[str]:
@dataclass(frozen=True)
class ComputeFunction:
target: Target

@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if f.manual_kernel_registration:
return None
if Variant.function not in f.variants:
Expand All @@ -326,13 +344,13 @@ def go(f: NativeFunction) -> Optional[str]:

sig_group = CppSignatureGroup.from_schema(f.func, method=False)

if target is Target.DECLARATION:
if self.target is Target.DECLARATION:
result = f"CAFFE2_API {sig_group.signature.decl()};\n"
if sig_group.faithful_signature is not None:
result += f"CAFFE2_API {sig_group.faithful_signature.decl()};\n"
return result

assert target is Target.DEFINITION
assert self.target is Target.DEFINITION

def generate_defn(sig: CppSignature) -> str:
dispatcher_sig = DispatcherSignature.from_schema(f.func)
Expand All @@ -357,14 +375,15 @@ def generate_defn(sig: CppSignature) -> str:

return result

return go

# Generates TensorBody.h (sic) and TensorMethods.cpp. These files provide the
# object-oriented (method-based) public C++ API, and the scaffolding to call into
# the dispatcher from these functions. See also compute_function.
def compute_tensor_method(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
@with_native_function
def go(f: NativeFunction) -> Optional[str]:
@dataclass(frozen=True)
class ComputeTensorMethod:
target: Target

@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if Variant.method not in f.variants:
return None

Expand All @@ -376,13 +395,13 @@ def go(f: NativeFunction) -> Optional[str]:

sig_group = CppSignatureGroup.from_schema(f.func, method=True)

if target is Target.DECLARATION:
if self.target is Target.DECLARATION:
result = f"{sig_group.signature.decl()} const;\n"
if sig_group.faithful_signature is not None:
result += f"{sig_group.faithful_signature.decl()} const;\n"
return result

assert target is Target.DEFINITION
assert self.target is Target.DEFINITION

def generate_defn(sig: CppSignature) -> str:
dispatcher_sig = DispatcherSignature.from_schema(f.func)
Expand All @@ -406,8 +425,6 @@ def generate_defn(sig: CppSignature) -> str:

return result

return go

# Generates ATenOpList.cpp, a runtime accessible list of all aten
# operators.
# TODO: This was historically used to help some JIT interop code
Expand Down Expand Up @@ -442,9 +459,12 @@ def compute_native_function_declaration(f: NativeFunction) -> List[str]:
# Generates BackendSelectRegister.cpp, a series of kernels which provide
# specialized computation of dispatch key for operator signatures which cannot
# be easily done automatically using templating.
def compute_backend_select(*, target: Target) -> Callable[[NativeFunction], Optional[str]]:
@with_native_function
def go(f: NativeFunction) -> Optional[str]:
@dataclass(frozen=True)
class ComputeBackendSelect:
target: Target

@method_with_native_function
def __call__(self, f: NativeFunction) -> Optional[str]:
if str(f.func.name.name).endswith('_like') or str(f.func.name.name).startswith('new_'):
return None

Expand All @@ -471,7 +491,7 @@ def go(f: NativeFunction) -> Optional[str]:
dispatcher_exprs = native_sig.dispatcher_exprs()
dispatch_key = "options.computeDispatchKey()"

if target is Target.DEFINITION:
if self.target is Target.DEFINITION:
# I don't think there's actually a good reason to generate
# these two cases differently
# The first case could probably be improved though- it calls dispatchTypeId(),
Expand All @@ -494,7 +514,7 @@ def go(f: NativeFunction) -> Optional[str]:
return op.callWithDispatchKey(_dk, {', '.join(a.expr for a in dispatcher_exprs)});
}}
"""
elif target is Target.REGISTRATION:
elif self.target is Target.REGISTRATION:
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
return f"""m.impl("aten::{f.func.name}", TORCH_FN({name}));"""
elif local.use_c10_dispatcher() is UseC10Dispatcher.hacky_wrapper_for_legacy_signatures:
Expand All @@ -504,11 +524,10 @@ def go(f: NativeFunction) -> Optional[str]:
else:
assert local.use_c10_dispatcher() is UseC10Dispatcher.with_codegenerated_unboxing_wrapper
return f"""m.impl_UNBOXED("aten::{f.func.name}", {name});"""
elif target is Target.DECLARATION:
elif self.target is Target.DECLARATION:
raise AssertionError()
else:
assert_never(target)
return go
assert_never(self.target)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
#
Expand Down Expand Up @@ -993,12 +1012,11 @@ def make_file_manager(install_dir: str) -> FileManager:
'',
'Backend': dispatch,
'type_derived_method_definitions': list(mapMaybe(
compute_type_method(dispatch, target=Target.DEFINITION, selector=selector),
ComputeTypeMethod(dispatch, Target.DEFINITION, selector),
native_functions
)),
'function_registrations': list(mapMaybe(
compute_type_method(
dispatch, target=Target.REGISTRATION, selector=selector),
ComputeTypeMethod(dispatch, Target.REGISTRATION, selector),
native_functions
)),
})
Expand All @@ -1012,35 +1030,35 @@ def make_file_manager(install_dir: str) -> FileManager:
cpu_fm.write('TypeDefault.cpp', lambda: {
'type_method_definitions':
list(mapMaybe(
compute_type_method('Math', target=Target.DEFINITION, selector=selector),
ComputeTypeMethod('Math', Target.DEFINITION, selector),
native_functions)) +
list(mapMaybe(
compute_type_method('DefaultBackend', target=Target.DEFINITION, selector=selector),
ComputeTypeMethod('DefaultBackend', Target.DEFINITION, selector),
native_functions)),

'function_registrations': list(mapMaybe(
compute_type_method(None, target=Target.REGISTRATION, selector=schema_selector),
ComputeTypeMethod(None, Target.REGISTRATION, schema_selector),
native_functions)),

'math_function_registrations': list(mapMaybe(
compute_type_method('Math', target=Target.REGISTRATION, selector=selector),
ComputeTypeMethod('Math', Target.REGISTRATION, selector),
native_functions)),

'default_backend_function_registrations': list(mapMaybe(
compute_type_method('DefaultBackend', target=Target.REGISTRATION, selector=selector),
ComputeTypeMethod('DefaultBackend', Target.REGISTRATION, selector),
native_functions)),
})
cpu_fm.write('Functions.h', lambda: {
'function_declarations': list(mapMaybe(compute_function(target=Target.DECLARATION), native_functions)),
'function_declarations': list(mapMaybe(ComputeFunction(Target.DECLARATION), native_functions)),
})
cpu_fm.write('Functions.cpp', lambda: {
'function_definitions': list(mapMaybe(compute_function(target=Target.DEFINITION), native_functions)),
'function_definitions': list(mapMaybe(ComputeFunction(Target.DEFINITION), native_functions)),
})
core_fm.write('TensorBody.h', lambda: {
'tensor_method_declarations': list(mapMaybe(compute_tensor_method(target=Target.DECLARATION), native_functions)),
'tensor_method_declarations': list(mapMaybe(ComputeTensorMethod(Target.DECLARATION), native_functions)),
})
core_fm.write('TensorMethods.cpp', lambda: {
'tensor_method_definitions': list(mapMaybe(compute_tensor_method(target=Target.DEFINITION), native_functions)),
'tensor_method_definitions': list(mapMaybe(ComputeTensorMethod(Target.DEFINITION), native_functions)),
})
core_fm.write('ATenOpList.cpp', lambda: {
'aten_ops': list(mapMaybe(compute_aten_op, native_functions)),
Expand All @@ -1050,9 +1068,9 @@ def make_file_manager(install_dir: str) -> FileManager:
})
cpu_fm.write('BackendSelectRegister.cpp', lambda: {
'backend_select_method_definitions':
list(mapMaybe(compute_backend_select(target=Target.DEFINITION), native_functions)),
list(mapMaybe(ComputeBackendSelect(Target.DEFINITION), native_functions)),
'backend_select_function_registrations':
list(mapMaybe(compute_backend_select(target=Target.REGISTRATION), native_functions)),
list(mapMaybe(ComputeBackendSelect(Target.REGISTRATION), native_functions)),
})

cpu_fm.write('Declarations.yaml', lambda: format_yaml([compute_declaration_yaml(f) for f in native_functions]))
Expand Down

0 comments on commit 0c64f9f

Please sign in to comment.