Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTI] Refactor some fallback op util functions #126182

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def g(args):
)

def get_c_shim_func_name(self, kernel):
if not config.abi_compatible:
if not config.abi_compatible or kernel.startswith("aoti_torch_"):
return kernel

assert "::" in kernel, "Cpp kernel name: " + kernel + " does not contain '::'"
Expand Down
176 changes: 72 additions & 104 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3930,6 +3930,21 @@ def should_allocate(self):
return True


def get_aten_cpp_kernel_name(kernel):
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
# repeat_interleave(const at::Tensor & self, int64_t repeats,
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
if not isinstance(kernel, torch._ops.OpOverload) or kernel.namespace != "aten":
return None
opname = (
kernel.__name__.split(".")[0]
if kernel._overloadname == "default"
else kernel.__name__.replace(".", "_")
)
return f"at::_ops::{opname}::call"


@dataclasses.dataclass
class ExternKernel(InputsKernel):
constant_args: Tuple[Any, ...] = ()
Expand Down Expand Up @@ -3973,7 +3988,8 @@ def __init__(
self.kwargs = kwargs if kwargs else {}
self.output_view = output_view
self.python_kernel_name = python_kernel_name
self.cpp_kernel_name = cpp_kernel_name
# If cpp_kernel_name is None, we will try to construct it from op_overload
self.cpp_kernel_name = cpp_kernel_name or get_aten_cpp_kernel_name(op_overload)
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
self.op_overload = op_overload
self.collect_arg_kwarg_properties()
Expand Down Expand Up @@ -4016,6 +4032,40 @@ def collect_arg_kwarg_properties(self):
else {}
)

def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
# Previously, we want to maintain forward-compatibility by skipping
# default args in the serialized artifacts in fbcode. However,
# some of our shim interfaces require default values being set.
# Discussed with Sherlock offline and we decided to allow serializing
# default args into the C++ wrapper code for now. We will refine this
# part if we see real FC requirement. More details related to FC
# can be found at:
# https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
assert isinstance(args, (list, tuple))
if isinstance(args, tuple):
args = list(args)
assert self.arg_properties, "ExternKernel.arg_properties should not be empty"

n_args = len(args)
n_pos_args = len(self.arg_properties)
# For cpp wrapper, if some positional args are not provided, we need to check
# if they're in the kwargs or use their default value
if n_args < n_pos_args:
log.debug(
"%s has %d unprovided positional arguments. "
"Will check if they are in the keyword arguments or will use default values.",
self.op_overload,
n_pos_args - n_args,
)
for i in range(n_args, n_pos_args):
arg_name = self.arg_properties[i]["name"]
args.append(
kwargs[arg_name]
if arg_name in kwargs
else self.arg_properties[i]["default_value"]
)
return args

def decide_layout(self):
if isinstance(self.layout, FlexibleLayout):
self.apply_constraint()
Expand All @@ -4030,7 +4080,15 @@ def codegen(self, wrapper):
raise NotImplementedError

def get_kernel_name(self):
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name
return (
(
V.graph.wrapper_code.get_c_shim_func_name(self.cpp_kernel_name) # type: ignore[attr-defined]
if config.abi_compatible
else self.cpp_kernel_name
)
if V.graph.cpp_wrapper
else self.python_kernel_name
)

@staticmethod
def copy_input(x):
Expand Down Expand Up @@ -5128,25 +5186,7 @@ class ExternKernelNode:
}


def get_aten_cpp_kernel_name(kernel):
# Calling with the default kernel name can lead to ambiguous behavior like the following example.
# repeat_interleave(const at::Tensor & repeats, c10::optional<int64_t> output_size=c10::nullopt)
# repeat_interleave(const at::Tensor & self, int64_t repeats,
# c10::optional<int64_t> dim=c10::nullopt, c10::optional<int64_t> output_size=c10::nullopt)
assert (
isinstance(kernel, torch._ops.OpOverload) and kernel.namespace == "aten"
), "Invalid aten kernel"
opname = (
kernel.__name__.split(".")[0]
if kernel._overloadname == "default"
else kernel.__name__.replace(".", "_")
)
return f"at::_ops::{opname}::call"


class FallbackKernel(ExternKernelAlloc):
args_default_value: List[Dict[str, Any]]

def __init__(
self,
layout,
Expand All @@ -5158,12 +5198,23 @@ def __init__(
*,
unbacked_bindings=None,
):
if (
kernel == aten.mul.Tensor
and len(tensor_args) == 1
and len(nontensor_args) == 1
):
# When aten.mul.Tensor's second arg is constant, cpp wrapper expects
# to call mul_Scalar. A more proper fix is to do it in decomposition.
# See https://github.com/pytorch/pytorch/issues/123478
kernel = aten.mul.Scalar

super().__init__(
layout,
tuple(tensor_args),
tuple(nontensor_args),
op_overload=kernel,
)

# We need output buffers for generating kernel arguments in the
# abi-compatible mode, where we retrieve outputs by pass each individual
# output through the abi-compatible interface.
Expand All @@ -5179,7 +5230,6 @@ def __init__(
),
), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
self.op_overload = kernel

self.unflatten_args = unflatten_args
self.kwargs = {} if kwargs is None else kwargs
V.graph.warn_fallback(self.python_kernel_name)
Expand Down Expand Up @@ -5341,41 +5391,6 @@ def is_not_write(arg):
self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr]

self.cpp_op_schema = get_cpp_op_schema(kernel)
self.init_args_default_value(kernel._schema)

def init_args_default_value(self, schema):
self.args_default_value = [
{
"name": x.name,
"type": x.real_type,
"value": x.default_value,
}
for x in schema.arguments
if not x.kwarg_only
]

def get_pos_arg_value(self, pos, kwargs):
# positional args may be provided in kwargs
pos_arg_name = self.args_default_value[pos]["name"]
if pos_arg_name in kwargs:
log.debug(
"Found argument %s with value %s from kwargs",
pos_arg_name,
kwargs[pos_arg_name],
)
return kwargs[pos_arg_name]

assert hasattr(
self, "args_default_value"
), "self.args_default_value has to be provided"
assert pos < len(
self.args_default_value
), f"expected the index {pos} to be smaller than len(self.args_default_value): {len(self.args_default_value)}"
arg_default_value = self.args_default_value[pos]["value"]
log.debug(
"Use default value %s for argument %s", arg_default_value, pos_arg_name
)
return arg_default_value

def codegen_args(self):
@dataclasses.dataclass
Expand All @@ -5388,24 +5403,14 @@ def __repr__(self):
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
args = self.fill_non_provided_args(args, kwargs)
args = [
V.graph.wrapper_code.val_to_cpp_arg_str(param.real_type, x)
for param, x in zip(self.op_overload._schema.arguments, args)
]
else:
args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]

# Previously, we want to maintain forward-compatibility by skipping
# default args in the serialized artifacts in fbcode. However,
# some of our shim interfaces require default values being set.
# Discussed with Sherlock offline and we decided to allow serializing
# default args into the C++ wrapper code for now. We will refine this
# part if we see real FC requirement. More details related to FC
# can be found at:
# https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
if V.graph.cpp_wrapper and hasattr(self, "args_default_value"):
self.fill_non_provided_args(args, kwargs, convert_val_to_str=True)

# let self.codegen_kwargs handle kwargs
self.kwargs.update(kwargs)
return args
Expand Down Expand Up @@ -5441,30 +5446,6 @@ def get_mutation_names(self):
assert len(self.mutation_names) <= 1
return self.mutation_names

def fill_non_provided_args(self, args, kwargs, convert_val_to_str=False):
assert isinstance(args, (list, tuple))
if isinstance(args, tuple):
args = list(args)
assert hasattr(self, "args_default_value")
n_args = len(args)
n_pos_args = len(self.args_default_value)
# For cpp wrapper, if some positional args are not provided, we need to check
# if they're in the kwargs or use their default value
if n_args < n_pos_args:
log.debug(
"%s has %d unprovided positional arguments. "
"Will check if they are in the keyword arguments or will use default values.",
self.op_overload,
n_pos_args - n_args,
)
pos_args = [
self.get_pos_arg_value(i, kwargs) for i in range(n_args, n_pos_args)
]
if convert_val_to_str:
pos_args = [V.graph.wrapper_code.val_to_arg_str(x) for x in pos_args]
args.extend(pos_args)
return args

# ProxyExecutor Design Note
# We export the ExternFallbackNodes (for custom ops) into a serialized file
# and run it with a host side proxy executor to address the ABI problem
Expand Down Expand Up @@ -5539,15 +5520,6 @@ def codegen(self, wrapper):
if kernel.namespace == "aten": # type: ignore[union-attr]
# Aten Fallback Ops
assert isinstance(kernel, torch._ops.OpOverload)

if (
kernel == aten.mul.Tensor
and len(self.inputs) == 1
and len(self.constant_args) == 1
):
# When aten.mul.Tensor's second arg is constant, cpp wrapper expects to call mul_Scalar
kernel = aten.mul.Scalar

if V.graph.cpp_wrapper:
if (
config.is_fbcode()
Expand All @@ -5562,10 +5534,6 @@ def codegen(self, wrapper):
)
self.use_runtime_dispatch = True
self.set_cpp_kernel(kernel)
else:
self.cpp_kernel_name = get_aten_cpp_kernel_name(kernel)
schema = kernel._schema # type: ignore[union-attr]
self.init_args_default_value(schema)
else:
self.python_kernel_name = str(kernel)
elif kernel.namespace == "_quantized": # type: ignore[union-attr]
Expand Down
Loading