Skip to content

Commit

Permalink
Add faithful C++ API
Browse files Browse the repository at this point in the history
Pull Request resolved: #44087

Each op taking a TensorOptions argument now has an additional overload in the C++ frontend where it takes scattered ScalarType, Layout, Device, bool instead of one TensorOptions argument.

If it is a c10-full op, then the scattered version calls into the dispatcher and the gathered version is a proxy calling into the scattered version.
If it is a non-c10-full op, then the gathered version calls into the dispatcher and the scattered version is a proxy calling into the gathered version.

This should minimize the amount of gathering and scattering needed.

This PR is also a prerequisite to remove the re-gathering of arguments that is currently happening in VariableKernel. Currently, VariableKernels gather arguments into a TensorOptions object
to call into the C++ API. In a PR stacked on top of this, VariableKernel will just directly call into the scattered C++ API introduced here and avoid the gathering step.
ghstack-source-id: 113235801

Differential Revision: [D23492188](https://our.internmc.facebook.com/intern/diff/D23492188/)
  • Loading branch information
smessmer committed Sep 30, 2020
1 parent 56af122 commit 100f126
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 47 deletions.
98 changes: 88 additions & 10 deletions tools/codegen/api/cpp.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from tools.codegen.model import *
from tools.codegen.api.types import TensorOptionsArguments, CppArgument, ThisArgument
import tools.codegen.local as local
from tools.codegen.api import dispatcher
from typing import Optional, Sequence, Union, Callable, List
import copy

# This file describes the translation of JIT schema to the public C++
# API, which is what people use when they call functions like at::add.
Expand Down Expand Up @@ -152,6 +154,7 @@ def returns_type(rs: Sequence[Return]) -> str:
'[]': '{}',
'[0,1]': '{0,1}', # TODO: stop special casing
'contiguous_format': 'MemoryFormat::Contiguous',
'long': 'at::kLong',
}

# Convert a JIT default into C++ expression representing the default
Expand Down Expand Up @@ -191,9 +194,73 @@ def argument(a: Union[Argument, TensorOptionsArguments, ThisArgument]) -> CppArg
else:
assert_never(a)

def group_arguments(
func: FunctionSchema, *, method: bool = False
) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]:
@dataclass(frozen=True)
class CppSignature:
returns: Sequence[Return]
arguments: Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]

def cpp_arguments(self) -> Sequence[CppArgument]:
return list(map(argument, self.arguments))

# Return arguments as a comma separated list, i.e. like they would be in a C++
# function signature. Include default values for arguments.
def cpp_arguments_str(self, with_defaults: bool) -> str:
args_without_this = [argument(a) for a in self.arguments if not isinstance(a, ThisArgument)]
if with_defaults:
return ', '.join(map(str, args_without_this))
else:
return ', '.join(map(lambda s: s.str_no_default(), args_without_this))

# Return a string with a comma separated list of expressions that could be used
# to call this operator. This can be used to generate code that wraps operators
# and calls back into them. The process_tensoroptions argument determines how
# tensor options should be treated. They can be
# - PASS_THROUGH: Don't do anything, just handle them as regular arguments
# - SCATTER: Expect a `TensorOptions options` in the scope and scatter it into `options.dtype, ...`
# - GATHER: Expect `dtype, ...` in the scope and gather them into a TensorOptions for calling
def exprs_str(self,
process_tensoroptions: dispatcher.ProcessTensoroptions = dispatcher.ProcessTensoroptions.PASS_THROUGH,
exclude_this: bool = False,
) -> str:
args = self.arguments
if exclude_this:
args = [a for a in args if not isinstance(a, ThisArgument)]
cpp_args = list(map(argument, args))
exprs = dispatcher.cpparguments_exprs(cpp_args, process_tensoroptions=process_tensoroptions)
return ', '.join(map(lambda a: a.expr, exprs))

def types_str(self) -> str:
args = self.cpp_arguments()
exprs = dispatcher.cpparguments_exprs(args, process_tensoroptions=dispatcher.ProcessTensoroptions.PASS_THROUGH)
return ', '.join(map(lambda a: a.type, exprs))


@dataclass(frozen=True)
class CppSignatureGroup:
# arguments contains the arguments for the C++ signature as it is represented
# in the JIT schema.
signature: CppSignature

# gathered_signature is an alternative C++ signature in which TensorOptions are
# gathered into one TensorOptions object instead of being scattered into
# ScalarType, Layout, Device. This is only present for factory operators,
# other operators have this set to None. This can be used to generate a
# convenience API in the C++ frontend so users can call using TensorOptions objects.
gathered_signature: Optional[CppSignature]

# If it is a factory op, this returns the arguments for the convenience API
# that takes TensorOptions. If it is not a factory op and doesn't have
# a gathered signature, then this returns the regular signature instead.
def signature_prefer_gathered(self) -> CppSignature:
if self.gathered_signature is not None:
return self.gathered_signature
else:
return self.signature


def signature_group(
func: FunctionSchema, *, method: bool = False,
) -> CppSignatureGroup:
args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = []
args.extend(func.out_arguments)

Expand All @@ -202,8 +269,9 @@ def group_arguments(
else:
args.extend(func.arguments)

# group up arguments for tensor options
gathered_args = copy.deepcopy(args)

# group up arguments for tensor options
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
predicates = [ # order matters
Expand All @@ -213,26 +281,36 @@ def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
pred('pin_memory', Type.parse('bool')),
]

has_tensoroptions_argument = False
i = 0
while i < len(func.kwarg_only_arguments):
# If there is enough space...
if i <= len(func.kwarg_only_arguments) - len(predicates):
# And the next len(predicates) arguments look like TensorOptions arguments
if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])):
has_tensoroptions_argument = True
# Group them together as one argument
args.append(TensorOptionsArguments(
gathered_args.append(TensorOptionsArguments(
dtype=func.kwarg_only_arguments[i],
layout=func.kwarg_only_arguments[i + 1],
device=func.kwarg_only_arguments[i + 2],
pin_memory=func.kwarg_only_arguments[i + 3],
))
i += len(predicates)
continue
args.append(func.kwarg_only_arguments[i])
gathered_args.append(func.kwarg_only_arguments[i])
i += 1

return args
args.extend(func.kwarg_only_arguments)

# Convert arguments to C++ API form
def arguments(func: FunctionSchema, *, method: bool = False) -> Sequence[CppArgument]:
return list(map(argument, group_arguments(func, method=method)))
if has_tensoroptions_argument:
return CppSignatureGroup(
signature=CppSignature(arguments=args, returns=func.returns),
gathered_signature=CppSignature(arguments=gathered_args, returns=func.returns),
)
else:
assert gathered_args == args
return CppSignatureGroup(
signature=CppSignature(arguments=args, returns=func.returns),
gathered_signature=None,
)
50 changes: 41 additions & 9 deletions tools/codegen/api/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from tools.codegen.api.types import CppArgument, DispatcherExpr, TensorOptionsArguments, \
DispatcherArgument, ThisArgument, LegacyDispatcherArgument
import tools.codegen.api.cpp as cpp
from tools.codegen.api import cpp
import tools.codegen.api.legacy_dispatcher as legacy_dispatcher
import tools.codegen.local as local

Expand Down Expand Up @@ -75,20 +75,36 @@ def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
for la in legacy_dispatcher.arguments(func)
]

# TODO GATHER is only needed for non-c10-full ops, remove later.
ProcessTensoroptions = Enum('ProcessTensoroptions', ('GATHER', 'SCATTER', 'PASS_THROUGH'))


# Given a set of CppArguments in scope, return a sequence of dispatcher
# expressions that translate the cpp API into dispatcher API
def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument]) -> Sequence[DispatcherExpr]:
def cppargument_exprs(a: CppArgument,
*,
tensor_options: Optional[CppArgument],
process_tensoroptions: ProcessTensoroptions = ProcessTensoroptions.PASS_THROUGH
) -> Sequence[DispatcherExpr]:
if isinstance(a.argument, TensorOptionsArguments):
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
if process_tensoroptions == ProcessTensoroptions.SCATTER:
ta = a.argument
return [
DispatcherExpr(type=argument_type(ta.dtype), expr=f'optTypeMetaToScalarType({a.name}.dtype_opt())'),
DispatcherExpr(type=argument_type(ta.layout), expr=f'{a.name}.layout_opt()'),
DispatcherExpr(type=argument_type(ta.device), expr=f'{a.name}.device_opt()'),
DispatcherExpr(type=argument_type(ta.pin_memory), expr=f'{a.name}.pinned_memory_opt()'), # weird discrep
]
elif process_tensoroptions == ProcessTensoroptions.GATHER:
return [
DispatcherExpr(
type='const TensorOptions &',
expr="TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)")]
else:
assert process_tensoroptions == ProcessTensoroptions.PASS_THROUGH
return [DispatcherExpr(type='const TensorOptions &', expr=a.name)]
elif isinstance(a.argument, ThisArgument):
return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)]
elif isinstance(a.argument, Argument):
if a.name == 'memory_format' and tensor_options is not None and local.use_c10_dispatcher() is UseC10Dispatcher.full:
return [DispatcherExpr(
Expand All @@ -97,19 +113,35 @@ def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument])
]
else:
return [DispatcherExpr(type=argument_type(a.argument), expr=a.name)]
elif isinstance(a.argument, ThisArgument):
return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)]
else:
assert_never(a.argument)

def cpparguments_exprs(args: Sequence[CppArgument]) -> Sequence[DispatcherExpr]:
def cpparguments_exprs(args: Sequence[CppArgument], process_tensoroptions: ProcessTensoroptions) -> Sequence[DispatcherExpr]:
tensor_options = next((a for a in args if isinstance(a.argument, TensorOptionsArguments)), None)
return [r for a in args for r in cppargument_exprs(a, tensor_options=tensor_options)]
return [r for a in args for r in cppargument_exprs(a,
tensor_options=tensor_options,
process_tensoroptions=process_tensoroptions)]

# I don't think this is entirely sound, but it should be reasonably
# close
def legacydispatcherarguments_exprs(args: Sequence[LegacyDispatcherArgument]) -> Sequence[DispatcherExpr]:
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args])
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
process_tensoroptions = ProcessTensoroptions.SCATTER
else:
process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
return cpparguments_exprs([CppArgument(type=a.type,
name=a.name,
default=None,
argument=a.argument) for a in args],
process_tensoroptions=process_tensoroptions)

def exprs(args: Sequence[DispatcherArgument]) -> Sequence[DispatcherExpr]:
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args])
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
process_tensoroptions = ProcessTensoroptions.SCATTER
else:
process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
return cpparguments_exprs([CppArgument(type=a.type,
name=a.name,
default=None,
argument=a.argument) for a in args],
process_tensoroptions=process_tensoroptions)
4 changes: 3 additions & 1 deletion tools/codegen/api/legacy_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,6 @@ def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> Legacy
assert_never(a)

def arguments(func: FunctionSchema) -> Sequence[LegacyDispatcherArgument]:
return list(map(argument, cpp.group_arguments(func)))
signature_group = cpp.signature_group(func)
args = signature_group.signature_prefer_gathered().arguments
return list(map(argument, args))

0 comments on commit 100f126

Please sign in to comment.