Skip to content

Commit

Permalink
[AOT] BugFix of workspace calculation (apache#10337)
Browse files Browse the repository at this point in the history
Following an investigation from apache#10022,
it turns out, currently the workspace
calculation assumes there would be a single
lowered PrimFunc could be produced per
primitive Relay Function.

However, the exception turned out to
be the CMSIS-NN codegen that produces
multiple calls/PrimFuncs in the place
of a single call to single relay PrimFunc.

This commit adds changes to workspace
calculation to be done on lowered IRModule.

Additionally, changes the test utils to
not to generate any stack allocator code
when USMP is used to make the tests more
strict.

This change also removes the confusing
"run_model" which has semantics identitical
to "__tvm_main__" in TIR.
  • Loading branch information
manupak authored and pfk-beta committed Apr 11, 2022
1 parent dad1041 commit 1c849e8
Show file tree
Hide file tree
Showing 17 changed files with 220 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
void TVMInitialize() { StackMemoryManager_Init(&app_workspace, g_aot_memory, WORKSPACE_SIZE); }

void TVMExecute(void* input_data, void* output_data) {
int ret_val = tvmgen_default_run_model(input_data, output_data);
int ret_val = tvmgen_default___tvm_main__(input_data, output_data);
if (ret_val != 0) {
TVMPlatformAbort(kTvmErrorPlatformCheckFailure);
}
Expand Down
2 changes: 1 addition & 1 deletion apps/microtvm/zephyr_cmsisnn/src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ extern float output_storage[12];

extern const size_t output_len;

static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE + 512];
static uint8_t g_crt_workspace[TVMGEN_DEFAULT_WORKSPACE_SIZE];
tvm_workspace_t app_workspace;

void TVMLogf(const char* msg, ...) {
Expand Down
2 changes: 0 additions & 2 deletions include/tvm/runtime/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ constexpr const char* tvm_module_main = "__tvm_main__";
constexpr const char* tvm_param_prefix = "__tvm_param__";
/*! \brief A PackedFunc that looks up linked parameters by storage_id. */
constexpr const char* tvm_lookup_linked_param = "_lookup_linked_param";
/*! \brief The main AOT executor function generated from TIR */
constexpr const char* tvm_run_func_suffix = "run_model";
/*! \brief Model entrypoint generated as an interface to the AOT function outside of TIR */
constexpr const char* tvm_entrypoint_suffix = "run";
} // namespace symbol
Expand Down
58 changes: 39 additions & 19 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -658,8 +658,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {

// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix);
String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));
dict_attrs.Set(tvm::attr::kTarget, target_host_);
Expand Down Expand Up @@ -702,6 +701,35 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
}

/*!
* brief Calculate workspace sizes for PrimFuncs in the IRModule
*/
Map<String, FunctionInfo> CalculateWorkspaceSizes(
const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
Executor executor_config = lowered_mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
Integer workspace_byte_alignment =
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
Map<String, FunctionInfo> updated_function_metadata;
for (const auto& kv : lowered_mod->functions) {
GlobalVar global_var = kv.first;
BaseFunc base_func = kv.second;
if (base_func->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
if (function_metadata.count(global_var->name_hint)) {
updated_function_metadata.Set(global_var->name_hint,
function_metadata[global_var->name_hint]);
updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws);
} else {
FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}};
updated_function_metadata.Set(global_var->name_hint, finfo);
}
}
}
return updated_function_metadata;
}

/*!
* brief Run USMP to plan memory for lowered IRModule
*/
Expand All @@ -710,17 +738,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Integer workspace_byte_alignment =
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
IRModule lowered_mod = mod->ShallowCopy();
function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
// Update workspace size based on the pool allocations.
for (const auto& kv : function_metadata_) {
if (lowered_mod->ContainGlobalVar(kv.first) &&
lowered_mod->Lookup(kv.first)->IsInstance<tir::PrimFuncNode>()) {
tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(lowered_mod->Lookup(kv.first));
Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
kv.second->workspace_sizes.Set(tgt, ws);
}
}
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
backend::FunctionInfo main_func_info =
Expand Down Expand Up @@ -752,17 +771,18 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Integer workspace_byte_alignment =
executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
IRModule lowered_mod = mod->ShallowCopy();
function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
// Running StorageRewrite just on the main function
tir::PrimFunc tir_main_func =
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
IRModule main_func_mod;
main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix),
main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
tir_main_func);
main_func_mod = tir::transform::StorageRewrite()(main_func_mod);
lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix),
main_func_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
tir_main_func =
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
// Use the PrimFunc to calculate the workspace required to service the allocates
Integer main_workspace_size_bytes =
CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment);
Expand Down Expand Up @@ -920,7 +940,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// function and replacing it with its TIR version. We should try to make this a Pass.
lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), prim_func);
// Parallel for loops are not supported in AoT codegen.
lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod);

Expand Down Expand Up @@ -960,7 +980,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
std::vector<tir::Var> pool_vars;
tir::PrimFunc tir_main_func =
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
if (allocated_pool_infos) {
Expand Down
2 changes: 1 addition & 1 deletion src/target/source/source_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
}

void GenerateAOTDescriptor() {
const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix;
const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_module_main;
const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix;
const std::string run_func_mangled =
runtime::get_name_mangled(metadata_->mod_name, run_func_suffix);
Expand Down
4 changes: 2 additions & 2 deletions src/tir/usmp/transform/assign_pool_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class PoolInfoAssigner : public StmtExprMutator {
public:
explicit PoolInfoAssigner(const IRModule& module) {
PrimFunc main_func =
Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
ICHECK(main_func.defined()) << "main function is not in the module";
Optional<Target> target_host = main_func->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target_host) << "main function does not have a target attr";
Expand Down Expand Up @@ -79,7 +79,7 @@ class PoolInfoAssigner : public StmtExprMutator {
PoolInfo PoolInfoAssigner::CreateDefaultMemoryPool(const tvm::IRModule& module) {
Map<Target, String> target_access;
tir::PrimFunc tir_main_func =
Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
Downcast<tir::PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
Target target_host = tir_main_func->GetAttr<Target>(tvm::attr::kTarget).value();
for (const auto& kv : module->functions) {
BaseFunc func = kv.second;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ PrimExpr PoolAllocationToOffsetConverter::VisitExpr_(const LoadNode* op) {
}

IRModule PoolAllocationToOffsetConverter::operator()() {
GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix);
GlobalVar gv = module_->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main);
PrimFunc main_func = Downcast<PrimFunc>(module_->Lookup(gv));
ScopeInfo si = UpdateFunctionScopeInfo(main_func);
this->scope_stack.push(si);
Expand Down
4 changes: 2 additions & 2 deletions src/tir/usmp/unified_static_memory_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(

IRModule PlanMemory(const IRModule& mod, String algo) {
VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
PrimFunc main_func = Downcast<PrimFunc>(mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
BufferInfoAnalysis buffer_info_analysis = ExtractBufferInfo(main_func, mod);
Array<BufferInfo> buffer_info_arr =
CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts);
Expand All @@ -63,7 +63,7 @@ IRModule PlanMemory(const IRModule& mod, String algo) {
buffer_info_analysis->buffer_info_stmts, buffer_info_pool_allocations);
IRModule ret = transform::ConvertPoolAllocationsToOffsets(stmt_pool_allocations)(mod);
tir::PrimFunc tir_main_func =
Downcast<tir::PrimFunc>(ret->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
Downcast<tir::PrimFunc>(ret->Lookup(::tvm::runtime::symbol::tvm_module_main));
Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
if (allocated_pool_infos) {
Expand Down
2 changes: 1 addition & 1 deletion src/tir/usmp/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ class ModuleWorkspaceSizeCalculator : public StmtExprVisitor {
for (const auto& gv_func : mod_->functions) {
functions_.Set(gv_func.first->name_hint, Downcast<PrimFunc>(gv_func.second));
}
main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_run_func_suffix));
main_func_ = Downcast<PrimFunc>(module->Lookup(::tvm::runtime::symbol::tvm_module_main));
ICHECK(main_func_.defined()) << "main function is not in the module";
Optional<Target> target_host = main_func_->GetAttr<Target>(tvm::attr::kTarget);
ICHECK(target_host) << "main function does not have a target attr";
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_ethosu/infra.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,13 @@ def build_source(
def verify_source(
models: List[AOTCompiledTestModel],
accel="ethos-u55-256",
enable_usmp=True,
):
"""
This method verifies the generated source from an NPU module by building it and running on an FVP.
"""
interface_api = "c"
test_runner = create_test_runner(accel)
test_runner = create_test_runner(accel, enable_usmp)
run_and_check(
models,
test_runner,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/contrib/test_ethosu/test_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_forward_mobilenet_v1(accel_type, enable_usmp):
compiled_models = infra.build_source(
mod, input_data, output_data, accel_type, output_tolerance=10, enable_usmp=enable_usmp
)
infra.verify_source(compiled_models, accel_type)
infra.verify_source(compiled_models, accel_type, enable_usmp=enable_usmp)


if __name__ == "__main__":
Expand Down
97 changes: 77 additions & 20 deletions tests/python/relay/aot/aot_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,29 +265,56 @@ def emit_data_linkage(output_file, data_linkage):


def emit_main_prologue(
main_file, custom_prologue, workspace_bytes, data_linkage, compiled_models, interface_api
main_file,
custom_prologue,
workspace_bytes,
data_linkage,
compiled_models,
interface_api,
use_stack_allocator=True,
):
# Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment.
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
if interface_api == "c":
for compiled_model in compiled_models:
model = compiled_model.model
workspace_define += f" + TVMGEN_{model.name.upper()}_WORKSPACE_SIZE"
workspace_define += " + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n"
main_file.write(workspace_define)
emit_data_linkage(main_file, data_linkage)
main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n")
main_file.write("tvm_workspace_t app_workspace;\n")
main_file.write(
"""
if use_stack_allocator:
workspace_define = f"#define WORKSPACE_SIZE ({workspace_bytes}"
if interface_api == "c":
for compiled_model in compiled_models:
model = compiled_model.model
workspace_define += f" + TVMGEN_{model.name.upper()}_WORKSPACE_SIZE"
# Add TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES because of memory alignment.
workspace_define += " + TVM_RUNTIME_ALLOC_ALIGNMENT_BYTES)\n"
main_file.write(workspace_define)
emit_data_linkage(main_file, data_linkage)
main_file.write("static uint8_t g_aot_memory[WORKSPACE_SIZE];\n")
main_file.write("tvm_workspace_t app_workspace;\n")
main_file.write(
"""
tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
return StackMemoryManager_Allocate(&app_workspace, num_bytes, out_ptr);
}
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
return StackMemoryManager_Free(&app_workspace,ptr);
}
"""
)
else:
# An implementation is not needed for these if the stack allocator is not used
main_file.write(
"""
tvm_crt_error_t TVMPlatformMemoryAllocate(size_t num_bytes, DLDevice dev, void** out_ptr) {
return kTvmErrorFunctionCallNotImplemented;
}
tvm_crt_error_t TVMPlatformMemoryFree(void* ptr, DLDevice dev) {
return kTvmErrorFunctionCallNotImplemented;
}
"""
)
main_file.write(
"""
void TVMPlatformAbort(tvm_crt_error_t code) { exit(-1); }
void TVMLogf(const char* msg, ...) {
Expand All @@ -296,10 +323,10 @@ def emit_main_prologue(
vfprintf(stdout, msg, args);
va_end(args);
}
TVM_DLL int TVMFuncRegisterGlobal(const char* name, TVMFunctionHandle f, int override) {}
int main(){\n
"""
"""
)
main_file.write(custom_prologue)

Expand Down Expand Up @@ -511,6 +538,7 @@ def create_main(
data_linkage,
interface_api,
workspace_bytes,
use_stack_allocator=True,
):
file_path = pathlib.Path(f"{output_path}/" + test_name).resolve()
# create header file
Expand All @@ -533,8 +561,10 @@ def create_main(
data_linkage,
compiled_models,
interface_api,
use_stack_allocator,
)
emit_main_init_memory_manager(main_file)
if use_stack_allocator:
emit_main_init_memory_manager(main_file)

if interface_api == "c":
for compiled_model in compiled_models:
Expand Down Expand Up @@ -709,11 +739,14 @@ def run_and_check(
t = tarfile.open(tar_file)
t.extractall(base_path)

workspace_bytes = model.extra_memory_in_bytes
use_usmp = runner.pass_config.get("tir.usmp.enable", False)
if interface_api == "packed" and not use_usmp:
# Interface C APIs does not need compiler generated
# workspace to generate the test application, because
# workspace size is codegen'd as a macro to
# tvmgen_<model_name>.h.
if interface_api != "c":
workspace_bytes += mlf_extract_workspace_size_bytes(tar_file)

workspace_bytes += model.extra_memory_in_bytes
for key in model.inputs:
sanitized_tensor_name = re.sub(r"\W", "_", key)
create_header_file(
Expand All @@ -738,6 +771,10 @@ def run_and_check(
data_linkage,
)

use_usmp = runner.pass_config.get("tir.usmp.enable", False)
# We only need the stack allocator if USMP is not used
use_stack_allocator = not use_usmp

create_main(
"test.c",
models,
Expand All @@ -748,6 +785,7 @@ def run_and_check(
data_linkage,
interface_api,
workspace_bytes,
use_stack_allocator,
)

# Verify that compiles fine
Expand Down Expand Up @@ -868,3 +906,22 @@ def generate_ref_data(mod, input_data, params=None, target="llvm"):
output_tensor_names = main.attrs["output_tensor_names"]

return dict(zip(output_tensor_names, out))


def create_relay_module_and_inputs_from_tflite_file(tflite_model_file):
"""A helper function to create a Relay IRModule with inputs
and params from a tflite file"""
with open(tflite_model_file, "rb") as f:
tflite_model_buf = f.read()
mod, params = convert_to_relay(tflite_model_buf)

inputs = dict()
for param in mod["main"].params:
name = str(param.name_hint)
data_shape = [int(i) for i in param.type_annotation.shape]
dtype = str(param.type_annotation.dtype)
in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max)
data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype)
inputs[name] = data

return mod, inputs, params
Loading

0 comments on commit 1c849e8

Please sign in to comment.