diff --git a/compiler/include/compiler/backend/optree/optimizer/optimizer.hpp b/compiler/include/compiler/backend/optree/optimizer/optimizer.hpp index 0cbb04e8..2e0d3afd 100644 --- a/compiler/include/compiler/backend/optree/optimizer/optimizer.hpp +++ b/compiler/include/compiler/backend/optree/optimizer/optimizer.hpp @@ -1,8 +1,8 @@ #pragma once -#include -#include +#include +#include "compiler/optree/operation.hpp" #include "compiler/optree/program.hpp" #include "compiler/backend/optree/optimizer/transform.hpp" @@ -11,16 +11,16 @@ namespace optree { namespace optimizer { class Optimizer { - std::vector transforms; - size_t iterLimit; + std::deque transforms; public: - Optimizer(); + Optimizer() = default; Optimizer(const Optimizer &) = default; Optimizer(Optimizer &&) = default; ~Optimizer() = default; - void add(const BaseTransform::Ptr &transform); + Optimizer &add(const BaseTransform::Ptr &transform); + void process(const Operation::Ptr &op) const; void process(Program &program) const; }; diff --git a/compiler/include/compiler/backend/optree/optimizer/transform.hpp b/compiler/include/compiler/backend/optree/optimizer/transform.hpp index 84e96fa8..149c11a2 100644 --- a/compiler/include/compiler/backend/optree/optimizer/transform.hpp +++ b/compiler/include/compiler/backend/optree/optimizer/transform.hpp @@ -1,5 +1,7 @@ #pragma once +#include +#include #include #include @@ -21,6 +23,7 @@ struct BaseTransform { virtual std::string_view name() const = 0; virtual bool canRun(const Operation::Ptr &op) const = 0; virtual void run(const Operation::Ptr &op, OptBuilder &builder) const = 0; + virtual bool recurse() const; }; template @@ -38,5 +41,29 @@ struct Transform : public BaseTransform { } }; +class CascadeTransform : public BaseTransform { + std::deque transforms; + std::string_view commonName; + size_t iterLimit; + + CascadeTransform(std::string_view commonName, size_t iterLimit); + CascadeTransform(const CascadeTransform &) = delete; + CascadeTransform(CascadeTransform &&) = default; + + public: + using Ptr = std::shared_ptr; + + ~CascadeTransform() override = default; + + std::string_view name() const override; + bool canRun(const Operation::Ptr &op) const override; + void run(const Operation::Ptr &op, OptBuilder &builder) const override; + bool recurse() const override; + + CascadeTransform &add(const BaseTransform::Ptr &transform); + + static Ptr make(std::string_view commonName, size_t iterLimit = 100U); +}; + } // namespace optimizer } // namespace optree diff --git a/compiler/lib/backend/optree/optimizer/optimizer.cpp b/compiler/lib/backend/optree/optimizer/optimizer.cpp index a34eaf13..d91ad0aa 100644 --- a/compiler/lib/backend/optree/optimizer/optimizer.cpp +++ b/compiler/lib/backend/optree/optimizer/optimizer.cpp @@ -1,16 +1,12 @@ #include "optimizer/optimizer.hpp" -#include -#include -#include - #include "compiler/optree/operation.hpp" #include "compiler/optree/program.hpp" #include "compiler/utils/debug.hpp" +#include "compiler/utils/helpers.hpp" #include "optimizer/opt_builder.hpp" #include "optimizer/transform.hpp" -#include "optimizer/transform_factories.hpp" using namespace optree; using namespace optree::optimizer; @@ -19,143 +15,35 @@ using dbg = utils::DebugPrinter; namespace { -class OperationSet { - std::vector data; - std::unordered_map positions; - - public: - OperationSet() { - constexpr size_t capacity = 64U; - data.reserve(capacity); - } - - OperationSet(const OperationSet &) = default; - OperationSet(OperationSet &&) = default; - ~OperationSet() = default; - - bool empty() const { - return positions.empty(); - } - - void push(const Operation::Ptr &op) { - if (positions.contains(op.get())) - return; - positions[op.get()] = data.size(); - data.emplace_back(op); - } - - Operation::Ptr pop() { - while (!data.back()) - data.pop_back(); - Operation::Ptr op = data.back(); - data.pop_back(); - positions.erase(op.get()); - while (!data.empty() && !data.back()) - data.pop_back(); - return op; - } - - void erase(const Operation::Ptr &op) { - auto it = positions.find(op.get()); - if (it == positions.end()) - return; - data[it->second].reset(); - positions.erase(it); - } - - void clear() { - data.clear(); - positions.clear(); - } -}; - -class MutationTracker { - Operation *const trackedOp; - bool updatedTag; - bool erasedTag; - - public: - MutationTracker(const MutationTracker &) = delete; - MutationTracker(MutationTracker &&) = delete; - ~MutationTracker() = default; - - explicit MutationTracker(const Operation::Ptr &trackedOp) - : trackedOp(trackedOp.get()), updatedTag(false), erasedTag(false){}; - - bool updated() const { - return updatedTag; - } - bool erased() const { - return erasedTag; - } - - void raiseUpdated(const Operation::Ptr &op) { - if (op.get() == trackedOp) - updatedTag = true; - } - void raiseErased(const Operation::Ptr &op) { - if (op.get() == trackedOp) - erasedTag = true; +void runTransform(const BaseTransform::Ptr &transform, const OptBuilder::Notifier ¬ifier, const Operation::Ptr &op) { + bool canRun = transform->canRun(op); + if (transform->recurse() || (!canRun && !transform->recurse())) { + for (const auto &childOp : utils::advanceEarly(op->body)) { + runTransform(transform, notifier, childOp); + } } -}; - -void pushToSet(const Operation::Ptr &root, OperationSet &ops) { - for (const auto &op : root->body) - pushToSet(op, ops); - ops.push(root); -} - -OptBuilder::Notifier makeNotifier(OperationSet &ops, bool &mutated, MutationTracker &tracker) { - OptBuilder::Notifier notifier; - notifier.onInsert = [&ops, &mutated](const Operation::Ptr &op) { - ops.push(op); - mutated = true; - }; - notifier.onUpdate = [&ops, &mutated, &tracker](const Operation::Ptr &op) { - ops.push(op); - mutated = true; - tracker.raiseUpdated(op); - }; - notifier.onErase = [&ops, &mutated, &tracker](const Operation::Ptr &op) { - ops.erase(op); - mutated = true; - tracker.raiseErased(op); - }; - return notifier; + if (!canRun) + return; + OptBuilder builder(notifier); + builder.setInsertPointBefore(op); + COMPILER_DEBUG(dbg::get() << "Run " << transform->name() << " on " << op->dump() << "{\n"); + transform->run(op, builder); + COMPILER_DEBUG(dbg::get() << "}\n\n"); } } // namespace -Optimizer::Optimizer() : iterLimit(100U) { +Optimizer &Optimizer::add(const BaseTransform::Ptr &transform) { + transforms.emplace_back(transform); + return *this; } -void Optimizer::add(const BaseTransform::Ptr &transform) { - transforms.emplace_back(transform); +void Optimizer::process(const Operation::Ptr &op) const { + OptBuilder::Notifier empty; + for (const auto &transform : transforms) + runTransform(transform, empty, op); } void Optimizer::process(Program &program) const { - OperationSet ops; - bool mutated = false; - size_t iter = 0; - do { - mutated = false; - ops.clear(); - pushToSet(program.root, ops); - while (!ops.empty()) { - Operation::Ptr op = ops.pop(); - MutationTracker tracker(op); - auto notifier = makeNotifier(ops, mutated, tracker); - for (const auto &transform : transforms) { - if (tracker.erased()) - break; - if (!transform->canRun(op)) - continue; - OptBuilder builder(notifier); - builder.setInsertPointBefore(op); - COMPILER_DEBUG(dbg::get() << "Run " << transform->name() << " on " << op->dump() << "{\n"); - transform->run(op, builder); - COMPILER_DEBUG(dbg::get() << "}\n\n"); - } - } - } while (mutated && ++iter < iterLimit); + process(program.root); } diff --git a/compiler/lib/backend/optree/optimizer/transform.cpp b/compiler/lib/backend/optree/optimizer/transform.cpp new file mode 100644 index 00000000..43e6faf3 --- /dev/null +++ b/compiler/lib/backend/optree/optimizer/transform.cpp @@ -0,0 +1,181 @@ +#include "optimizer/transform.hpp" + +#include +#include +#include +#include + +#include "compiler/optree/operation.hpp" +#include "compiler/utils/debug.hpp" + +#include "optimizer/opt_builder.hpp" + +using namespace optree; +using namespace optree::optimizer; + +using dbg = utils::DebugPrinter; + +namespace { + +class OperationSet { + std::vector data; + std::unordered_map positions; + + public: + OperationSet() { + constexpr size_t capacity = 64U; + data.reserve(capacity); + } + + OperationSet(const OperationSet &) = default; + OperationSet(OperationSet &&) = default; + ~OperationSet() = default; + + bool empty() const { + return positions.empty(); + } + + void push(const Operation::Ptr &op) { + if (positions.contains(op.get())) + return; + positions[op.get()] = data.size(); + data.emplace_back(op); + } + + Operation::Ptr pop() { + while (!data.back()) + data.pop_back(); + Operation::Ptr op = data.back(); + data.pop_back(); + positions.erase(op.get()); + while (!data.empty() && !data.back()) + data.pop_back(); + return op; + } + + void erase(const Operation::Ptr &op) { + auto it = positions.find(op.get()); + if (it == positions.end()) + return; + data[it->second].reset(); + positions.erase(it); + } + + void clear() { + data.clear(); + positions.clear(); + } +}; + +class MutationTracker { + Operation *const trackedOp; + bool updatedTag; + bool erasedTag; + + public: + MutationTracker(const MutationTracker &) = delete; + MutationTracker(MutationTracker &&) = delete; + ~MutationTracker() = default; + + explicit MutationTracker(const Operation::Ptr &trackedOp) + : trackedOp(trackedOp.get()), updatedTag(false), erasedTag(false){}; + + bool updated() const { + return updatedTag; + } + bool erased() const { + return erasedTag; + } + + void raiseUpdated(const Operation::Ptr &op) { + if (op.get() == trackedOp) + updatedTag = true; + } + void raiseErased(const Operation::Ptr &op) { + if (op.get() == trackedOp) + erasedTag = true; + } +}; + +void pushToSet(const Operation::Ptr &root, OperationSet &ops) { + for (const auto &op : root->body) + pushToSet(op, ops); + ops.push(root); +} + +OptBuilder::Notifier makeNotifier(OperationSet &ops, bool &mutated, MutationTracker &tracker) { + OptBuilder::Notifier notifier; + notifier.onInsert = [&ops, &mutated](const Operation::Ptr &op) { + ops.push(op); + mutated = true; + }; + notifier.onUpdate = [&ops, &mutated, &tracker](const Operation::Ptr &op) { + ops.push(op); + mutated = true; + tracker.raiseUpdated(op); + }; + notifier.onErase = [&ops, &mutated, &tracker](const Operation::Ptr &op) { + ops.erase(op); + mutated = true; + tracker.raiseErased(op); + }; + return notifier; +} + +} // namespace + +bool BaseTransform::recurse() const { + return true; +} + +CascadeTransform::CascadeTransform(std::string_view commonName, size_t iterLimit) + : commonName(commonName), iterLimit(iterLimit) { +} + +std::string_view CascadeTransform::name() const { + return commonName; +} + +bool CascadeTransform::canRun([[maybe_unused]] const Operation::Ptr &op) const { + return true; +} + +void CascadeTransform::run(const Operation::Ptr &op, [[maybe_unused]] OptBuilder &builder) const { + OperationSet ops; + bool mutated = false; + size_t iter = 0; + do { + mutated = false; + ops.clear(); + pushToSet(op, ops); + while (!ops.empty()) { + Operation::Ptr op = ops.pop(); + MutationTracker tracker(op); + auto notifier = makeNotifier(ops, mutated, tracker); + for (const auto &transform : transforms) { + if (tracker.erased()) + break; + if (!transform->canRun(op)) + continue; + OptBuilder builder(notifier); + builder.setInsertPointBefore(op); + COMPILER_DEBUG(dbg::get() << "Cascade run " << transform->name() << " on " << op->dump() << "{\n"); + transform->run(op, builder); + COMPILER_DEBUG(dbg::get() << "}\n\n"); + } + } + } while (mutated && ++iter < iterLimit); +} + +bool CascadeTransform::recurse() const { + return false; +} + +CascadeTransform &CascadeTransform::add(const BaseTransform::Ptr &transform) { + transforms.emplace_back(transform); + return *this; +} + +CascadeTransform::Ptr CascadeTransform::make(std::string_view commonName, size_t iterLimit) { + return Ptr(new CascadeTransform(commonName, iterLimit)); +} diff --git a/compiler/lib/backend/optree/optimizer/transforms/erase_unused_functions.cpp b/compiler/lib/backend/optree/optimizer/transforms/erase_unused_functions.cpp index 02822d87..dd37eb75 100644 --- a/compiler/lib/backend/optree/optimizer/transforms/erase_unused_functions.cpp +++ b/compiler/lib/backend/optree/optimizer/transforms/erase_unused_functions.cpp @@ -62,6 +62,10 @@ struct EraseUnusedFunctions : public Transform { } } } + + bool recurse() const override { + return false; + } }; } // namespace diff --git a/compiler/lib/cli/compiler.cpp b/compiler/lib/cli/compiler.cpp index 348bcacd..3d7c1c9c 100644 --- a/compiler/lib/cli/compiler.cpp +++ b/compiler/lib/cli/compiler.cpp @@ -1,4 +1,5 @@ #include "compiler.hpp" +#include "compiler/backend/optree/optimizer/transform.hpp" #include #include @@ -300,8 +301,10 @@ int Compiler::runOptreeOptimizer() { Timer timer; try { Optimizer optimizer; - optimizer.add(createEraseUnusedOps()); - optimizer.add(createFoldConstants()); + auto canonicalizer = CascadeTransform::make("Canonicalizer"); + canonicalizer->add(createEraseUnusedOps()); + canonicalizer->add(createFoldConstants()); + optimizer.add(canonicalizer); optimizer.add(createEraseUnusedFunctions()); timer.start(); optimizer.process(program); diff --git a/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp b/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp index 148ce61f..a63879f2 100644 --- a/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp +++ b/compiler/tests/backend/optree/optimizer/erase_unused_ops.cpp @@ -1,8 +1,7 @@ #include -#include - #include "compiler/backend/optree/optimizer/optimizer.hpp" +#include "compiler/backend/optree/optimizer/transform.hpp" #include "compiler/backend/optree/optimizer/transform_factories.hpp" #include "compiler/optree/adaptors.hpp" @@ -13,7 +12,9 @@ using namespace optree::optimizer; class EraseUnusedOpsTest : public TransformTestBase { virtual void setupOptimizer(Optimizer &opt) const override { - opt.add(createEraseUnusedOps()); + auto transform = CascadeTransform::make("EraseUnusedOpsTest"); + transform->add(createEraseUnusedOps()); + opt.add(transform); } public: diff --git a/compiler/tests/backend/optree/optimizer/fold_constants.cpp b/compiler/tests/backend/optree/optimizer/fold_constants.cpp index 62353f54..ffbf1665 100644 --- a/compiler/tests/backend/optree/optimizer/fold_constants.cpp +++ b/compiler/tests/backend/optree/optimizer/fold_constants.cpp @@ -1,8 +1,7 @@ #include -#include - #include "compiler/backend/optree/optimizer/optimizer.hpp" +#include "compiler/backend/optree/optimizer/transform.hpp" #include "compiler/backend/optree/optimizer/transform_factories.hpp" #include "compiler/optree/adaptors.hpp" @@ -13,7 +12,9 @@ using namespace optree::optimizer; class FoldConstantsTest : public TransformTestBase { virtual void setupOptimizer(Optimizer &opt) const override { - opt.add(createFoldConstants()); + auto transform = CascadeTransform::make("FoldConstantsTest"); + transform->add(createFoldConstants()); + opt.add(transform); } public: diff --git a/compiler/tests/backend/optree/optimizer/fold_control_flow_ops.cpp b/compiler/tests/backend/optree/optimizer/fold_control_flow_ops.cpp index 3e08f6c7..7617d7bd 100644 --- a/compiler/tests/backend/optree/optimizer/fold_control_flow_ops.cpp +++ b/compiler/tests/backend/optree/optimizer/fold_control_flow_ops.cpp @@ -1,7 +1,5 @@ #include -#include - #include "compiler/backend/optree/optimizer/optimizer.hpp" #include "compiler/backend/optree/optimizer/transform_factories.hpp" #include "compiler/optree/adaptors.hpp" @@ -13,8 +11,10 @@ using namespace optree::optimizer; class FoldControlFlowTest : public TransformTestBase { virtual void setupOptimizer(Optimizer &opt) const override { - opt.add(createFoldConstants()); - opt.add(createFoldControlFlowOps()); + auto transform = CascadeTransform::make("FoldControlFlowTest"); + transform->add(createFoldControlFlowOps()); + transform->add(createFoldConstants()); + opt.add(transform); } public: