Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOTInductor] ProxyExecutor supports List[Tensor] return type #110182

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
44 changes: 24 additions & 20 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ def convert_arg_type(python_type):


def convert_return_type(python_type):
# TODO: only support Tensor as func return type for now
# TODO: support alias
assert (
python_type == "Tensor"
), f"only support tensor output for cpp_wrapper, but receive type {python_type}"
return f"at::{python_type}"
python_to_cpp = {
"Tensor": "at::Tensor",
"List[Tensor]": "std::vector<at::Tensor>",
}

cpp_type = python_to_cpp.get(python_type, None)
assert cpp_type is not None, f"NYI return type: {python_type}"
return cpp_type


def get_cpp_op_schema(kernel):
Expand Down Expand Up @@ -1742,30 +1745,31 @@ def fill_output_arg(arg, return_type):
)
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
new_tensor_args.append(f"{arg}.get()")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
):
# TODO: handle tensor list return type
raise NotImplementedError("NYI support for return type: List[Tensor]")
elif isinstance(return_type, torch.SymIntType):
raise NotImplementedError("NYI support for return type: SymInt")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.SymIntType
):
raise NotImplementedError("NYI support for return type: List[SymInt]")
else:
raise AssertionError(f"Unsupport return type found: {return_type}")
raise AssertionError(f"Unsupported return type found: {return_type}")

# TODO: Only support tensor(s) returns for now, SymInt is not implemented yet
for return_type in return_types:
if isinstance(return_type, (torch.TensorType)):
pass
elif isinstance(return_type, torch.OptionalType):
assert isinstance(return_type.getElementType(), torch.TensorType)
elif isinstance(return_type, torch.ListType):
assert isinstance(return_type.getElementType(), torch.TensorType)
else:
raise NotImplementedError(
f"return type {return_type} is not yet supported."
)

for output_arg, return_type in zip(output_args, return_types):
for output_arg in output_args:
if output_arg is not None:
if isinstance(return_type, torch.OptionalType):
fill_output_arg(output_arg, return_type.getElementType())
elif isinstance(return_type, torch.TensorType):
fill_output_arg(output_arg, return_type)
else:
raise NotImplementedError(
"Only Tensor and OptionalTensor return type is supported."
)
fill_output_arg(output_arg, torch.TensorType.get())

return new_tensor_args, new_int_args

Expand Down
13 changes: 12 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3792,13 +3792,24 @@ def export_extern_kernel_node(self):
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)

# serialize_outputs
if isinstance(self.outputs, (list, tuple)):
if isinstance(self.outputs, tuple):
# For tuple returns, e.g "-> (Tensor, Tensor)"
output_arguments = [
export_schema.Argument.create(
as_tensor=export_schema.TensorArgument(name=output.get_name())
)
for output in self.outputs
]
elif isinstance(self.outputs, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why we are differentiating list from tuple above. Particularly, why do we pass as_tensors for list but as_tensor for tuple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is by design of the PT2 IR.
TensorList is a first class data type, whereas tuple is used only when there are multiple returns.

Consider following two cases foo(...) -> (Tensor, Tensor) and bar(...) -> Tensor[]

In the first case, self.output is a python tuple; The output would be serialized as [Arguemnt(asTensor="buf3"), Argument(asTensor = "buf4")]
In the second case, self.output is a python list. The output would be serialized as [Argument(asTensors=["buf3", "buf4"])].

For more details, see description in #110187

# For list of tensor, e.g. "-> List[Tensor]"
output_arguments = [
export_schema.Argument.create(
as_tensors=[
export_schema.TensorArgument(name=output.get_name())
for output in self.outputs
]
)
]
else:
output_arguments = [
export_schema.Argument.create(
Expand Down