Skip to content

Commit

Permalink
[AOTI] Refactor some fallback op util functions (#126182)
Browse files Browse the repository at this point in the history
Summary: Move some util functions for cpp kernel naming and missing arg filling from FallbackKernel to ExternKernel, since they are useful for ExternKernel in general.

Pull Request resolved: #126182
Approved by: https://github.com/chenyang78
ghstack dependencies: #126181
  • Loading branch information
desertfire authored and pytorchmergebot committed May 16, 2024
1 parent c5f926a commit 5792bc3
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 105 deletions.
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def g(args):
)

def get_c_shim_func_name(self, kernel):
if not config.abi_compatible:
if not config.abi_compatible or kernel.startswith("aoti_torch_"):
return kernel

assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'"
Expand Down
176 changes: 72 additions & 104 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3930,6 +3930,21 @@ def should_allocate(self):
return True


def get_aten_cpp_kernel_name(kernel):
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
# repeat_interleave(const at::Tensor & self, int64_t repeats,
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
if not isinstance(kernel, torch._ops.OpOverload) or kernel.namespace != "aten":
return None
opname = (
kernel.__name__.split(".")[0]
if kernel._overloadname == "default"
else kernel.__name__.replace(".", "_")
)
return f"at::_ops::{opname}::call"


@dataclasses.dataclass
class ExternKernel(InputsKernel):
constant_args: Tuple[Any, ...] = ()
Expand Down Expand Up @@ -3973,7 +3988,8 @@ def __init__(
self.kwargs = kwargs if kwargs else {}
self.output_view = output_view
self.python_kernel_name = python_kernel_name
self.cpp_kernel_name = cpp_kernel_name
# If cpp_kernel_name is None, we will try to construct it from op_overload
self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload)
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
self.op_overload = op_overload
self.collect_arg_kwarg_properties()
Expand Down Expand Up @@ -4016,6 +4032,40 @@ def collect_arg_kwarg_properties(self):
else {}
)

def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
# Previously, we want to maintain forward-compatibility by skipping
# default args in the serialized artifacts in fbcode. However,
# some of our shim interfaces require default values being set.
# Discussed with Sherlock offline and we decided to allow serializing
# default args into the C++ wrapper code for now. We will refine this
# part if we see real FC requirement. More details related to FC
# can be found at:
# https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
assert isinstance(args, (list, tuple))
if isinstance(args, tuple):
args = list(args)
assert self.arg_properties, "ExternKernel.arg_properties should not be empty"

n_args = len(args)
n_pos_args = len(self.arg_properties)
# For cpp wrapper, if some positional args are not provided, we need to check
# if they're in the kwargs or use their default value
if n_args < n_pos_args:
log.debug(
"%s has %d unprovided positional arguments. "
"Will check if they are in the keyword arguments or will use default values.",
self.op_overload,
n_pos_args - n_args,
)
for i in range(n_args, n_pos_args):
arg_name = self.arg_properties[i]["name"]
args.append(
kwargs[arg_name]
if arg_name in kwargs
else self.arg_properties[i]["default_value"]
)
return args

def decide_layout(self):
if isinstance(self.layout, FlexibleLayout):
self.apply_constraint()
Expand All @@ -4030,7 +4080,15 @@ def codegen(self, wrapper):
raise NotImplementedError

def get_kernel_name(self):
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name
return (
(
V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined]
if config.abi_compatible
else self.cpp_kernel_name
)
if V.graph.cpp_wrapper
else self.python_kernel_name
)

@staticmethod
def copy_input(x):
Expand Down Expand Up @@ -5128,25 +5186,7 @@ class ExternKernelNode:
}


def get_aten_cpp_kernel_name(kernel):
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
# repeat_interleave(const at::Tensor & self, int64_t repeats,
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
assert (
isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten"
), "Invalid aten kernel"
opname = (
kernel.__name__.split(".")[0]
if kernel._overloadname == "default"
else kernel.__name__.replace(".", "_")
)
return f"at::_ops::{opname}::call"


class FallbackKernel(ExternKernelAlloc):
args_default_value: List[Dict[str, Any]]

def __init__(
self,
layout,
Expand All @@ -5158,12 +5198,23 @@ def __init__(
*,
unbacked_bindings=None,
):
if (
kernel == aten.mul.Tensor
and len(tensor_args) == 1
and len(nontensor_args) == 1
):
# When aten.mul.Tensor's second arg is constant, cpp wrapper expects
# to call mul_Scalar. A more proper fix is to do it in decomposition.
# See https://github.com/pytorch/pytorch/issues/123478
kernel = aten.mul.Scalar

super().__init__(
layout,
tuple(tensor_args),
tuple(nontensor_args),
op_overload=kernel,
)

# We need output buffers for generating kernel arguments in the
# abi-compatible mode, where we retrieve outputs by pass each individual
# output through the abi-compatible interface.
Expand All @@ -5179,7 +5230,6 @@ def __init__(
),
), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
self.op_overload = kernel

self.unflatten_args = unflatten_args
self.kwargs = {} if kwargs is None else kwargs
V.graph.warn_fallback(self.python_kernel_name)
Expand Down Expand Up @@ -5341,41 +5391,6 @@ def is_not_write(arg):
self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr]

self.cpp_op_schema = get_cpp_op_schema(kernel)
self.init_args_default_value(kernel._schema)

def init_args_default_value(self, schema):
self.args_default_value = [
{
"name": x.name,
"type": x.real_type,
"value": x.default_value,
}
for x in schema.arguments
if not x.kwarg_only
]

def get_pos_arg_value(self, pos, kwargs):
# positional args may be provided in kwargs
pos_arg_name = self.args_default_value[pos]["name"]
if pos_arg_name in kwargs:
log.debug(
"Found argument %s with value %s from kwargs",
pos_arg_name,
kwargs[pos_arg_name],
)
return kwargs[pos_arg_name]

assert hasattr(
self, "args_default_value"
), "self.args_default_value has to be provided"
assert pos < len(
self.args_default_value
), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}"
arg_default_value = self.args_default_value[pos]["value"]
log.debug(
"Use default value %s for argument %s", arg_default_value, pos_arg_name
)
return arg_default_value

def codegen_args(self):
@dataclasses.dataclass
Expand All @@ -5388,24 +5403,14 @@ def __repr__(self):
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
args = self.fill_non_provided_args(args, kwargs)
args = [
V.graph.wrapper_code.val_to_cpp_arg_str(param.real_type, x)
for param, x in zip(self.op_overload._schema.arguments, args)
]
else:
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]

# Previously, we want to maintain forward-compatibility by skipping
# default args in the serialized artifacts in fbcode. However,
# some of our shim interfaces require default values being set.
# Discussed with Sherlock offline and we decided to allow serializing
# default args into the C++ wrapper code for now. We will refine this
# part if we see real FC requirement. More details related to FC
# can be found at:
# https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
if V.graph.cpp_wrapper and hasattr(self, "args_default_value"):
self.fill_non_provided_args(args, kwargs, convert_val_to_str=True)

# let self.codegen_kwargs handle kwargs
self.kwargs.update(kwargs)
return args
Expand Down Expand Up @@ -5441,30 +5446,6 @@ def get_mutation_names(self):
assert len(self.mutation_names) <= 1
return self.mutation_names

def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
assert isinstance(args, (list, tuple))
if isinstance(args, tuple):
args = list(args)
assert hasattr(self, "args_default_value")
n_args = len(args)
n_pos_args = len(self.args_default_value)
# For cpp wrapper, if some positional args are not provided, we need to check
# if they're in the kwargs or use their default value
if n_args < n_pos_args:
log.debug(
"%s has %d unprovided positional arguments. "
"Will check if they are in the keyword arguments or will use default values.",
self.op_overload,
n_pos_args - n_args,
)
pos_args = [
self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args)
]
if convert_val_to_str:
pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args]
args.extend(pos_args)
return args

# ProxyExecutor Design Note
# We export the ExternFallbackNodes (for custom ops) into a serialized file
# and run it with a host side proxy executor to address the ABI problem
Expand Down Expand Up @@ -5539,15 +5520,6 @@ def codegen(self, wrapper):
if kernel.namespace == "aten": # type: ignore[union-attr]
# Aten Fallback Ops
assert isinstance(kernel, torch._ops.OpOverload)

if (
kernel == aten.mul.Tensor
and len(self.inputs) == 1
and len(self.constant_args) == 1
):
# When aten.mul.Tensor's second arg is constant, cpp wrapper expects to call mul_Scalar
kernel = aten.mul.Scalar

if V.graph.cpp_wrapper:
if (
config.is_fbcode()
Expand All @@ -5562,10 +5534,6 @@ def codegen(self, wrapper):
)
self.use_runtime_dispatch = True
self.set_cpp_kernel(kernel)
else:
self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel)
schema = kernel._schema # type: ignore[union-attr]
self.init_args_default_value(schema)
else:
self.python_kernel_name = str(kernel)
elif kernel.namespace == "_quantized": # type: ignore[union-attr]
Expand Down

0 comments on commit 5792bc3

Please sign in to comment.