Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions test/inductor/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,23 @@ std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>> fn_with_optiona
return {t3, t4, t5};
}

std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>, int64_t, int64_t> fn_with_int_output_impl(Tensor t1, Tensor t2, int64_t i1) {
Tensor t3 = t1 + t2;
Tensor t4 = t1 - t2;
Tensor t5;
int64_t i2 = 0;
int64_t i3 = 0;
return {t3, t4, t5, i2, i3};
}

std::tuple<Tensor, std::optional<Tensor>, std::optional<Tensor>, int64_t, int64_t> fn_with_int_output_meta(Tensor t1, Tensor t2, int64_t i1) {
Tensor t3 = t1.clone();
Tensor t4 = t1.clone();
Tensor t5;
int64_t i2 = 0;
int64_t i3 = 0;
return {t3, t4, t5, i2, i3};
}

Tensor fn_with_all_inputs_impl(
const Tensor& tensor,
Expand Down Expand Up @@ -381,6 +397,7 @@ TORCH_LIBRARY(aoti_custom_ops, m) {
m.def("custom_add(Tensor t1, Tensor t2) -> Tensor");
m.def("fn_with_optional_tensor_output(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
m.def("fn_with_optional_tensor_output_2(Tensor t1, Tensor t2) -> (Tensor, Tensor?, Tensor?)");
m.def("fn_with_int_output(Tensor t1, Tensor t2, int i) -> (Tensor, Tensor?, Tensor?, int, int)");
m.def(
"fn_with_all_inputs(Tensor tensor, "
"Tensor[] tensors, "
Expand Down Expand Up @@ -428,6 +445,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
m.impl("custom_add", at::custom_add_impl);
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_impl);
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_impl);
m.impl("fn_with_int_output", at::fn_with_int_output_impl);
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_impl);
m.impl("fn_with_default_input", at::fn_with_default_input_impl);
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_impl);
Expand All @@ -441,6 +459,7 @@ TORCH_LIBRARY_IMPL(aoti_custom_ops, CompositeExplicitAutograd, m) {
TORCH_LIBRARY_IMPL(aoti_custom_ops, Meta, m) {
m.impl("fn_with_optional_tensor_output", at::fn_with_optional_tensor_output_meta);
m.impl("fn_with_optional_tensor_output_2", at::fn_with_optional_tensor_output_2_meta);
m.impl("fn_with_int_output", at::fn_with_int_output_meta);
m.impl("fn_with_all_inputs", at::fn_with_all_inputs_meta);
m.impl("fn_with_default_input", at::fn_with_default_input_meta);
m.impl("fn_with_tuple_output", at::fn_with_tuple_output_meta);
Expand Down
14 changes: 14 additions & 0 deletions test/inductor/test_aot_inductor_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ def forward(self, x, y):
)
self.check_model(m, args)

def test_fn_with_int_output(self) -> None:
class M(torch.nn.Module):
def forward(self, x, y):
i = x.shape[0]
z, _, _, i1, i2 = torch.ops.aoti_custom_ops.fn_with_int_output(x, y, i)
return z, z * (i1 + i2 + i)

m = M().to(device=self.device)
args = (
torch.randn(3, 3, device=self.device),
torch.randn(3, 3, device=self.device),
)
self.check_model(m, args)

def test_custom_op_all_inputs(self) -> None:
class MyModel(torch.nn.Module):
# pyre-fixme[3]: Return type must be annotated.
Expand Down
13 changes: 12 additions & 1 deletion torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1885,6 +1885,11 @@ def generate_extern_kernel_args_decl_if_needed(
output_args: Optional[list[str]] = None,
raw_outputs: Optional[list[ir.Buffer]] = None,
):
"""
Generates declarations for external kernel arguments if needed, based on the provided
operator and its arguments. It processes both input and output arguments, categorizing
them into tensor and integer arguments for further code generation.
"""
schema = None
if isinstance(op_overload, torch._higher_order_ops.torchbind.CallTorchBind):
obj = raw_args[0]
Expand Down Expand Up @@ -2006,7 +2011,9 @@ def fill_output_arg(arg, return_type, is_mutated_output: bool):

# TODO: Only support None and tensor(s) returns for now, SymInt is not implemented yet
for return_type in return_types:
if isinstance(return_type, (torch.TensorType, torch.NoneType)):
if isinstance(
return_type, (torch.TensorType, torch.NoneType, torch.IntType)
):
pass
elif isinstance(return_type, torch.OptionalType):
assert isinstance(return_type.getElementType(), torch.TensorType)
Expand All @@ -2021,6 +2028,8 @@ def fill_output_arg(arg, return_type, is_mutated_output: bool):
# None output is supported, but Optional return types are not yet supported
if output_arg is None:
continue
elif isinstance(raw_output_arg, int):
new_int_args.append(str(raw_output_arg))
elif isinstance(output_arg, (list, tuple)):
for out in output_arg:
fill_output_arg(
Expand Down Expand Up @@ -2060,6 +2069,8 @@ def extract_output_name(out):
return mutated_buf_names[0]
elif isinstance(out, (list, tuple)):
return type(out)(extract_output_name(o) for o in out)
elif isinstance(out, int):
return str(out)
else:
raise AssertionError(f"Unexpected output: {type(out)}")

Expand Down
8 changes: 8 additions & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -6753,6 +6753,12 @@ class ExternKernelNode:


class FallbackKernel(ExternKernelAlloc):
"""
A class that represents a fallback kernel for handling operators that are not
directly support by inductor. It currently supports functional ops, view ops,
implace aten ops, and mutating ops that are auto-functionalizable.
"""

def __init__( # type: ignore[no-untyped-def]
self,
layout,
Expand Down Expand Up @@ -7023,6 +7029,8 @@ def handle_single_output(return_type, output): # type: ignore[no-untyped-def]
)
)
)
elif isinstance(return_type, torch.IntType):
return export_schema.Argument.create(as_int=output)
else:
raise RuntimeError(f"Unsupported return type {type(return_type)}")

Expand Down
36 changes: 34 additions & 2 deletions torch/csrc/inductor/aoti_torch/oss_proxy_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,17 @@ void OSSProxyExecutor::get_output_info_from_serialized(
}
break;
}
case c10::TypeKind::IntType: {
TORCH_CHECK(
serialized_output_type == "as_int",
"Expected extern kernel ",
serialized_node["target"],
" to have serialized output type as_int, ",
" but got ",
serialized_output_type);
outputs.emplace_back(output_index, DynamicArgType::IntType, 1);
break;
}
default: {
TORCH_CHECK(
false,
Expand Down Expand Up @@ -800,12 +811,14 @@ void OSSProxyExecutor::call_function(
tensor_id,
", expected num = ",
num_tensors - num_output_tensors);

int num_output_ints = op_kernel->num_output_ints();
TORCH_CHECK(
int_id == num_ints,
int_id == num_ints - num_output_ints,
"Mismatch between ints consumed and num_ints, got int_id = ",
int_id,
", num_ints = ",
num_ints);
num_ints - num_output_ints);

// Call the op with the prepared stack.
op_kernel->run(stack);
Expand Down Expand Up @@ -851,6 +864,18 @@ void OSSProxyExecutor::call_function(
} else {
index++;
}
} else if (schema_return.real_type()->kind() == c10::TypeKind::IntType) {
// need to use real_type() to differentiate between IntType and SymIntType
// for int type, it is already specialized in downstream kernels. So we
// don't need to do anything here.
auto returned_int_value = stack[index++].toInt();
auto serialized_int_value = flatten_int_args[int_id++];
TORCH_CHECK(
returned_int_value == serialized_int_value,
"Expect returned int value to match the serialized int value, but got retured int value: ",
returned_int_value,
" and serialized int value: ",
serialized_int_value);
} else {
TORCH_CHECK(
false,
Expand All @@ -865,6 +890,13 @@ void OSSProxyExecutor::call_function(
tensor_id,
", expected num = ",
num_tensors);

TORCH_CHECK(
int_id == num_ints,
"Mismatch between tensors consumed and num_ints, got tensor_id = ",
int_id,
", expected num = ",
num_ints);
}

} // namespace torch::aot_inductor
10 changes: 10 additions & 0 deletions torch/csrc/inductor/aoti_torch/oss_proxy_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@ struct OSSOpKernel {
return num_output_tensors;
}

int num_output_ints() const {
int num_output_ints = 0;
for (const auto& output : outputs_) {
if (output.arg_type == DynamicArgType::IntType) {
num_output_ints += output.length;
}
}
return num_output_ints;
}

virtual void run(std::vector<c10::IValue>& stack) = 0;
virtual c10::FunctionSchema schema() const = 0;
virtual ~OSSOpKernel() = default;
Expand Down
Loading