Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#pragma once

#include <cstddef>
#include <vector>
#include <deque>

#include "compiler/optree/operation.hpp"
#include "compiler/optree/program.hpp"

#include "compiler/backend/optree/optimizer/transform.hpp"
Expand All @@ -11,16 +11,16 @@ namespace optree {
namespace optimizer {

class Optimizer {
std::vector<BaseTransform::Ptr> transforms;
size_t iterLimit;
std::deque<BaseTransform::Ptr> transforms;

public:
Optimizer();
Optimizer() = default;
Optimizer(const Optimizer &) = default;
Optimizer(Optimizer &&) = default;
~Optimizer() = default;

void add(const BaseTransform::Ptr &transform);
Optimizer &add(const BaseTransform::Ptr &transform);
void process(const Operation::Ptr &op) const;
void process(Program &program) const;
};

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once

#include <cstddef>
#include <deque>
#include <memory>
#include <string_view>

Expand All @@ -21,6 +23,7 @@ struct BaseTransform {
virtual std::string_view name() const = 0;
virtual bool canRun(const Operation::Ptr &op) const = 0;
virtual void run(const Operation::Ptr &op, OptBuilder &builder) const = 0;
virtual bool recurse() const;
};

template <typename... AdaptorTypes>
Expand All @@ -38,5 +41,29 @@ struct Transform : public BaseTransform {
}
};

class CascadeTransform : public BaseTransform {
std::deque<BaseTransform::Ptr> transforms;
std::string_view commonName;
size_t iterLimit;

CascadeTransform(std::string_view commonName, size_t iterLimit);
CascadeTransform(const CascadeTransform &) = delete;
CascadeTransform(CascadeTransform &&) = default;

public:
using Ptr = std::shared_ptr<CascadeTransform>;

~CascadeTransform() override = default;

std::string_view name() const override;
bool canRun(const Operation::Ptr &op) const override;
void run(const Operation::Ptr &op, OptBuilder &builder) const override;
bool recurse() const override;

CascadeTransform &add(const BaseTransform::Ptr &transform);

static Ptr make(std::string_view commonName, size_t iterLimit = 100U);
};

} // namespace optimizer
} // namespace optree
156 changes: 22 additions & 134 deletions compiler/lib/backend/optree/optimizer/optimizer.cpp
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
#include "optimizer/optimizer.hpp"

#include <cstddef>
#include <unordered_map>
#include <vector>

#include "compiler/optree/operation.hpp"
#include "compiler/optree/program.hpp"
#include "compiler/utils/debug.hpp"
#include "compiler/utils/helpers.hpp"

#include "optimizer/opt_builder.hpp"
#include "optimizer/transform.hpp"
#include "optimizer/transform_factories.hpp"

using namespace optree;
using namespace optree::optimizer;
Expand All @@ -19,143 +15,35 @@ using dbg = utils::DebugPrinter;

namespace {

class OperationSet {
std::vector<Operation::Ptr> data;
std::unordered_map<const Operation *, size_t> positions;

public:
OperationSet() {
constexpr size_t capacity = 64U;
data.reserve(capacity);
}

OperationSet(const OperationSet &) = default;
OperationSet(OperationSet &&) = default;
~OperationSet() = default;

bool empty() const {
return positions.empty();
}

void push(const Operation::Ptr &op) {
if (positions.contains(op.get()))
return;
positions[op.get()] = data.size();
data.emplace_back(op);
}

Operation::Ptr pop() {
while (!data.back())
data.pop_back();
Operation::Ptr op = data.back();
data.pop_back();
positions.erase(op.get());
while (!data.empty() && !data.back())
data.pop_back();
return op;
}

void erase(const Operation::Ptr &op) {
auto it = positions.find(op.get());
if (it == positions.end())
return;
data[it->second].reset();
positions.erase(it);
}

void clear() {
data.clear();
positions.clear();
}
};

class MutationTracker {
Operation *const trackedOp;
bool updatedTag;
bool erasedTag;

public:
MutationTracker(const MutationTracker &) = delete;
MutationTracker(MutationTracker &&) = delete;
~MutationTracker() = default;

explicit MutationTracker(const Operation::Ptr &trackedOp)
: trackedOp(trackedOp.get()), updatedTag(false), erasedTag(false){};

bool updated() const {
return updatedTag;
}
bool erased() const {
return erasedTag;
}

void raiseUpdated(const Operation::Ptr &op) {
if (op.get() == trackedOp)
updatedTag = true;
}
void raiseErased(const Operation::Ptr &op) {
if (op.get() == trackedOp)
erasedTag = true;
void runTransform(const BaseTransform::Ptr &transform, const OptBuilder::Notifier &notifier, const Operation::Ptr &op) {
bool canRun = transform->canRun(op);
if (transform->recurse() || (!canRun && !transform->recurse())) {
for (const auto &childOp : utils::advanceEarly(op->body)) {
runTransform(transform, notifier, childOp);
}
}
};

void pushToSet(const Operation::Ptr &root, OperationSet &ops) {
for (const auto &op : root->body)
pushToSet(op, ops);
ops.push(root);
}

OptBuilder::Notifier makeNotifier(OperationSet &ops, bool &mutated, MutationTracker &tracker) {
OptBuilder::Notifier notifier;
notifier.onInsert = [&ops, &mutated](const Operation::Ptr &op) {
ops.push(op);
mutated = true;
};
notifier.onUpdate = [&ops, &mutated, &tracker](const Operation::Ptr &op) {
ops.push(op);
mutated = true;
tracker.raiseUpdated(op);
};
notifier.onErase = [&ops, &mutated, &tracker](const Operation::Ptr &op) {
ops.erase(op);
mutated = true;
tracker.raiseErased(op);
};
return notifier;
if (!canRun)
return;
OptBuilder builder(notifier);
builder.setInsertPointBefore(op);
COMPILER_DEBUG(dbg::get() << "Run " << transform->name() << " on " << op->dump() << "{\n");
transform->run(op, builder);
COMPILER_DEBUG(dbg::get() << "}\n\n");
}

} // namespace

Optimizer::Optimizer() : iterLimit(100U) {
Optimizer &Optimizer::add(const BaseTransform::Ptr &transform) {
transforms.emplace_back(transform);
return *this;
}

void Optimizer::add(const BaseTransform::Ptr &transform) {
transforms.emplace_back(transform);
void Optimizer::process(const Operation::Ptr &op) const {
OptBuilder::Notifier empty;
for (const auto &transform : transforms)
runTransform(transform, empty, op);
}

void Optimizer::process(Program &program) const {
OperationSet ops;
bool mutated = false;
size_t iter = 0;
do {
mutated = false;
ops.clear();
pushToSet(program.root, ops);
while (!ops.empty()) {
Operation::Ptr op = ops.pop();
MutationTracker tracker(op);
auto notifier = makeNotifier(ops, mutated, tracker);
for (const auto &transform : transforms) {
if (tracker.erased())
break;
if (!transform->canRun(op))
continue;
OptBuilder builder(notifier);
builder.setInsertPointBefore(op);
COMPILER_DEBUG(dbg::get() << "Run " << transform->name() << " on " << op->dump() << "{\n");
transform->run(op, builder);
COMPILER_DEBUG(dbg::get() << "}\n\n");
}
}
} while (mutated && ++iter < iterLimit);
process(program.root);
}
Loading