Skip to content

Commit

Permalink
Register DefaultBackend implementations for functional/inplace struct…
Browse files Browse the repository at this point in the history
…ured operators

As remarked in #52277 it is easy to give an (inefficient, due to extra
redispatches) DefaultBackend implementation of foo and foo_ in terms of
foo_out.  This patch enables code generation for DefaultBackend in these
cases by default for all structured kernels.  You can see the payoff
in MSNPU extension: it only has to register a kernel for add.out, and it
gets add and add_ kernels automatically.

The actual code changes are very modest:
- When DefaultBackend, call the dispatched (not direct native::)
  functions to allocate tensors, change device guard, etc
- Don't call impl() for DefaultBackend (as it doesn't exist); instead,
  directly generate a call to at::foo_out to do the actual work.
- Do NOT generate DefaultBackend implementation for foo_out.  Actually,
  there is a case to be made for this being a good idea with more infra;
  see comments inside.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

ghstack-source-id: e45111af6bf9c908c894c0576739a9770e1123b2
Pull Request resolved: #53037
  • Loading branch information
ezyang committed Mar 1, 2021
1 parent b8861bb commit c1bf7a5
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 10 deletions.
6 changes: 3 additions & 3 deletions test/cpp_extensions/msnpu_extension.cpp
Expand Up @@ -26,9 +26,9 @@ Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::op
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
}

Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
Tensor& add_out_override(const Tensor & a, const Tensor & b , Scalar c, Tensor & out) {
test_int = 1;
return get_tensor(a.dtype(), a.sizes());
return out;
}

Tensor fake_convolution(
Expand All @@ -54,7 +54,7 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(

TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
m.impl("empty.memory_format", empty_override);
m.impl("add.Tensor", add_override);
m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution);
m.impl("convolution_backward_overrideable", fake_convolution_backward);
}
Expand Down
6 changes: 6 additions & 0 deletions tools/codegen/api/types.py
Expand Up @@ -188,6 +188,12 @@ class CppSignatureGroup:
signature: CppSignature
faithful_signature: Optional[CppSignature]

def most_faithful_signature(self) -> CppSignature:
if self.faithful_signature:
return self.faithful_signature
else:
return self.signature

@staticmethod
def from_native_function(f: NativeFunction, *, method: bool, fallback_binding: bool = False) -> 'CppSignatureGroup':
func = f.func
Expand Down
70 changes: 63 additions & 7 deletions tools/codegen/dest/register_dispatch_key.py
Expand Up @@ -64,6 +64,10 @@ def gen_structured(self, g: StructuredNativeFunctions) -> List[str]:
assert self.dispatch_key not in g.out.dispatch, \
"Do not explicitly specify Meta dispatch key on structured " \
"functions, they will be automatically generated for you"
elif self.dispatch_key == DispatchKey.DefaultBackend:
assert self.dispatch_key not in g.out.dispatch, \
"Do not explicitly specify DefaultBackend dispatch key on structured " \
"functions, they will be automatically generated for you"
elif not is_structured_dispatch_key(self.dispatch_key):
return list(mapMaybe(self.gen_unstructured, g.functions()))
elif self.dispatch_key not in g.out.dispatch:
Expand Down Expand Up @@ -226,14 +230,14 @@ def gen_class_set_output(self, k: SchemaKind, parent_class: str, generate_super:
"""

def gen_class_set_output_body(self, k: SchemaKind) -> str:
if self.dispatch_key == DispatchKey.CUDA:
if self.dispatch_key in [DispatchKey.CUDA, DispatchKey.DefaultBackend]:
maybe_set_guard = """
auto current_device = guard_.current_device();
if (C10_UNLIKELY(current_device.has_value())) {
TORCH_INTERNAL_ASSERT(*current_device == options.device(),
"structured kernels don't support multi-device outputs");
} else {
guard_.set_device(options.device());
guard_.reset_device(options.device());
}
"""
else:
Expand All @@ -257,6 +261,9 @@ def gen_class_set_output_body(self, k: SchemaKind) -> str:
elif self.dispatch_key == DispatchKey.CUDA:
empty_impl = "at::native::empty_cuda"
empty_strided_impl = "at::native::empty_strided_cuda"
elif self.dispatch_key == DispatchKey.DefaultBackend:
empty_impl = "at::empty"
empty_strided_impl = "at::empty_strided"
else:
raise AssertionError("unsupported dispatch key")
return f"""
Expand Down Expand Up @@ -314,6 +321,8 @@ def gen_class(
guard_field = 'c10::hip::OptionalHIPGuardMasqueradingAsCUDA guard_;'
else:
guard_field = 'c10::cuda::OptionalCUDAGuard guard_;'
elif self.dispatch_key == DispatchKey.DefaultBackend:
guard_field = 'c10::OptionalDeviceGuard guard_;'
else:
guard_field = ''

Expand All @@ -336,6 +345,22 @@ def gen_one(self, f: NativeFunction) -> Optional[str]:
if self.target is Target.REGISTRATION and not self.selector.is_native_function_selected(f):
return None

# TODO: Now, there is something interesting going on here. In the code below,
# we generate DefaultBackend implementations of functional and inplace
# based on the out implementation. But in fact, out is definable by
# functional too (just not very efficiently), and this is honestly the
# MORE likely situation for a backend implementor. How do we pick?
# Well, taking a page from Haskell type classes and default methods,
# we could conceivably register a circular definition (out in terms
# of functional, and functional in terms of out) and just require
# someone to implement one or the other. We'd have to do a little bit
# of work to not register one of these "weak" definitions unless there
# is a strong definition somewhere in the DAG! So it's not implemented yet.
if self.dispatch_key == DispatchKey.DefaultBackend and f.func.kind() is SchemaKind.out:
# Never generate a default implementation for out, that's what you
# have to define as a backend implementor
return None

# Note [Direct dispatch bindings]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Signature of the non-dispatched function we'll expose in a header
Expand Down Expand Up @@ -377,9 +402,13 @@ def generate_defn(cpp_sig: CppSignature) -> str:

# Initialize the class corresponding to this structured
# operator; feeding it the output argument(s) if it is known
if self.dispatch_key == DispatchKey.Meta:
if self.dispatch_key is DispatchKey.Meta:
class_name = f"structured_{meta.name(self.g)}_meta_{k.name}"
parent_class = f"at::meta::{meta.name(self.g)}"
elif self.dispatch_key is DispatchKey.DefaultBackend:
# TODO: dedup this branch
class_name = f"structured_{meta.name(self.g)}_default_backend_{k.name}"
parent_class = f"at::meta::{meta.name(self.g)}"
else:
class_name = f"structured_{self.g.out.dispatch[self.dispatch_key]}_{k.name}"
parent_class = f"at::native::structured_{self.g.out.dispatch[self.dispatch_key]}"
Expand Down Expand Up @@ -407,14 +436,43 @@ def generate_defn(cpp_sig: CppSignature) -> str:
# After running meta, op.outputs_ is guaranteed to be valid;
# add it to the context
# TODO: handle multi-return
assert ConstRefCType(BaseCType("Tensor", structured.out_arguments(self.g)[0].ctype.name)) == \
structured.out_arguments(self.g)[0].ctype
context.append(Expr(
expr="op.outputs_[0]",
type=structured.out_arguments(self.g)[0].ctype,
# TODO: Stop hardcoding that the output type is a Tensor. Note
# that for the codegen here this is fine because outputs_ is
# hardcoded to be tensor already
type=MutRefCType(BaseCType("Tensor", structured.out_arguments(self.g)[0].ctype.name)),
))

# With the expanded context, do the impl call (if not a meta
# function)
if self.dispatch_key != DispatchKey.Meta:
if self.dispatch_key == DispatchKey.DefaultBackend:
# TODO: https://github.com/pytorch/pytorch/issues/53023
out_sig_group = CppSignatureGroup.from_native_function(
self.g.out, method=False, fallback_binding=f.manual_cpp_binding)
out_sig = out_sig_group.most_faithful_signature()
api_name = out_sig.name()
out_exprs = ', '.join(
e.expr for e in translate(
context,
out_sig.arguments(),
method=False
)
)
# TODO: I think this means structured won't work with method
# only functions (but maybe you're saved by faithful? iunno.)
# NB: Originally I wrote this as an at::redispatch call, but
# I got in trouble because that meant I needed a DispatchKeySet
# in the wrapper function, which meant I needed a DispatchKeySet
# in the DispatchKeyFunctions declarations, but the defined API
# there does NOT permit a dispatch key set. I think you can
# probably unwind this by calling some function to do the TLS
# fetch and get the DispatchKeySet when you don't have it, but
# I didn't do it for this version
sig_body.append(f"at::{api_name}({out_exprs});")
elif self.dispatch_key != DispatchKey.Meta:
impl_exprs = ', '.join(
e.expr for e in translate(
context,
Expand Down Expand Up @@ -453,8 +511,6 @@ def generate_defn(cpp_sig: CppSignature) -> str:
"""

elif self.target is Target.REGISTRATION:
dispatcher_sig = DispatcherSignature.from_schema(f.func)

assert local.use_c10_dispatcher() is UseC10Dispatcher.full
return f'm.impl("{f.func.name}", TORCH_FN({sig.name()}));'
else:
Expand Down

0 comments on commit c1bf7a5

Please sign in to comment.