Skip to content

Commit

Permalink
adding support for Composite ops in TFLite flatbuffer schema
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627588497
  • Loading branch information
turbotoribio authored and tensorflower-gardener committed Apr 24, 2024
1 parent c8ee776 commit 7c5bc63
Show file tree
Hide file tree
Showing 13 changed files with 497 additions and 15 deletions.
119 changes: 119 additions & 0 deletions tensorflow/compiler/mlir/lite/flatbuffer_export.cc
Expand Up @@ -20,6 +20,7 @@ limitations under the License.
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <cstdio>
#include <cstring>
#include <iterator>
#include <limits>
Expand Down Expand Up @@ -771,6 +772,11 @@ class Translator {
const std::vector<int32_t>& results,
mlir::VhloToStablehloTypeConverter& vhlo_type_converter);

std::optional<BufferOffset<tflite::Operator>> BuildVhloCompositeV1Op(
mlir::vhlo::CompositeOpV1 composite_op,
const std::vector<int32_t>& operands, const std::vector<int32_t>& results,
std::string op_name);

std::optional<BufferOffset<tflite::Operator>> BuildVhloScatterV1Op(
mlir::vhlo::ScatterOpV1 scatter_op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results,
Expand Down Expand Up @@ -1497,6 +1503,43 @@ uint32_t Translator::GetOpcodeIndex(const std::string& op_name,
return it.first->second;
}

void CreateFlexbufferVector(
const std::unique_ptr<flexbuffers::Builder>& flex_builder,
std::string& name, const mlir::Attribute& attr) {
auto start = flex_builder->StartVector(name.c_str());
auto array = attr.cast<mlir::ArrayAttr>().getValue();

for (int i = 0; i < array.size(); i++) {
if (llvm::isa<mlir::BoolAttr>(array[i])) {
flex_builder->Bool(name.c_str(),
array[i].cast<mlir::BoolAttr>().getValue());
} else if (llvm::isa<mlir::StringAttr>(attr)) {
flex_builder->String(name.c_str(),
array[i].cast<mlir::StringAttr>().getValue().str());
} else if (llvm::isa<mlir::vhlo::BooleanV1Attr>(array[i])) {
flex_builder->Bool(name.c_str(),
array[i].cast<mlir::vhlo::BooleanV1Attr>().getValue());
} else if (llvm::isa<mlir::vhlo::StringV1Attr>(array[i])) {
flex_builder->String(
name.c_str(),
array[i].cast<mlir::vhlo::StringV1Attr>().getValue().str());
} else if (llvm::isa<mlir::vhlo::IntegerV1Attr>(array[i])) {
flex_builder->Int(
name.c_str(),
array[i].cast<mlir::vhlo::IntegerV1Attr>().getValue().getSExtValue());
} else if (llvm::isa<mlir::vhlo::FloatV1Attr>(array[i])) {
flex_builder->Float(
name.c_str(),
array[i].cast<mlir::vhlo::FloatV1Attr>().getValue().convertToFloat());

} else if (llvm::isa<mlir::vhlo::ArrayV1Attr>(array[i])) {
CreateFlexbufferVector(flex_builder, name, array[i]);
}
}

flex_builder->EndVector(start, /*typed=*/false, /*fixed=*/false);
}

std::optional<BufferOffset<tflite::Operator>>
Translator::BuildStablehloOperatorwithoutOptions(
Operation* inst, const std::vector<int32_t>& operands,
Expand Down Expand Up @@ -1573,6 +1616,78 @@ Translator::BuildStablehloGatherOp(mlir::stablehlo::GatherOp gather_op,
tflite::BuiltinOptions2_StablehloGatherOptions, gather_option.Union());
}

std::optional<BufferOffset<tflite::Operator>>
Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op,
const std::vector<int32_t>& operands,
const std::vector<int32_t>& results,
std::string op_name) {
uint32_t opcode_index =
GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_COMPOSITE);

int32_t api_version = composite_op.getVersion()
.cast<mlir::vhlo::IntegerV1Attr>()
.getValue()
.getSExtValue();

auto name = builder_.CreateString(
composite_op.getName().cast<mlir::vhlo::StringV1Attr>().getValue().str());

auto composite_attributes = composite_op.getCompositeAttributes()
.cast<mlir::vhlo::DictionaryV1Attr>();
auto flex_builder = std::make_unique<flexbuffers::Builder>();
size_t map_start = flex_builder->StartMap();

for (auto namedAttr : composite_attributes.getValue()) {
auto name =
namedAttr.first.cast<mlir::vhlo::StringV1Attr>().getValue().str();
auto attr = namedAttr.second;

if (llvm::isa<mlir::BoolAttr>(attr))
flex_builder->Bool(name.c_str(), attr.cast<mlir::BoolAttr>().getValue());
else if (llvm::isa<mlir::StringAttr>(attr))
flex_builder->String(name.c_str(),
attr.cast<mlir::StringAttr>().getValue().str());
else if (llvm::isa<mlir::vhlo::BooleanV1Attr>(attr))
flex_builder->Bool(name.c_str(),
attr.cast<mlir::vhlo::BooleanV1Attr>().getValue());
else if (llvm::isa<mlir::vhlo::StringV1Attr>(attr))
flex_builder->String(
name.c_str(), attr.cast<mlir::vhlo::StringV1Attr>().getValue().str());
else if (llvm::isa<mlir::vhlo::IntegerV1Attr>(attr))
flex_builder->Int(
name.c_str(),
attr.cast<mlir::vhlo::IntegerV1Attr>().getValue().getSExtValue());
else if (llvm::isa<mlir::vhlo::FloatV1Attr>(attr))
flex_builder->Float(
name.c_str(),
attr.cast<mlir::vhlo::FloatV1Attr>().getValue().convertToFloat());
}

flex_builder->EndMap(map_start);
flex_builder->Finish();

int32_t decomposition_subgraph_index =
subgraph_index_map_[composite_op.getDecomposition()
.cast<mlir::vhlo::StringV1Attr>()
.getValue()
.str()];

auto composite_option = tflite::CreateStableHLOCompositeOptions(
builder_, name, decomposition_subgraph_index,
builder_.CreateVector(flex_builder->GetBuffer()),
tflite::CustomOptionsFormat_FLEXBUFFERS, api_version);

return tflite::CreateOperator(
builder_, opcode_index, builder_.CreateVector(operands),
builder_.CreateVector(results), tflite::BuiltinOptions_NONE,
/*builtin_options=*/0, /*custom_options=*/0,
tflite::CustomOptionsFormat_FLEXBUFFERS,
/*mutating_variable_inputs=*/0, /*intermediates=*/0,
/*large_custom_options_offset=*/0, /*large_custom_options_size=*/0,
tflite::BuiltinOptions2_StableHLOCompositeOptions,
composite_option.Union());
}

std::optional<BufferOffset<tflite::Operator>>
Translator::BuildStablehloScatterOp(mlir::stablehlo::ScatterOp scatter_op,
const std::vector<int32_t>& operands,
Expand Down Expand Up @@ -2036,6 +2151,10 @@ std::optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
if (auto vhlo_op = llvm::dyn_cast<mlir::vhlo::PadOpV1>(inst)) {
return BuildVhloPadV1Op(vhlo_op, operands, results, vhlo_type_converter);
}
if (auto vhlo_op = llvm::dyn_cast<mlir::vhlo::CompositeOpV1>(inst)) {
return BuildVhloCompositeV1Op(vhlo_op, operands, results,
inst->getName().getStringRef().str());
}
// for ops don't have kernels, only serialize when conversion is set to
// true
if (convert_stablehlo_) {
Expand Down
14 changes: 14 additions & 0 deletions tensorflow/compiler/mlir/lite/flatbuffer_import.cc
Expand Up @@ -770,6 +770,20 @@ StatusOr<Operation*> ConvertOp(
mlir::BuiltinOptionsToAttributes(op.builtin_options, builder, attrs);
mlir::BuiltinOptions2ToAttributes(op.builtin_options_2, builder, attrs);
}

if (builtin_code == tflite::BuiltinOperator_STABLEHLO_COMPOSITE) {
auto composite_options = op.builtin_options_2.AsStableHLOCompositeOptions();
std::string decomposition = "";
if (composite_options->decomposition_subgraph_index > -1) {
decomposition =
func_names.at(composite_options->decomposition_subgraph_index);
}

attrs.emplace_back(builder.getNamedAttr(
"decomposition",
mlir::vhlo::StringV1Attr::get(builder.getContext(), decomposition)));
}

op_state.addAttributes(attrs);

// Handle the conversion from subgraph index to functions for If and While. We
Expand Down
99 changes: 97 additions & 2 deletions tensorflow/compiler/mlir/lite/flatbuffer_operator.cc
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <cstdint>
#include <optional>
#include <string>
#include <utility>
#include <vector>

#include "absl/strings/str_cat.h"
Expand Down Expand Up @@ -301,6 +302,19 @@ static mlir::Attribute BuildVhloArrayV1Attr(std::vector<mlir::Attribute> value,
return mlir::vhlo::ArrayV1Attr::get(builder.getContext(), value);
}

static mlir::Attribute BuildVhloDictionaryV1Attr(
std::vector<std::pair<mlir::Attribute, mlir::Attribute>> value,
mlir::Builder builder) {
return mlir::vhlo::DictionaryV1Attr::get(builder.getContext(), value);
}

static mlir::Attribute BuildVhloFloatV1Attr(::llvm::APFloat value,
mlir::Type type,
mlir::Builder builder) {
return mlir::vhlo::FloatV1Attr::get(builder.getContext(), type,
std::move(value));
}

static mlir::Attribute BuildRankedTensorAttr(std::vector<int64_t> shape,
std::vector<bool> value,
mlir::Builder builder) {
Expand Down Expand Up @@ -416,6 +430,34 @@ static mlir::Attribute BuildTFL_MirrorPaddingAttr(tflite::MirrorPadMode value,
return mlir::TFL::MirrorPaddingTypeAttr::get(builder.getContext(), padding);
}

static std::vector<mlir::Attribute> BuildAttributeVectorFromFlatbuffer(
flexbuffers::Vector flatbuffer_vector, mlir::Builder builder) {
std::vector<mlir::Attribute> mlir_vector;

for (int i = 0; i < flatbuffer_vector.size(); ++i) {
auto value = flatbuffer_vector[i];

if (value.IsBool()) {
mlir_vector.push_back(BuildVhloBooleanV1Attr(value.AsBool(), builder));
} else if (value.IsString()) {
mlir_vector.push_back(
BuildVhloStringV1Attr(value.AsString().str(), builder));
} else if (value.IsInt()) {
mlir_vector.push_back(BuildVhloIntV1Attr(value.AsInt64(), builder));
} else if (value.IsFloat()) {
mlir_vector.push_back(BuildVhloFloatV1Attr(llvm::APFloat(value.AsFloat()),
mlir::Float32Type(), builder));
} else if (value.IsVector()) {
std::vector<mlir::Attribute> nested_mlir_vector =
BuildAttributeVectorFromFlatbuffer(value.AsVector(), builder);
mlir_vector.push_back(
BuildVhloArrayV1Attr(std::move(nested_mlir_vector), builder));
}
}

return mlir_vector;
}

static mlir::Attribute BuildTFL_PaddingAttr(tflite::Padding value,
mlir::Builder builder) {
const char* option_name = tflite::EnumNamePadding(value);
Expand Down Expand Up @@ -613,8 +655,6 @@ void BuiltinOptions2ToAttributesManual(
bool has_side_effect_set = false;
const flexbuffers::Map& computation_map =
flexbuffers::GetRoot(op->custom_attributes).AsMap();
std::vector<mlir::Attribute> symbol_vec;
symbol_vec.reserve(computation_map.size());
const auto& keys = computation_map.Keys();
for (size_t i = 0; i < keys.size(); ++i) {
const auto key = keys[i].AsKey();
Expand All @@ -638,6 +678,61 @@ void BuiltinOptions2ToAttributesManual(
"has_side_effect", BuildVhloBooleanV1Attr(false, builder)));
return;
}
if (const auto* op = op_union.AsStableHLOCompositeOptions()) {
attributes.emplace_back(
builder.getNamedAttr("name", BuildVhloStringV1Attr(op->name, builder)));

attributes.emplace_back(builder.getNamedAttr(
"version", BuildVhloIntV1Attr(op->version, builder)));

auto composite_attribute_pairs =
std::vector<std::pair<mlir::Attribute, mlir::Attribute>>();

auto composite_attributes =
flexbuffers::GetRoot(op->composite_attributes).AsMap();

const auto& keys = composite_attributes.Keys();
for (size_t i = 0; i < keys.size(); ++i) {
const auto key = keys[i].AsKey();
const auto& value = composite_attributes[key];

std::pair<mlir::Attribute, mlir::Attribute> composite_attribute_pair;
composite_attribute_pair.first = BuildVhloStringV1Attr(key, builder);

if (value.IsBool()) {
composite_attribute_pair.second =
BuildVhloBooleanV1Attr(value.AsBool(), builder);
}
if (value.IsString()) {
composite_attribute_pair.second =
BuildVhloStringV1Attr(value.AsString().str(), builder);
}
if (value.IsInt()) {
composite_attribute_pair.second =
BuildVhloIntV1Attr(value.AsInt64(), builder);
}
if (value.IsFloat()) {
composite_attribute_pair.second = BuildVhloFloatV1Attr(
llvm::APFloat(value.AsFloat()), mlir::Float32Type(), builder);
}

if (value.IsVector()) {
std::vector<mlir::Attribute> mlir_vector =
BuildAttributeVectorFromFlatbuffer(value.AsVector(), builder);

composite_attribute_pair.second =
BuildVhloArrayV1Attr(std::move(mlir_vector), builder);
}

composite_attribute_pairs.emplace_back(composite_attribute_pair);
}

attributes.emplace_back(builder.getNamedAttr(
"composite_attributes",
BuildVhloDictionaryV1Attr(std::move(composite_attribute_pairs),
builder)));
return;
}
if (const auto* op = op_union.AsStablehloPadOptions()) {
std::vector<int64_t> shape = {
static_cast<int64_t>(op->edge_padding_low.size())};
Expand Down
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "stablehlo/transforms/Passes.h" // from @stablehlo
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/passes.h"
#include "tensorflow/lite/core/macros.h"

#define DEBUG_TYPE "compat-passes"

Expand Down Expand Up @@ -274,11 +275,11 @@ struct LegalizeStablehloToVhloPass
LegalizeStablehloToVhloPass> {
void runOnOperation() override {
ModuleOp module = getOperation();
std::string target_version = "0.14.0";
std::string target_version = tflite_supported_stablehlo_version;
VhloToStablehloTypeConverter to_builtin_converter;

// StableHLO --> VHLO (allow funcs)
// VHLO -> Downgrade to 0.14.0
// VHLO -> Downgrade to 0.19.0 / tflite_supported_stablehlo_version
// VHLO Tensor --> Builtin Tensor
// Remove cast(tensor->vhlo) -> cast(vhlo->tensor) pattern
if (failed(ApplyStablehloToVhloPatterns(module,
Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/BUILD
Expand Up @@ -31,6 +31,7 @@ filegroup(
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
"//tensorflow/compiler/mlir/lite:flatbuffer_translate",
"//tensorflow/compiler/mlir/lite:json_to_flatbuffer",
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
"@llvm-project//llvm:FileCheck",
],
)
Expand Down
@@ -0,0 +1,27 @@
// RUN: tf_tfl_translate --enable-hlo-to-tf-conversion --input-mlir %s -o - | flatbuffer_translate --tflite-flatbuffer-to-mlir - -o - | FileCheck %s --check-prefix=CHECK-ROUNDTRIP


module {
func.func public @main( %arg0: tensor<i64>) -> tensor<i64> {
%0 = func.call @test_add_roundtrip(%arg0) : (tensor<i64>) -> tensor<i64>

return %0 : tensor<i64>
}


// CHECK-LABEL: func.func private @test_add_roundtrip
func.func private @test_add_roundtrip(%arg0: tensor<i64>) -> tensor<i64> {
// CHECK-ROUNDTRIP: %0 = stablehlo.composite "stablehlo.add_n" %arg0 {composite_attributes = {test_bool = false, test_int = 2 : i64, test_string = "test"}, decomposition = @add_n.impl} : (tensor<i64>) -> tensor<i64>
%0 = stablehlo.composite "stablehlo.add_n" %arg0 { composite_attributes = { test_int = 2 : i64, test_bool = 0 : i1, test_string = "test"}, decomposition = @add_n.impl } : (tensor<i64>) -> tensor<i64>
return %0 : tensor<i64>
}
func.func private @add_n.impl(%arg0: tensor<i64>) -> tensor<i64> {
%0 = stablehlo.constant dense<2> : tensor<i64>
%1 = stablehlo.add %arg0, %0 : tensor<i64>
return %1 : tensor<i64>
}




}
1 change: 1 addition & 0 deletions tensorflow/lite/builtin_ops.h
Expand Up @@ -233,6 +233,7 @@ typedef enum {
kTfLiteBuiltinDilate = 203,
kTfLiteBuiltinStablehloRngBitGenerator = 204,
kTfLiteBuiltinReduceWindow = 205,
kTfLiteBuiltinStablehloComposite = 206,
} TfLiteBuiltinOperator;

#ifdef __cplusplus
Expand Down
1 change: 1 addition & 0 deletions tensorflow/lite/core/api/flatbuffer_conversions.cc
Expand Up @@ -925,6 +925,7 @@ TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_STABLEHLO_SLICE:
case BuiltinOperator_STABLEHLO_BROADCAST_IN_DIM:
case BuiltinOperator_STABLEHLO_CONVOLUTION:
case BuiltinOperator_STABLEHLO_COMPOSITE:
case BuiltinOperator_STABLEHLO_LOGISTIC:
case BuiltinOperator_STABLEHLO_ADD:
case BuiltinOperator_STABLEHLO_DIVIDE:
Expand Down
4 changes: 4 additions & 0 deletions tensorflow/lite/core/kernels/builtin_op_kernels.h
Expand Up @@ -321,6 +321,10 @@ Register_STABLEHLO_TRANSPOSE(); // WARNING: not implemented, using this
TfLiteRegistration* Register_DILATE();

TfLiteRegistration* Register_REDUCE_WINDOW();

TfLiteRegistration*
Register_STABLEHLO_COMPOSITE(); // WARNING: not implemented, using this
// op will crash the runtime
} // namespace builtin
} // namespace ops
} // namespace tflite
Expand Down

0 comments on commit 7c5bc63

Please sign in to comment.