Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 33 additions & 32 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,12 @@ payload (deprecated) or via offsets to the constant_data_ptr. If no constant
data associated with the tensor value, then returns nullptr.
*/
const uint8_t* getConstantDataPtr(
const fb_xnnpack::XNNTensorValue* tensor_value,
uint32_t buffer_idx,
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
auto buffer_idx = tensor_value->constant_buffer_idx();
if (buffer_idx) {
if (!constant_data_ptr) {
// TODO(T172265611): Remove constant_buffer in flatbuffer path after BC
Expand Down Expand Up @@ -230,6 +229,22 @@ const uint8_t* getConstantDataPtr(
return nullptr;
}

const uint8_t* getConstantDataPtr(
const fb_xnnpack::XNNTensorValue* tensor_value,
GraphPtr flatbuffer_graph,
const uint8_t* constant_data_ptr,
const NamedDataMap* named_data_map,
std::vector<FreeableBuffer>& freeable_buffers,
XNNWeightsCache* weights_cache) {
return getConstantDataPtr(
tensor_value->constant_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache);
}

/**
Define serialized tensor value into
the subgraph. While also keeping track of the remapped ids from
Expand Down Expand Up @@ -434,22 +449,15 @@ Error defineTensor(
const float* scale = qparams->scale()->data();

if (qparams->scale_buffer_idx() != 0) {
// if scales are stored in named data, then retrieve it
ConstantDataOffsetPtr scale_buffer_offset =
flatbuffer_graph->constant_data()->Get(
qparams->scale_buffer_idx());
const std::string& data_name =
scale_buffer_offset->named_key()->str();
Result<FreeableBuffer> scale_buffer =
named_data_map->get_data(data_name.c_str());
scale = reinterpret_cast<const float*>(getConstantDataPtr(
qparams->scale_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache));
ET_CHECK_OR_RETURN_ERROR(
scale_buffer.ok(),
Internal,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(scale_buffer.error()));
scale = reinterpret_cast<const float*>(scale_buffer.get().data());
freeable_buffers.push_back(std::move(scale_buffer.get()));
scale != nullptr, Internal, "Failed to load scale data.");
}
status = xnn_define_channelwise_quantized_tensor_value_v2(
/*subgraph=*/subgraph_ptr,
Expand Down Expand Up @@ -483,22 +491,15 @@ Error defineTensor(
// Block scales are preferably serialized as bf16 but can also be
// serialized as fp32 for backwards compatability.
if (qparams->scale_buffer_idx() != 0) {
ConstantDataOffsetPtr scale_buffer_offset =
flatbuffer_graph->constant_data()->Get(
qparams->scale_buffer_idx());
const std::string& data_name =
scale_buffer_offset->named_key()->str();
Result<FreeableBuffer> scale_buffer =
named_data_map->get_data(data_name.c_str());
scale_data = reinterpret_cast<const uint16_t*>(getConstantDataPtr(
qparams->scale_buffer_idx(),
flatbuffer_graph,
constant_data_ptr,
named_data_map,
freeable_buffers,
weights_cache));
ET_CHECK_OR_RETURN_ERROR(
scale_buffer.ok(),
Internal,
"Failed to get constant data for key %s from named_data_map. Error code: %u",
data_name.c_str(),
static_cast<uint32_t>(scale_buffer.error()));
scale_data =
reinterpret_cast<const uint16_t*>(scale_buffer.get().data());
freeable_buffers.push_back(std::move(scale_buffer.get()));
scale_data != nullptr, Internal, "Failed to load scale data.");
scale_numel = qparams->num_scales();
} else {
// Read fp32 scales, convert to bf16.
Expand Down
Loading