Skip to content

Commit

Permalink
[inductor] Clean up AOTInductor runtime ABI (#109678)
Browse files Browse the repository at this point in the history
Summary: Change the AOTInductor runtime interface to avoid referring to aten data structures directly, mostly at::Tensor and ProxyExecutor. This a combination of #109436,  #109498, #109450, #109606, plus a few internal build changes.

Reviewed By: frank-wei

Differential Revision: D49374820

Pull Request resolved: #109678
Approved by: https://github.com/frank-wei, https://github.com/chenyang78
  • Loading branch information
desertfire authored and pytorchmergebot committed Sep 21, 2023
1 parent 4e3b032 commit 9c2715b
Show file tree
Hide file tree
Showing 17 changed files with 590 additions and 357 deletions.
1 change: 1 addition & 0 deletions build_variables.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,7 @@ lazy_tensor_core_python_sources = [

inductor_core_resources = [
"torch/csrc/inductor/aoti_torch/shim_common.cpp",
"torch/csrc/inductor/aoti_torch/tensor_converter.cpp",
"torch/csrc/inductor/inductor_ops.cpp",
]

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ def main():
"include/torch/csrc/dynamo/*.h",
"include/torch/csrc/inductor/*.h",
"include/torch/csrc/inductor/aot_runtime/*.h",
"include/torch/csrc/inductor/aoti_torch/*.h",
"include/torch/csrc/inductor/aoti_torch/c/*.h",
"include/torch/csrc/jit/*.h",
"include/torch/csrc/jit/backends/*.h",
Expand Down
59 changes: 33 additions & 26 deletions test/cpp/aot_inductor/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <c10/cuda/CUDAStream.h>
#include <torch/csrc/inductor/aot_runtime/interface.h>
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/torch.h>

namespace torch {
Expand Down Expand Up @@ -49,14 +50,17 @@ TEST(AotInductorTest, BasicTest) {
inputs.push_back(y);

AOTInductorModelContainerHandle container_handle;
AOT_INDUCTOR_ERROR_CHECK(
AOTInductorModelContainerCreate(&container_handle, 1 /*num_models*/))
AOTInductorParamShape max_output_shape;
AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerGetMaxOutputShape(
container_handle, 0 /*output_idx*/, &max_output_shape));

c10::IntArrayRef array_size(
max_output_shape.shape_data, max_output_shape.ndim);
AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerCreate(
&container_handle,
1 /*num_models*/,
false /*is_cpu*/,
nullptr /*cubin_dir*/));
const int64_t* max_output_sizes;
int64_t max_output_dim;
AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerGetMaxOutputShape(
container_handle, 0 /*output_idx*/, &max_output_sizes, &max_output_dim));

c10::IntArrayRef array_size(max_output_sizes, max_output_dim);
torch::Tensor output_tensor =
at::zeros(array_size, at::dtype(at::kFloat).device(at::kCUDA));
std::vector<torch::Tensor> outputs;
Expand All @@ -66,31 +70,34 @@ TEST(AotInductorTest, BasicTest) {
const auto stream_id = cuda_stream.stream();
AOTInductorStreamHandle stream_handle =
reinterpret_cast<AOTInductorStreamHandle>(stream_id);
AOTInductorTensorHandle inputs_handle =
reinterpret_cast<AOTInductorTensorHandle>(inputs.data());
AOTInductorTensorHandle outputs_handle =
reinterpret_cast<AOTInductorTensorHandle>(outputs.data());
std::vector<AOTInductorParamShape> output_shapes(
outputs.size(), AOTInductorParamShape());
std::vector<AtenTensorHandle> input_handles =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(inputs);
std::vector<AtenTensorHandle> output_handles =
torch::aot_inductor::unsafe_alloc_new_handles_from_tensors(outputs);

AOTInductorProxyExecutorHandle proxy_executor_handle = nullptr;
std::vector<const int64_t*> output_sizes(outputs.size());
std::vector<int64_t> output_ndims(outputs.size());

AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerRun(
AOTIProxyExecutorHandle proxy_executor_handle = nullptr;

AOTI_RUNTIME_ERROR_CODE_CHECK(AOTInductorModelContainerRun(
container_handle,
inputs_handle,
input_handles.data(),
inputs.size(),
outputs_handle,
output_handles.data(),
outputs.size(),
output_shapes.data(),
stream_handle,
proxy_executor_handle));

ASSERT_EQ(output_shapes.size(), 1);
ASSERT_EQ(output_shapes[0].ndim, 2);
ASSERT_EQ(output_shapes[0].shape_data[0], 32);
ASSERT_EQ(output_shapes[0].shape_data[1], 10);
proxy_executor_handle,
output_sizes.data(),
output_ndims.data()));

ASSERT_EQ(output_sizes.size(), 1);
ASSERT_EQ(output_ndims[0], 2);
ASSERT_EQ(output_sizes[0][0], 32);
ASSERT_EQ(output_sizes[0][1], 10);
ASSERT_TRUE(torch::allclose(results_ref, outputs[0]));
AOT_INDUCTOR_ERROR_CHECK(AOTInductorModelContainerDelete(container_handle));
AOTI_RUNTIME_ERROR_CODE_CHECK(
AOTInductorModelContainerDelete(container_handle));
}

} // namespace aot_inductor
Expand Down
9 changes: 2 additions & 7 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,17 +280,14 @@ def forward(self, x, y):
torch.randn(10285, 96, device="cuda"),
torch.randn(96, 1, device="cuda"),
)
expected = model(*example_inputs)
actual = AOTInductorModelRunner.run(
self.check_model(
model,
example_inputs,
expected,
options={
"max_autotune": True,
"max_autotune_gemm_backends": "TRITON",
},
)
self.assertTrue(same(actual, expected))

def test_addmm(self):
class Model(torch.nn.Module):
Expand All @@ -309,9 +306,7 @@ def forward(self, a):
batch = 2
a = torch.randn(batch, M, K, device="cuda")
example_inputs = (a,)
expected = model(*example_inputs)
actual = AOTInductorModelRunner.run(model, example_inputs, expected)
self.assertTrue(same(actual, expected))
self.check_model(model, example_inputs)

def test_aliased_buffer_reuse(self):
class Repro(torch.nn.Module):
Expand Down
133 changes: 65 additions & 68 deletions torch/_inductor/codegen/aot_runtime/interface.cpp
Original file line number Diff line number Diff line change
@@ -1,48 +1,46 @@
#include <torch/csrc/inductor/aot_runtime/interface.h>
#include <torch/csrc/inductor/aot_runtime/model_container.h>
#include <torch/csrc/inductor/aot_runtime/proxy_executor.h>
#include <ATen/core/dispatch/Dispatcher.h>

#include <iostream>
#include <stdexcept>
#include <vector>

#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
try { \
__VA_ARGS__ \
} catch (const std::exception& e) { \
std::cerr << "Error: " << e.what() << std::endl; \
return AOTInductorError::Failure; \
} catch (...) { \
std::cerr << "Unknown exception occurred." << std::endl; \
return AOTInductorError::Failure; \
} \
return AOTInductorError::Success;
#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \
try { \
__VA_ARGS__ \
} catch (const std::exception& e) { \
std::cerr << "Error: " << e.what() << std::endl; \
return AOTI_RUNTIME_FAILURE; \
} catch (...) { \
std::cerr << "Unknown exception occurred." << std::endl; \
return AOTI_RUNTIME_FAILURE; \
} \
return AOTI_RUNTIME_SUCCESS;

extern "C" {

AOTInductorError AOTInductorModelContainerCreate(
AOTIRuntimeError AOTInductorModelContainerCreate(
AOTInductorModelContainerHandle* container_handle,
size_t num_models,
bool is_cpu,
const char* cubin_dir) {
if (num_models == 0) {
std::cerr << "Error: num_models must be positive, but got 0" << std::endl;
return AOTInductorError::Failure;
return AOTI_RUNTIME_FAILURE;
}
CONVERT_EXCEPTION_TO_ERROR_CODE({
std::optional<std::string> cubin_dir_opt;
if (cubin_dir != nullptr) {
cubin_dir_opt.emplace(cubin_dir);
}
auto* container =
new torch::aot_inductor::AOTInductorModelContainer(num_models, is_cpu, cubin_dir_opt);
auto* container = new torch::aot_inductor::AOTInductorModelContainer(
num_models, is_cpu, cubin_dir_opt);
*container_handle =
reinterpret_cast<AOTInductorModelContainerHandle>(container);
})
}

AOTInductorError AOTInductorModelContainerDelete(
AOTIRuntimeError AOTInductorModelContainerDelete(
AOTInductorModelContainerHandle container_handle) {
CONVERT_EXCEPTION_TO_ERROR_CODE({
auto* container =
Expand All @@ -52,138 +50,137 @@ AOTInductorError AOTInductorModelContainerDelete(
});
}

AOTInductorError AOTInductorModelContainerRun(
AOTIRuntimeError AOTInductorModelContainerRun(
AOTInductorModelContainerHandle container_handle,
AOTInductorTensorHandle inputs_handle,
// Array of raw AtenTensorHandle for output tensors. Handles will be stolen
AtenTensorHandle* input_handles,
size_t num_inputs,
AOTInductorTensorHandle outputs_handle,
// Array of raw AtenTensorHandle for output tensors. Handles will be stolen
AtenTensorHandle* output_handles,
size_t num_outputs,
AOTInductorParamShape* output_shapes,
AOTInductorStreamHandle stream_handle,
AOTInductorProxyExecutorHandle proxy_executor_handle) {
AOTIProxyExecutorHandle proxy_executor_handle,
const int64_t** ret_output_sizes,
int64_t* ret_output_ndims) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);

auto* inputs = reinterpret_cast<at::Tensor*>(inputs_handle);
std::vector<at::Tensor> input_tensors;
input_tensors.reserve(num_inputs);
for (size_t i = 0; i < num_inputs; i++) {
input_tensors.push_back(inputs[i]);
}

auto* outputs = reinterpret_cast<at::Tensor*>(outputs_handle);
std::vector<at::Tensor> output_tensors;
output_tensors.reserve(num_outputs);
for (size_t i = 0; i < num_outputs; i++) {
output_tensors.push_back(outputs[i]);
}
auto input_unique_handles =
torch::aot_inductor::steal_from_raw_handles_to_raii_handles(input_handles, num_inputs);
auto output_unique_handles =
torch::aot_inductor::steal_from_raw_handles_to_raii_handles(output_handles, num_outputs);

auto stream = reinterpret_cast<cudaStream_t>(stream_handle);

torch::aot_inductor::ProxyExecutor* proxy_executor = reinterpret_cast<torch::aot_inductor::ProxyExecutor*>(proxy_executor_handle);

CONVERT_EXCEPTION_TO_ERROR_CODE({
std::vector<std::vector<int64_t>> *shapes;
container->run(input_tensors, output_tensors, &shapes, stream, proxy_executor);
std::vector<std::vector<int64_t>>* shapes;
container->run(
input_unique_handles,
output_unique_handles,
&shapes,
stream,
proxy_executor_handle);
for (size_t i = 0; i < num_outputs; i++) {
output_shapes[i] =
AOTInductorParamShape((shapes->at(i)).data(), (shapes->at(i)).size());
ret_output_sizes[i] = shapes->at(i).data();
ret_output_ndims[i] = shapes->at(i).size();
}
})
}

AOTInductorError AOTInductorModelContainerGetNumInputs(
AOTIRuntimeError AOTInductorModelContainerGetNumInputs(
AOTInductorModelContainerHandle container_handle,
size_t* num_inputs_out) {
size_t* ret_num_inputs) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *num_inputs_out = container->num_inputs(); })
{ *ret_num_inputs = container->num_inputs(); })
}

AOTInductorError AOTInductorModelContainerGetInputName(
AOTIRuntimeError AOTInductorModelContainerGetInputName(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** input_name_out) {
const char** ret_input_names) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *input_name_out = container->input_name(input_idx); })
{ *ret_input_names = container->input_name(input_idx); })
}

AOTInductorError AOTInductorModelContainerGetInputDtype(
AOTIRuntimeError AOTInductorModelContainerGetInputDtype(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** input_dtype_out) {
const char** ret_input_dtypes) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *input_dtype_out = container->get_input_dtype(input_idx); })
{ *ret_input_dtypes = container->get_input_dtype(input_idx); })
}

AOTInductorError AOTInductorModelContainerGetNumOutputs(
AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
AOTInductorModelContainerHandle container_handle,
size_t* num_outputs_out) {
size_t* ret_num_outputs) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *num_outputs_out = container->num_outputs(); })
{ *ret_num_outputs = container->num_outputs(); })
}

AOTInductorError AOTInductorModelContainerGetOutputName(
AOTIRuntimeError AOTInductorModelContainerGetOutputName(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const char** output_name_out) {
const char** ret_output_names) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *output_name_out = container->output_name(output_idx); })
{ *ret_output_names = container->output_name(output_idx); })
}

AOTInductorError AOTInductorModelContainerGetOutputDtype(
AOTIRuntimeError AOTInductorModelContainerGetOutputDtype(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const char** output_dtype_out) {
const char** ret_output_dtypes) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *output_dtype_out = container->get_output_dtype(output_idx); })
{ *ret_output_dtypes = container->get_output_dtype(output_idx); })
}

AOTInductorError AOTInductorModelContainerGetMaxInputShape(
AOTIRuntimeError AOTInductorModelContainerGetMaxInputShape(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
AOTInductorParamShape* input_shape) {
const int64_t** ret_input_sizes,
int64_t* ret_input_ndim) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
const std::vector<int64_t>& max_input_shape =
container->max_input_shape(input_idx);
*input_shape =
AOTInductorParamShape(max_input_shape.data(), max_input_shape.size());
*ret_input_sizes = max_input_shape.data();
*ret_input_ndim = max_input_shape.size();
})
}

AOTInductorError AOTInductorModelContainerGetMaxOutputShape(
AOTIRuntimeError AOTInductorModelContainerGetMaxOutputShape(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
AOTInductorParamShape* output_shape) {
const int64_t** ret_output_sizes,
int64_t* ret_output_ndim) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
const std::vector<int64_t>& max_output_shape =
container->max_output_shape(output_idx);
*output_shape =
AOTInductorParamShape(max_output_shape.data(), max_output_shape.size());
*ret_output_sizes = max_output_shape.data();
*ret_output_ndim = max_output_shape.size();
})
}

Expand Down

0 comments on commit 9c2715b

Please sign in to comment.