Skip to content

Commit

Permalink
Flatbuffer export memory optimization for large models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627417753
  • Loading branch information
sirakiin authored and tensorflower-gardener committed Apr 23, 2024
1 parent 3ab7aa8 commit f6fe996
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 39 deletions.
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/BUILD
Expand Up @@ -1124,6 +1124,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
"@flatbuffers",
"@llvm-project//llvm:Support",
Expand Down
113 changes: 74 additions & 39 deletions tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
Expand Down Expand Up @@ -161,6 +162,9 @@ ABSL_CONST_INIT const absl::string_view kFlexOpNamePrefix = "Flex";
// used by the TOCO export. (It does not explain rationale for this choice.)
constexpr size_t kInitialBufferSize = 10240;

// Flatbuffer fields to be padded to 16 bytes aligned.
constexpr size_t kFbAlignment = 16;

// Set `isSigned` to false if the `type` is an 8-bit unsigned integer type.
// Since tflite doesn't support unsigned for other types, returns error if
// `isSigned` is set to false for other types.
Expand Down Expand Up @@ -716,7 +720,7 @@ class Translator {

// Append constant and custom op buffers at the end of the flatbuffer and
// calculate the offsets
void AppendBufferData(std::string& result);
void AppendBufferData(absl::Cord& result);

// Update constant & custom op buffer offsets
// Return false if fail to update offset
Expand Down Expand Up @@ -816,7 +820,8 @@ class Translator {
// Maps buffer data to corresponding buffer index
// in the idx map, the value is a pair of offset and size
absl::flat_hash_map<int, std::pair<uint64_t, uint64_t>> buffer_idx_map_;
absl::flat_hash_map<int, std::vector<uint8_t>> buffer_data_map_;
absl::flat_hash_map<int, std::string> buffer_data_map_;
bool buffer_data_exported_ = false;

// Maps custom options data to corresponding node
// Key is set to be the list of input tensor indices and list of output tensor
Expand Down Expand Up @@ -955,7 +960,8 @@ std::optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
}
auto packed_buffer = tflite::PackInt4ValuesDensely(data);
if (use_buffer_offset_) {
buffer_data_map_[index] = packed_buffer;
buffer_data_map_[index] =
std::string(packed_buffer.begin(), packed_buffer.end());
return tflite::CreateBuffer(builder_, 0, 1, 1);
} else {
if (IsModelBiggerThan2GB(packed_buffer.size())) {
Expand Down Expand Up @@ -991,7 +997,8 @@ std::optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(
if (use_buffer_offset_) {
std::vector<uint8_t> buffer_data(tensor_buffer, tensor_buffer + bytes);
free(tensor_buffer);
buffer_data_map_[index] = buffer_data;
buffer_data_map_[index] =
std::string(buffer_data.begin(), buffer_data.end());
return tflite::CreateBuffer(builder_, 0, 1, 1);
} else {
if (IsModelBiggerThan2GB(bytes)) {
Expand All @@ -1007,9 +1014,7 @@ std::optional<BufferOffset<tflite::Buffer>> Translator::BuildBuffer(

absl::string_view tensor_data = tensor.tensor_data();
if (use_buffer_offset_) {
std::vector<uint8_t> buffer_data(tensor_data.data(),
tensor_data.data() + tensor_data.size());
buffer_data_map_[index] = buffer_data;
buffer_data_map_[index] = std::string(tensor_data);
return tflite::CreateBuffer(builder_, 0, 1, 1);
} else {
if (IsModelBiggerThan2GB(tensor_data.size())) {
Expand Down Expand Up @@ -3197,17 +3202,19 @@ std::optional<std::string> Translator::Translate(
op_or_arg_name_mapper = &default_op_or_arg_name_mapper;
if (!UpdateEntryFunction(module)) return std::nullopt;
if (!IsValidTFLiteMlirModule(module)) return std::nullopt;
Translator translator(module, toco_flags, tags, op_or_arg_name_mapper,
metadata, custom_option_alignment);
translator.convert_stablehlo_ = serialize_stablehlo_ops;
auto ret = translator.TranslateInternal();
if (translator.require_use_buffer_offset_) {
auto translator = std::unique_ptr<Translator>(
new Translator(module, toco_flags, tags, op_or_arg_name_mapper, metadata,
custom_option_alignment));
translator->convert_stablehlo_ = serialize_stablehlo_ops;
auto ret = translator->TranslateInternal();
if (translator->require_use_buffer_offset_) {
ret = std::nullopt;
auto new_toco_flags = toco_flags;
new_toco_flags.set_use_buffer_offset(true);
Translator new_translator(module, new_toco_flags, tags,
op_or_arg_name_mapper, metadata,
custom_option_alignment);
return new_translator.TranslateInternal();
translator = std::unique_ptr<Translator>(
new Translator(module, new_toco_flags, tags, op_or_arg_name_mapper,
metadata, custom_option_alignment));
return translator->TranslateInternal();
}
return ret;
}
Expand Down Expand Up @@ -3453,63 +3460,91 @@ std::optional<std::string> Translator::TranslateInternal() {
}
}

auto result =
std::string(reinterpret_cast<const char*>(builder_.GetBufferPointer()),
builder_.GetSize());
absl::Cord result;
auto fbs = absl::string_view(
reinterpret_cast<const char*>(builder_.GetBufferPointer()),
builder_.GetSize());
result.Append(fbs);

// Return serialized string for the built FlatBuffer.
if (use_buffer_offset_) {
// Pad to be 16 bytes aligned
{
std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0');
result.Append(std::move(pad));
}
AppendBufferData(result);
auto mutable_model = tflite::GetMutableModel(result.data());
std::string result_str = std::string(std::move(result));
auto mutable_model = tflite::GetMutableModel(result_str.data());
bool ret = UpdateBufferOffsets(mutable_model);
if (!ret) {
return std::nullopt;
}
return result;
return result_str;
}
return result;
return std::string(result);
}

void Translator::AppendBufferData(std::string& result) {
void Translator::AppendBufferData(absl::Cord& result) {
std::unordered_map<uint64_t, std::pair<int64_t, int64_t>> hashcode_to_pos;
// Pad to be 16 bytes aligned
while (result.size() % 16 != 0) result += '\0';
for (auto& it : buffer_data_map_) {
auto buffer = std::string(it.second.begin(), it.second.end());
int64_t index = it.first;
// Buffer data should be exported only once.
assert(!buffer_data_exported_);

auto it = buffer_data_map_.begin();
while (it != buffer_data_map_.end()) {
std::string buffer = it->second;
int64_t index = it->first;
int64_t offset = result.size();
int64_t size = it.second.size();
int64_t size = buffer.size();
uint64_t hash = tsl::Fingerprint64(buffer);
if (hashcode_to_pos.find(hash) == hashcode_to_pos.end()) {
hashcode_to_pos[hash] = std::make_pair(offset, size);
buffer_idx_map_[index] = std::make_pair(offset, size);
result += std::string(it.second.begin(), it.second.end());
// Pad to be 16 bytes aligned
while (result.size() % 16 != 0) result += '\0';
result.Append(std::move(buffer));
// Pad to be 16 bytes aligned.
{
std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0');
result.Append(std::move(pad));
}
} else {
// only update offset/index.
buffer_idx_map_[index] = hashcode_to_pos[hash];
}
buffer_data_map_.erase(it);
it = buffer_data_map_.begin();
buffer_data_exported_ = true;
}
// pad 16 bytes for the last buffer for XNNPack
result += "\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0";
result.Append("\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0");
// pad to be 16 bytes aligned
while (result.size() % 16 != 0) result += '\0';
{
std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0');
result.Append(std::move(pad));
}

for (auto& it : custom_op_data_map_) {
while (result.size() % 16 != 0) result += '\0';
{
std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0');
result.Append(std::move(pad));
}
if (custom_option_alignment_.has_value()) {
while (result.size() % custom_option_alignment_.value() != 0)
result += '\0';
{
auto alignment = custom_option_alignment_.value();
std::string pad(alignment - result.size() % alignment, '\0');
result.Append(std::move(pad));
}
}
auto buffer = std::string(it.second.begin(), it.second.end());
int64_t offset = result.size();
int64_t size = it.second.size();
custom_op_idx_map_[it.first] = std::make_pair(offset, size);
result += buffer;
result.Append(std::move(buffer));
}
// pad to be 16 bytes aligned
while (result.size() % 16 != 0) result += '\0';
{
std::string pad(kFbAlignment - result.size() % kFbAlignment, '\0');
result.Append(std::move(pad));
}
}

bool Translator::UpdateBufferOffsets(tflite::Model* mutable_model) {
Expand Down

0 comments on commit f6fe996

Please sign in to comment.