Skip to content

Commit

Permalink
ModuleOp pass try
Browse files Browse the repository at this point in the history
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
  • Loading branch information
yongtang committed May 30, 2020
1 parent 6e16c94 commit a01f280
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 57 deletions.
1 change: 1 addition & 0 deletions tensorflow_io/core/BUILD
Expand Up @@ -38,6 +38,7 @@ cc_library(
"@llvm-project//mlir:Dialect",
"@llvm-project//mlir:DerivedAttributeOpInterface",
"@llvm-project//mlir:InferTypeOpInterface",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
Expand Down
96 changes: 39 additions & 57 deletions tensorflow_io/core/kernels/io_optimization.cc
Expand Up @@ -23,77 +23,59 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"

#include "mlir/InitAllDialects.h"

static bool foo_dialect_registration_once = []() {
mlir::registerAllDialects();
return true;
}();

namespace tensorflow {
namespace io {
namespace {

class AudioOptimizationPass
: public mlir::PassWrapper<AudioOptimizationPass, mlir::FunctionPass> {
: public mlir::PassWrapper<AudioOptimizationPass, mlir::OperationPass<mlir::ModuleOp>> {
public:
void runOnFunction() override {
std::cerr << "XXXXX runOnFunction() ENTER XXXXX" << std::endl;

auto f = getFunction();
f.walk([&](mlir::Operation *op) {
if (op->getName().getStringRef().str() == "tf.IO>AudioEncodeWAV") {
std::cerr << "XXXXX runOnFunction() NAME XXXXX: " << op->getName().getStringRef().str() << std::endl;

std::cerr << "XXXXX runOnFunction() op->getNumOperands() XXXXX: " << op->getNumOperands() << std::endl;
for (auto attr: op->getAttrs()) {
std::cerr << "XXXXX runOnFunction() ATTR XXXXX: " << attr.first.str() << std::endl;
if (attr.first.str() == "codec") {
std::cerr << "XXXXX runOnFunction() ATTR2 XXXXX: " << attr.second.isa<mlir::StringAttr>() << std::endl;
std::cerr << "XXXXX runOnFunction() ATTR3 XXXXX: " << attr.second.dyn_cast<mlir::StringAttr>().getValue().str() << std::endl;
}
}
}
});

/*
mlir::ConversionTarget target(getContext());
target.addDynamicallyLegalDialect<mlir::TF::TensorFlowDialect>(llvm::Optional<mlir::ConversionTarget::DynamicLegalityCallbackFn>([](mlir::Operation *op) {
std::cerr << "XXXXX DynamicallyLegal() NAME XXXXX: " << op->getName().getStringRef().str() << std::endl;
std::cerr << "XXXXX DynamicallyLegal() DIALECT XXXXX: " << op->getName().getDialect().str() << std::endl;
return false;
}));
mlir::OwningRewritePatternList patterns;
patterns.insert<OptimizingAudioOp>(&getContext());
if (failed(applyPartialConversion(getFunction(), target, patterns))) {
signalPassFailure();
}
*/
// Define the dialects that are legal targets.
//target.addLegalDialect<AffineDialect, StandardOpsDialect>();

// Define the Foo dialect as Illegal, so all operatsions are converted.
// Explicitly mark the Foo operations, `foo.print`, as `legal`.
//target.addIllegalDialect<foo::FooDialect>();
//target.addLegalOp<foo::PrintOp>();

// Provide the set of patterns that will lower the Foo operations.
//OwningRewritePatternList patterns;
//patterns.insert<LoweringConstOp, LoweringReturnOp>(&getContext());

// Signal failure if any `illegal` operations were not converted
// successfully.
//if (failed(applyPartialConversion(getFunction(), target, patterns))) {
// signalPassFailure();
//}

void runOnOperation() override {
std::cerr << "XXXXX runOnOperation() ENTER XXXXX" << std::endl;
mlir::ModuleOp module = getOperation();
mlir::MLIRContext* context = module.getContext();

std::cerr << "XXXXX module XXXXX: " << module.getName()->str() << std::endl;

//auto attr = mlir::StringAttr::get("ffmpeg", context);
for (auto function : module.getOps<mlir::FuncOp>()) {
std::cerr << "XXXXX function XXXXX: " << function.getOperation()->getName().getStringRef().str() << std::endl;
//if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();

//llvm::SmallVector<std::string, 4> var_handle_shared_names;
//PromoteVarHandlesToArguments(function, /*add_validation=*/false,
// &var_handle_shared_names);

// Add resource names for each `tf.VarHandleOp` that were promoted to
// resource arguments.
//const int var_handle_args_offset =
// function.getNumArguments() - var_handle_shared_names.size();
//for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
// function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
// kResourceNameArgAttr,
// StringAttr::get(var_name_and_index.value(), context));
}

//op->setAttr("codec", op->getAttr("codec"));

std::cerr << "XXXXX runOnFunction() EXIT XXXXX" << std::endl;
std::cerr << "XXXXX runOnOperation() EXIT XXXXX" << std::endl;

}
};

std::unique_ptr<mlir::Pass> createAudioOptimizationPass() {
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createAudioOptimizationPass() {
return std::make_unique<AudioOptimizationPass>();
}
//std::unique_ptr<mlir::Pass> createAudioOptimizationPass() {
// return std::make_unique<AudioOptimizationPass>();
//}

class MlirIOGraphOptimizationPass : public ::tensorflow::MlirOptimizationPass {
public:
Expand Down

0 comments on commit a01f280

Please sign in to comment.