Skip to content

Commit

Permalink
[AOTInductor] Simplified AOTInductor interface and model class (#110411)
Browse files Browse the repository at this point in the history
Summary:
This PR removed several APIs from the AOTInductor interface,
which are not used by the client.

It also simplified AOTInductor's model class by removing
the dim info for input/output tensors. We included dim info
before to return max output shapes, which was used by the client
to allocate memory for output tensors. Now, we allocate output
tensor memory from the .so so that we don't need to maintain
such information any more. The deletion of dim info from
the model class also simplified the codegen quite a bit.

Test Plan: ci

Reviewed By: khabinov

Differential Revision: D49835430

Pull Request resolved: #110411
Approved by: https://github.com/khabinov, https://github.com/desertfire, https://github.com/jansel
  • Loading branch information
chenyang78 authored and pytorchmergebot committed Oct 4, 2023
1 parent baa9af1 commit 46a5558
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 334 deletions.
20 changes: 20 additions & 0 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,25 @@ def forward(self, x, y):
):
self.check_model(Repro(), example_inputs)

def test_dynamic_cat(self):
class Model(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x1, x2):
return torch.cat([x1, x2], dim=0)

a = torch.randn(2, 4, device=self.device)
b = torch.randn(3, 4, device=self.device)
constraints = [
torch._export.dynamic_dim(a, 0) >= 1,
torch._export.dynamic_dim(a, 0) <= 10,
torch._export.dynamic_dim(b, 0) >= 1,
torch._export.dynamic_dim(b, 0) <= 20,
]
example_inputs = (a, b)
self.check_model(Model(), example_inputs, constraints=constraints)

@unittest.skipIf(
torch.cuda.device_count() < 2, "The test requires multiple devices"
)
Expand Down Expand Up @@ -719,6 +738,7 @@ class AOTInductorTestABICompatibleCpu(TestCase):
{
"test_addmm_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
"test_bmm_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
"test_dynamic_cat": TestFailure(("abi_compatible_cpu",)),
"test_dynamic_smem_above_default_limit": TestFailure(("abi_compatible_cpu",)),
"test_foreach_multiple_dynamic": TestFailure(("abi_compatible_cpu",)),
# TODO: test_freezing_abi_compatible_cpu somehow fails on CI but not locally,
Expand Down
54 changes: 0 additions & 54 deletions torch/_inductor/codegen/aoti_runtime/interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,17 +108,6 @@ AOTIRuntimeError AOTInductorModelContainerGetInputName(
{ *ret_input_names = container->input_name(input_idx); })
}

AOTIRuntimeError AOTInductorModelContainerGetInputDtype(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** ret_input_dtypes) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_input_dtypes = container->get_input_dtype(input_idx); })
}

AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
AOTInductorModelContainerHandle container_handle,
size_t* ret_num_outputs) {
Expand All @@ -140,49 +129,6 @@ AOTIRuntimeError AOTInductorModelContainerGetOutputName(
{ *ret_output_names = container->output_name(output_idx); })
}

AOTIRuntimeError AOTInductorModelContainerGetOutputDtype(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const char** ret_output_dtypes) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE(
{ *ret_output_dtypes = container->get_output_dtype(output_idx); })
}

AOTIRuntimeError AOTInductorModelContainerGetMaxInputShape(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const int64_t** ret_input_sizes,
int64_t* ret_input_ndim) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
const std::vector<int64_t>& max_input_shape =
container->max_input_shape(input_idx);
*ret_input_sizes = max_input_shape.data();
*ret_input_ndim = max_input_shape.size();
})
}

AOTIRuntimeError AOTInductorModelContainerGetMaxOutputShape(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const int64_t** ret_output_sizes,
int64_t* ret_output_ndim) {
auto* container =
reinterpret_cast<torch::aot_inductor::AOTInductorModelContainer*>(
container_handle);
CONVERT_EXCEPTION_TO_ERROR_CODE({
const std::vector<int64_t>& max_output_shape =
container->max_output_shape(output_idx);
*ret_output_sizes = max_output_shape.data();
*ret_output_ndim = max_output_shape.size();
})
}

AOTIRuntimeError AOTInductorModelCreate(
AOTInductorModelHandle* model_handle,
AOTInductorConstantMapHandle constant_map_handle) {
Expand Down
69 changes: 4 additions & 65 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,11 +299,6 @@ def __init__(self):
self.last_seen_device_guard_index = None
self.supports_intermediate_hooks = True
self.expr_printer = pexpr
# Not all the dynamic symbols will be used in the generated code. This
# set contains those actually being defined by something like
# "{self.declare_shape} s0 = ...". It ensures that we are not going to
# emit queries for undefined symbols.
self.defined_symbols = set()

self.write_header()
self.write_prefix()
Expand Down Expand Up @@ -602,7 +597,6 @@ def is_expr(x):
for name, shape in graph_inputs_expr:
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
needed.remove(shape)
code.writeline(f"{self.declare}{shape} = {name}{self.ending}")

Expand All @@ -611,7 +605,6 @@ def is_expr(x):
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
needed.remove(shape)
code.writeline(
f"{self.declare}{shape} = {sizeof(name)}[{dim}]{self.ending}"
Expand All @@ -622,7 +615,6 @@ def is_expr(x):
for dim, shape in enumerate(shapes):
shape = V.graph.sizevars.simplify(shape)
if shape in needed:
self.defined_symbols.add(shape)
needed.remove(shape)
code.writeline(
f"{self.declare}{shape} = {strideof(name)}[{dim}]{self.ending}"
Expand Down Expand Up @@ -1098,26 +1090,8 @@ def write_input_output_info(
info_kind: str,
idx: int,
name: str,
dtype: str,
sizes: List[sympy.Expr],
):
self.prefix.writeline(f"""{info_kind}[{idx}].name = "{name}";""")
self.prefix.writeline(f"""{info_kind}[{idx}].dtype = "{dtype}";""")
self.prefix.writeline(f"{info_kind}[{idx}].shape.reserve({len(sizes)});")
for size in sizes:
if isinstance(size, sympy.Integer):
self.prefix.writeline(
f"{info_kind}[{idx}].shape.push_back(make_static_dim({size}));"
)
else:
size = V.graph.sizevars.simplify(size)
# FIXME: handle non-Symbol cases later.
assert isinstance(
size, sympy.Symbol
), f"expected {size=} to be a Symbol"
self.prefix.writeline(
f"{info_kind}[{idx}].shape.push_back({size.name});"
)

def write_wrapper_decl(self):
inputs_len = len(V.graph.graph_inputs.keys())
Expand Down Expand Up @@ -1210,16 +1184,6 @@ def write_wrapper_decl(self):

if V.graph.aot_mode:
self.prefix.writeline("inputs.clear();")
dynamic_symbols = [
s
for s in V.graph.sizevars.free_symbols()
if s in self.defined_symbols
]
for dim in dynamic_symbols:
self.prefix.writeline(
f'auto dim_{dim} = find_dynamic_dim("{dim}");'
)
self.prefix.writeline(f"dim_{dim}->set_value({dim});")

def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
if config.aot_inductor.abi_compatible:
Expand All @@ -1244,14 +1208,8 @@ def codegen_model_constructor(self):
// Generated code example
AOTInductorModel::AOTInductorModel()
: AOTInductorModelBase(4, 1) {
auto s0 = make_dynamic_dim("s0", 2048);
inputs_info_[0].name = "input0";
inputs_info_[0].shape.reserve(2);
inputs_info_[0].shape.push_back(make_static_dim(10));
inputs_info_[0].shape.push_back(make_static_dim(64));
inputs_info_[1].shape.reserve(2);
inputs_info_[1].shape.push_back(s0);
inputs_info_[1].shape.push_back(make_static_dim(64));
inputs_info_[0].dtype = "torch.float16";
...
constants_info_[0].name = "L__self___weight";
constants_info_[0].dtype = at::kFloat;
Expand All @@ -1261,24 +1219,10 @@ def codegen_model_constructor(self):
constants_info_[0].stride = {32, 1};
...
outputs_info_[0].name = "output0";
outputs_info_[0].shape.reserve(2);
outputs_info_[0].shape.push_back(s0);
outputs_info_[0].shape.push_back(make_static_dim(10));
outputs_info_[0].dtype = "torch.float16";
}
"""

def codegen_dynamic_dims():
dynamic_symbols = V.graph.sizevars.free_symbols()
for dim in dynamic_symbols:
var_to_range = V.graph.sizevars.shape_env.var_to_range
dim_range = var_to_range.get(dim, None)
assert (
dim_range is not None
), f"Could not find dim_range for {dim=} from {var_to_range=}"
self.prefix.writeline(
f'auto {dim.name} = make_dynamic_dim("{dim.name}", {dim_range.lower}, {dim_range.upper});'
)

num_inputs = len(V.graph.graph_inputs)
num_outputs = len(V.graph.graph_outputs)
num_constants = len(V.graph.constants)
Expand All @@ -1290,14 +1234,11 @@ def codegen_dynamic_dims():
)

with self.prefix.indent():
codegen_dynamic_dims()
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
assert not isinstance(
inp, sympy.Expr
), f"input {name=} cannot be symbolic"
sizes = inp.get_size()
dtype = V.graph.graph_inputs[name].get_dtype()
self.write_input_output_info("inputs_info_", idx, name, dtype, sizes)
self.write_input_output_info("inputs_info_", idx, name)

for idx, (name, tensor) in enumerate(V.graph.constants.items()):
assert isinstance(tensor, torch.Tensor)
Expand Down Expand Up @@ -1326,10 +1267,8 @@ def codegen_dynamic_dims():
assert not isinstance(
output, sympy.Expr
), f"output {name=} cannot be symbolic"
sizes = output.get_size()
name = f"output{idx}"
dtype = output.get_dtype()
self.write_input_output_info("outputs_info_", idx, name, dtype, sizes)
self.write_input_output_info("outputs_info_", idx, name)

self.prefix.writeline("}")

Expand Down
26 changes: 0 additions & 26 deletions torch/csrc/inductor/aoti_runtime/interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,6 @@ AOTIRuntimeError AOTInductorModelContainerGetInputName(
size_t input_idx,
const char** ret_input_names);

// Retrieves the input dtype at the given index.
AOTIRuntimeError AOTInductorModelContainerGetInputDtype(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const char** ret_input_dtypes);

// Retrieves the number of outputs for the model.
AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(
AOTInductorModelContainerHandle container_handle,
Expand All @@ -98,26 +92,6 @@ AOTIRuntimeError AOTInductorModelContainerGetOutputName(
size_t output_idx,
const char** ret_output_names);

// Retrieves the output dtype at the given index.
AOTIRuntimeError AOTInductorModelContainerGetOutputDtype(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const char** ret_output_dtypes);

// Retieves the input shape with the maximum dimension size for each dimension.
AOTIRuntimeError AOTInductorModelContainerGetMaxInputShape(
AOTInductorModelContainerHandle container_handle,
size_t input_idx,
const int64_t** ret_input_sizes,
int64_t* ret_input_ndim);

// Retieves the output shape with the maximum dimension size for each dimension.
AOTIRuntimeError AOTInductorModelContainerGetMaxOutputShape(
AOTInductorModelContainerHandle container_handle,
size_t output_idx,
const int64_t** ret_output_sizes,
int64_t* ret_output_ndim);

// Creates an AOTInductorModel instance. This is a thin and light wrapper
// around the compiled model; it doesn't handle concurrency, queueing, device
// management, etc. Use this if bare-metal performance is needed and you are
Expand Down
Loading

0 comments on commit 46a5558

Please sign in to comment.