Skip to content

Commit

Permalink
[aotinductor] reland: return a copy of any constant (pytorch#112370)
Browse files Browse the repository at this point in the history
When the model returns a constant, we cannot "release" its handle,
because the constant doesn't have any handle at all. Instead,
we should allocate a new tensor and then return a copy of the constant.

Pull Request resolved: pytorch#112370
Approved by: https://github.com/hl475, https://github.com/desertfire
  • Loading branch information
chenyang78 authored and xuhancn committed Nov 8, 2023
1 parent 66919a3 commit 129ed1b
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 11 deletions.
13 changes: 13 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,19 @@ def forward(self, x, weight, bias):

self.assertTrue(same(actual, expected))

def test_return_constant(self):
class Model(torch.nn.Module):
def __init__(self, device):
super().__init__()
self.cst = torch.randn(5, 5, device=device)

def forward(self, x):
a = self.cst.clone()
return (x, a)

x = torch.randn(5, device=self.device)
self.check_model(Model(self.device), (x,))


class AOTInductorTestABICompatibleCpu(TestCase):
device = "cpu"
Expand Down
41 changes: 30 additions & 11 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,25 +1505,44 @@ def define_kernel(

def generate_return(self, output_refs):
if V.graph.aot_mode:
cst_names = V.graph.constants.keys()
for idx, output in enumerate(output_refs):
if config.aot_inductor.abi_compatible:
if output in self.cached_thread_locals:
if output in cst_names:
# In some rare cases where we return a constant, we
# have to return a copy of this constant, because
# (1) constants are not owned by the Model instance
# (2) constants remain the same cross inference runs,
# assuming they are not updated at runtime
# Basically, we cannot release or transfer the ownership
# of any origianl constant to the user.
if config.aot_inductor.abi_compatible:
self.wrapper_call.writeline(
f"aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]);"
f"aoti_torch_clone({output}, &output_handles[{idx}]);"
)
else:
self.wrapper_call.writeline(
f"aoti_torch_assign_tensors({output}, output_handles[{idx}]);"
f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
+ f"new at::Tensor(std::move({output}.clone())));"
)
else:
if config.aot_inductor.abi_compatible:
if output in self.cached_thread_locals:
self.wrapper_call.writeline(
f"aoti_torch_new_uninitialized_tensor(&output_handles[{idx}]);"
)
self.wrapper_call.writeline(
f"aoti_torch_assign_tensors({output}, output_handles[{idx}]);"
)
else:
self.wrapper_call.writeline(
f"output_handles[{idx}] = {output}.release();"
)

else:
self.wrapper_call.writeline(
f"output_handles[{idx}] = {output}.release();"
f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
+ f"new at::Tensor({output}));"
)

else:
self.wrapper_call.writeline(
f"output_handles[{idx}] = reinterpret_cast<AtenTensorHandle>("
+ f"new at::Tensor({output}));"
)
else:
self.wrapper_call.writeline(f"return {{{', '.join(output_refs)}}};\n}}")

Expand Down
7 changes: 7 additions & 0 deletions torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,13 @@ aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst);

// This function will create a new tensor object and its pointer is returned
// through *ret. The caller is responsible for wrapping the tensor pointer
// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object
// when going out of scope.
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret);

AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out(
AtenTensorHandle out,
AtenTensorHandle self,
Expand Down
9 changes: 9 additions & 0 deletions torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,15 @@ AOTITorchError aoti_torch_assign_tensors(
});
}

AOTITorchError aoti_torch_clone(AtenTensorHandle self, AtenTensorHandle* ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* self_tensor = tensor_handle_to_tensor_pointer(self);
at::Tensor out_tensor = self_tensor->clone();
at::Tensor* out_tensor_ptr = new at::Tensor(std::move(out_tensor));
*ret = tensor_pointer_to_tensor_handle(out_tensor_ptr);
});
}

// TODO: implement a more efficient version instead of calling into aten
AOTITorchError aoti_torch_addmm_out(
AtenTensorHandle out,
Expand Down

0 comments on commit 129ed1b

Please sign in to comment.