Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 63 additions & 57 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,8 @@ Result<const uint8_t*> getConstantDataPtr(
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
XNNWeightsCache* weights_cache,
bool use_weight_cache) {
if (buffer_idx) {
if (!constant_data_ptr) {
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
Expand Down Expand Up @@ -230,30 +231,30 @@ Result<const uint8_t*> getConstantDataPtr(
InvalidProgram,
"Named key is null");
const std::string& data_name = constant_data_offset->named_key()->str();
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
Result<const uint8_t*> data_ptr =
weights_cache->load_unpacked_data(data_name);
if (!data_ptr.ok()) {
ET_LOG(Error, "Failed to load weights from cache");
return data_ptr.error();
}
return data_ptr.get();
#else
Result<FreeableBuffer> buffer =
named_data_map->get_data(data_name.c_str());
if (!buffer.ok()) {
ET_LOG(
Error,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(buffer.error()));
return buffer.error();
if (use_weight_cache) {
Result<const uint8_t*> data_ptr =
weights_cache->load_unpacked_data(data_name);
if (!data_ptr.ok()) {
ET_LOG(Error, "Failed to load weights from cache");
return data_ptr.error();
}
return data_ptr.get();
} else {
Result<FreeableBuffer> buffer =
named_data_map->get_data(data_name.c_str());
if (!buffer.ok()) {
ET_LOG(
Error,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(buffer.error()));
return buffer.error();
}
const uint8_t* data_ptr =
static_cast<const uint8_t*>(buffer.get().data());
freeable_buffers.push_back(std::move(buffer.get()));
return data_ptr;
}
const uint8_t* data_ptr =
static_cast<const uint8_t*>(buffer.get().data());
freeable_buffers.push_back(std::move(buffer.get()));
return data_ptr;
#endif
}
}
}
Expand All @@ -267,14 +268,16 @@ Result<const uint8_t*> getConstantDataPtr(
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
XNNWeightsCache* weights_cache,
bool use_weight_cache) {
return getConstantDataPtr(
tensor_value->constant_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
weights_cache,
use_weight_cache);
}

/**
Expand All @@ -293,7 +296,8 @@ Error defineTensor(
CompileAllocator& allocator,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
XNNWeightsCache* weights_cache,
bool use_weight_cache) {
const fb_xnnpack::XNNTensorValue* tensor_value = nullptr;
const fb_xnnpack::XNNQuantizedTensorValue* qtensor_value = nullptr;

Expand Down Expand Up @@ -347,7 +351,8 @@ Error defineTensor(
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
weights_cache,
use_weight_cache);
if (!buffer_result.ok()) {
return buffer_result.error();
}
Expand Down Expand Up @@ -502,7 +507,8 @@ Error defineTensor(
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
weights_cache,
use_weight_cache);
if (!scale_result.ok()) {
return scale_result.error();
}
Expand Down Expand Up @@ -548,7 +554,8 @@ Error defineTensor(
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
weights_cache,
use_weight_cache);
if (!scale_data_result.ok()) {
return scale_data_result.error();
}
Expand Down Expand Up @@ -1976,7 +1983,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
XNNExecutor* executor,
XNNWeightsCache* weights_cache,
xnn_workspace_t workspace,
const NamedDataMap* named_data_map) {
const NamedDataMap* named_data_map,
bool use_weight_cache) {
Result<XNNHeader> header = XNNHeader::Parse(buffer_pointer, num_bytes);
const uint8_t* flatbuffer_data = nullptr;
const uint8_t* constant_data = nullptr;
Expand Down Expand Up @@ -2086,7 +2094,8 @@ ET_NODISCARD Error XNNCompiler::compileModel(
compile_allocator,
named_data_map,
unpacked_buffers,
weights_cache);
weights_cache,
use_weight_cache);

if (err != Error::Ok) {
return err;
Expand All @@ -2108,19 +2117,16 @@ ET_NODISCARD Error XNNCompiler::compileModel(

xnn_runtime_t runtime_ptr = nullptr;

// XNNWeightsCache if weights cache is not enabled, then XNNWeightsCache
// just manages the unpacked weights until the runtime is created.
#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
ET_CHECK_OR_RETURN_ERROR(
unpacked_buffers.size() == 0,
Internal,
"Weight Cache is enabled, which means unpacked buffers should be owned by the cache");
xnn_weights_cache_t weights_cache_ptr =
weights_cache->get_num_unpacked_data() > 0 ? weights_cache->get()
: nullptr;
#else
xnn_weights_cache_t weights_cache_ptr = nullptr;
#endif
if (use_weight_cache) {
ET_CHECK_OR_RETURN_ERROR(
unpacked_buffers.size() == 0,
Internal,
"Weight Cache is enabled, which means unpacked buffers should be owned by the cache");
weights_cache_ptr = weights_cache->get_num_unpacked_data() > 0
? weights_cache->get()
: nullptr;
}

// NOLINTBEGIN(facebook-hte-NullableDereference) - weights cache is allowed to
// be null
Expand All @@ -2139,25 +2145,25 @@ ET_NODISCARD Error XNNCompiler::compileModel(
"XNN Runtime creation failed with code: %s",
xnn_status_to_string(status));

#ifdef ENABLE_XNNPACK_WEIGHTS_CACHE
auto packed_weights_names = weights_cache->finalize_for_runtime();
ET_CHECK_OR_RETURN_ERROR(
packed_weights_names.ok(),
Internal,
"Failed to finalize weights cache after creating the xnn runtime")
#else
for (auto& buffer : unpacked_buffers) {
buffer.Free();
std::vector<std::string> packed_weights_names;
if (use_weight_cache) {
auto packed_weights_names_result = weights_cache->finalize_for_runtime();
ET_CHECK_OR_RETURN_ERROR(
packed_weights_names_result.ok(),
Internal,
"Failed to finalize weights cache after creating the xnn runtime");
packed_weights_names = std::move(packed_weights_names_result.get());
} else {
for (auto& buffer : unpacked_buffers) {
buffer.Free();
}
}
Result<std::vector<std::string>> packed_weights_names =
std::vector<std::string>();
#endif

err = executor->initialize( // NOLINT: runtime_ptr is non-null
runtime_ptr,
std::move(input_ids),
std::move(output_ids),
std::move(packed_weights_names.get()));
std::move(packed_weights_names));

return err;
};
Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/runtime/XNNCompiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class XNNCompiler {
XNNExecutor* executor,
XNNWeightsCache* weights_cache,
xnn_workspace_t workspace,
const NamedDataMap* named_data_map);
const NamedDataMap* named_data_map,
bool use_weight_cache);
};

} // namespace delegate
Expand Down
3 changes: 2 additions & 1 deletion backends/xnnpack/runtime/XNNPACKBackend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ class XnnpackBackend final
executor,
weights_cache_.get(),
workspace_ptr,
named_data_map);
named_data_map,
use_weight_cache);
// This backend does not need its processed data after compiling the model.
processed->Free();

Expand Down
Loading