diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index 78eaaf6d039..1ed7db80d84 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant data associated with the tensor value, then returns nullptr. */ const uint8_t* getConstantDataPtr( - const fb_xnnpack::XNNTensorValue* tensor_value, + uint32_t buffer_idx, GraphPtr flatbuffer_graph, const uint8_t* constant_data_ptr, const NamedDataMap* named_data_map, std::vector& freeable_buffers, XNNWeightsCache* weights_cache) { - auto buffer_idx = tensor_value->constant_buffer_idx(); if (buffer_idx) { if (!constant_data_ptr) { // TODO(T172265611): Remove constant_buffer in flatbuffer path after BC @@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr( return nullptr; } +const uint8_t* getConstantDataPtr( + const fb_xnnpack::XNNTensorValue* tensor_value, + GraphPtr flatbuffer_graph, + const uint8_t* constant_data_ptr, + const NamedDataMap* named_data_map, + std::vector& freeable_buffers, + XNNWeightsCache* weights_cache) { + return getConstantDataPtr( + tensor_value->constant_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache); +} + /** Define serialized tensor value into the subgraph. While also keeping track of the remapped ids from @@ -434,22 +449,15 @@ Error defineTensor( const float* scale = qparams->scale()->data(); if (qparams->scale_buffer_idx() != 0) { - // if scales are stored in named data, then retrieve it - ConstantDataOffsetPtr scale_buffer_offset = - flatbuffer_graph->constant_data()->Get( - qparams->scale_buffer_idx()); - const std::string& data_name = - scale_buffer_offset->named_key()->str(); - Result scale_buffer = - named_data_map->get_data(data_name.c_str()); + scale = reinterpret_cast(getConstantDataPtr( + qparams->scale_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache)); ET_CHECK_OR_RETURN_ERROR( - scale_buffer.ok(), - Internal, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(scale_buffer.error())); - scale = reinterpret_cast(scale_buffer.get().data()); - freeable_buffers.push_back(std::move(scale_buffer.get())); + scale != nullptr, Internal, "Failed to load scale data."); } status = xnn_define_channelwise_quantized_tensor_value_v2( /*subgraph=*/subgraph_ptr, @@ -483,22 +491,15 @@ Error defineTensor( // Block scales are preferably serialized as bf16 but can also be // serialized as fp32 for backwards compatability. if (qparams->scale_buffer_idx() != 0) { - ConstantDataOffsetPtr scale_buffer_offset = - flatbuffer_graph->constant_data()->Get( - qparams->scale_buffer_idx()); - const std::string& data_name = - scale_buffer_offset->named_key()->str(); - Result scale_buffer = - named_data_map->get_data(data_name.c_str()); + scale_data = reinterpret_cast(getConstantDataPtr( + qparams->scale_buffer_idx(), + flatbuffer_graph, + constant_data_ptr, + named_data_map, + freeable_buffers, + weights_cache)); ET_CHECK_OR_RETURN_ERROR( - scale_buffer.ok(), - Internal, - "Failed to get constant data for key %s from named_data_map. Error code: %u", - data_name.c_str(), - static_cast(scale_buffer.error())); - scale_data = - reinterpret_cast(scale_buffer.get().data()); - freeable_buffers.push_back(std::move(scale_buffer.get())); + scale_data != nullptr, Internal, "Failed to load scale data."); scale_numel = qparams->num_scales(); } else { // Read fp32 scales, convert to bf16.