Skip to content

Commit

Permalink
Add faithful C++ API
Browse files Browse the repository at this point in the history
Pull Request resolved: #44087

Each op taking a TensorOptions argument now has an additional overload in the C++ frontend where it takes scattered ScalarType, Layout, Device, bool instead of one TensorOptions argument.

If it is a c10-full op, then the scattered version calls into the dispatcher and the gathered version is a proxy calling into the scattered version.
If it is a non-c10-full op, then the gathered version calls into the dispatcher and the scattered version is a proxy calling into the gathered version.

This should minimize the amount of gathering and scattering needed.

This PR is also a prerequisite to remove the re-gathering of arguments that is currently happening in VariableKernel. Currently, VariableKernels gather arguments into a TensorOptions object
to call into the C++ API. In a PR stacked on top of this, VariableKernel will just directly call into the scattered C++ API introduced here and avoid the gathering step.
ghstack-source-id: 112810177

Differential Revision: [D23492188](https://our.internmc.facebook.com/intern/diff/D23492188/)
  • Loading branch information
smessmer committed Sep 24, 2020
1 parent 525e212 commit df4349e
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 68 deletions.
68 changes: 36 additions & 32 deletions tools/codegen/api/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def returns_type(rs: Sequence[Return]) -> str:
'[]': '{}',
'[0,1]': '{0,1}', # TODO: stop special casing
'contiguous_format': 'MemoryFormat::Contiguous',
'long': 'at::kLong',
}

# Convert a JIT default into C++ expression representing the default
Expand Down Expand Up @@ -194,8 +195,8 @@ def argument(a: Union[Argument, TensorOptionsArguments, ThisArgument]) -> CppArg
else:
assert_never(a)

def group_arguments(
func: FunctionSchema, *, method: bool = False
def arguments(
func: FunctionSchema, *, method: bool = False, gathered: bool = False,
) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]:
args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = []
args.extend(func.out_arguments)
Expand All @@ -205,37 +206,40 @@ def group_arguments(
else:
args.extend(func.arguments)

# group up arguments for tensor options

def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
predicates = [ # order matters
pred('dtype', Type.parse('ScalarType')),
pred('layout', Type.parse('Layout')),
pred('device', Type.parse('Device')),
pred('pin_memory', Type.parse('bool')),
]

i = 0
while i < len(func.kwarg_only_arguments):
# If there is enough space...
if i <= len(func.kwarg_only_arguments) - len(predicates):
# And the next len(predicates) arguments look like TensorOptions arguments
if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])):
# Group them together as one argument
args.append(TensorOptionsArguments(
dtype=func.kwarg_only_arguments[i],
layout=func.kwarg_only_arguments[i + 1],
device=func.kwarg_only_arguments[i + 2],
pin_memory=func.kwarg_only_arguments[i + 3],
))
i += len(predicates)
continue
args.append(func.kwarg_only_arguments[i])
i += 1
if gathered:
# group up arguments for tensor options

def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
predicates = [ # order matters
pred('dtype', Type.parse('ScalarType')),
pred('layout', Type.parse('Layout')),
pred('device', Type.parse('Device')),
pred('pin_memory', Type.parse('bool')),
]

i = 0
while i < len(func.kwarg_only_arguments):
# If there is enough space...
if i <= len(func.kwarg_only_arguments) - len(predicates):
# And the next len(predicates) arguments look like TensorOptions arguments
if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])):
# Group them together as one argument
args.append(TensorOptionsArguments(
dtype=func.kwarg_only_arguments[i],
layout=func.kwarg_only_arguments[i + 1],
device=func.kwarg_only_arguments[i + 2],
pin_memory=func.kwarg_only_arguments[i + 3],
))
i += len(predicates)
continue
args.append(func.kwarg_only_arguments[i])
i += 1
else:
args.extend(func.kwarg_only_arguments)

return args

# Convert arguments to C++ API form
def arguments(func: FunctionSchema, *, method: bool = False) -> Sequence[CppArgument]:
return list(map(argument, group_arguments(func, method=method)))
def cpp_arguments(func: FunctionSchema, *, method: bool = False, gathered: bool = False) -> Sequence[CppArgument]:
return list(map(argument, arguments(func, method=method, gathered=gathered)))
32 changes: 24 additions & 8 deletions tools/codegen/api/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,20 +72,30 @@ def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
for la in legacy_dispatcher.arguments(func)
]

class ProcessTensoroptions:
GATHER = 0
SCATTER = 1
PASS_THROUGH = 2

# Given a set of CppArguments in scope, return a sequence of dispatcher
# expressions that translate the cpp API into dispatcher API
def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument]) -> Sequence[DispatcherExpr]:
def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument], process_tensoroptions: ProcessTensoroptions = ProcessTensoroptions.PASS_THROUGH) -> Sequence[DispatcherExpr]:
if isinstance(a.argument, TensorOptionsArguments):
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
if process_tensoroptions == ProcessTensoroptions.SCATTER:
ta = a.argument
return [
DispatcherExpr(type=argument_type(ta.dtype), expr=f'optTypeMetaToScalarType({a.name}.dtype_opt())'),
DispatcherExpr(type=argument_type(ta.layout), expr=f'{a.name}.layout_opt()'),
DispatcherExpr(type=argument_type(ta.device), expr=f'{a.name}.device_opt()'),
DispatcherExpr(type=argument_type(ta.pin_memory), expr=f'{a.name}.pinned_memory_opt()'), # weird discrep
]
elif process_tensoroptions == ProcessTensoroptions.GATHER:
return [DispatcherExpr(type='const TensorOptions &', expr="TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory)")]
else:
assert process_tensoroptions == ProcessTensoroptions.PASS_THROUGH
return [DispatcherExpr(type='const TensorOptions &', expr=a.name)]
elif isinstance(a.argument, ThisArgument):
return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)]
elif isinstance(a.argument, Argument):
if a.name == 'memory_format' and tensor_options is not None and local.use_c10_dispatcher() is UseC10Dispatcher.full:
return [DispatcherExpr(
Expand All @@ -94,19 +104,25 @@ def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument])
]
else:
return [DispatcherExpr(type=argument_type(a.argument), expr=a.name)]
elif isinstance(a.argument, ThisArgument):
return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)]
else:
assert_never(a.argument)

def cpparguments_exprs(args: Sequence[CppArgument]) -> Sequence[DispatcherExpr]:
def cpparguments_exprs(args: Sequence[CppArgument], process_tensoroptions: ProcessTensoroptions) -> Sequence[DispatcherExpr]:
tensor_options = next((a for a in args if isinstance(a.argument, TensorOptionsArguments)), None)
return [r for a in args for r in cppargument_exprs(a, tensor_options=tensor_options)]
return [r for a in args for r in cppargument_exprs(a, tensor_options=tensor_options, process_tensoroptions=process_tensoroptions)]

# I don't think this is entirely sound, but it should be reasonably
# close
def legacydispatcherarguments_exprs(args: Sequence[LegacyDispatcherArgument]) -> Sequence[DispatcherExpr]:
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args])
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
process_tensoroptions = ProcessTensoroptions.SCATTER
else:
process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args], process_tensoroptions=process_tensoroptions)

def exprs(args: Sequence[DispatcherArgument]) -> Sequence[DispatcherExpr]:
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args])
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
process_tensoroptions = ProcessTensoroptions.SCATTER
else:
process_tensoroptions = ProcessTensoroptions.PASS_THROUGH
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args], process_tensoroptions=process_tensoroptions)
2 changes: 1 addition & 1 deletion tools/codegen/api/legacy_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,4 @@ def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> Legacy
assert_never(a)

def arguments(func: FunctionSchema) -> Sequence[LegacyDispatcherArgument]:
return list(map(argument, cpp.group_arguments(func)))
return list(map(argument, cpp.arguments(func, gathered=True)))

0 comments on commit df4349e

Please sign in to comment.