Skip to content

Commit

Permalink
Merge pull request #7 from gongshaotian/drr
Browse files Browse the repository at this point in the history
[DRR] Replace 'weak_ptr' with pointer
  • Loading branch information
yuanlehome committed Aug 9, 2023
2 parents 727a5c1 + 2d4df1a commit 7279d74
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 50 deletions.
13 changes: 6 additions & 7 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

#include "paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h"

#include <glog/logging.h>
#include "paddle/ir/pattern_rewrite/drr/pattern_graph.h"

namespace ir {
Expand Down Expand Up @@ -78,27 +77,27 @@ void DrrPatternContext::RequireEqual(const TensorShape& first,
}

void Op::operator()(const Tensor& arg, const Tensor* out) const {
std::vector<std::weak_ptr<const Tensor>> inputs{arg.shared_from_this()};
std::vector<std::weak_ptr<const Tensor>> outputs{out->shared_from_this()};
std::vector<const Tensor*> inputs{&arg};
std::vector<const Tensor*> outputs{out};
pattern_graph_->AddOpCall(
std::make_shared<OpCall>(shared_from_this(), inputs, outputs));
}

Tensor& Op::operator()(const Tensor& arg) const {
std::vector<std::weak_ptr<const Tensor>> inputs{arg.shared_from_this()};
std::vector<const Tensor*> inputs{&arg};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
"tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_)));
std::vector<std::weak_ptr<const Tensor>> outputs{out.shared_from_this()};
std::vector<const Tensor*> outputs{&out};
pattern_graph_->AddOpCall(
std::make_shared<OpCall>(shared_from_this(), inputs, outputs));
return out;
}

Tensor& Op::operator()() const {
std::vector<std::weak_ptr<const Tensor>> inputs{};
std::vector<const Tensor*> inputs{};
auto& out = pattern_graph_->AddTmpTensor(std::shared_ptr<Tensor>(new Tensor(
"tmp_" + op_type_name_ + "_" + std::to_string(count++), pattern_graph_)));
std::vector<std::weak_ptr<const Tensor>> outputs{out.shared_from_this()};
std::vector<const Tensor*> outputs{&out};
pattern_graph_->AddOpCall(
std::make_shared<OpCall>(shared_from_this(), inputs, outputs));
return out;
Expand Down
32 changes: 16 additions & 16 deletions paddle/ir/pattern_rewrite/drr/api/drr_pattern_context.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,20 @@ class Tensor : public std::enable_shared_from_this<Tensor> {

void set_id(const id_type& id) { tensor_id_ = id; }

std::weak_ptr<OpCall> producer() const { return producer_; }
OpCall* producer() const { return producer_; }

void set_producer(std::weak_ptr<OpCall> producer) { producer_ = producer; }
void set_producer(OpCall* producer) { producer_ = producer; }

const std::vector<std::weak_ptr<const OpCall>>& consumers() const {
const std::vector<const OpCall*>& consumers() const {
return consumers_;
}

void set_consumables(
const std::vector<std::weak_ptr<const OpCall>>& consumers) {
const std::vector<const OpCall*>& consumers) {
consumers_ = consumers;
}

void AddConsumer(std::weak_ptr<const OpCall> consumer) {
void AddConsumer(const OpCall* consumer) {
consumers_.push_back(consumer);
}

Expand All @@ -184,33 +184,33 @@ class Tensor : public std::enable_shared_from_this<Tensor> {
: tensor_id_(tensor_id), pattern_graph_(pattern_graph) {}

id_type tensor_id_;
std::weak_ptr<OpCall> producer_;
std::vector<std::weak_ptr<const OpCall>> consumers_;
OpCall* producer_;
std::vector<const OpCall*> consumers_;
PatternGraph* pattern_graph_;
};

class OpCall : public std::enable_shared_from_this<OpCall> {
public:
OpCall(std::weak_ptr<const Op> op,
const std::vector<std::weak_ptr<const Tensor>>& inputs,
const std::vector<std::weak_ptr<const Tensor>>& outputs)
OpCall( Op const * op,
const std::vector<const Tensor *>& inputs,
const std::vector<const Tensor *>& outputs)
: op_(op), inputs_(inputs), outputs_(outputs) {}

const std::string& name() const { return op_.lock()->name(); }
const std::string& name() const { return op_->name(); }

const std::vector<std::weak_ptr<const Tensor>>& inputs() const {
const std::vector<const Tensor*>& inputs() const {
return inputs_;
}

const std::vector<std::weak_ptr<const Tensor>>& outputs() const {
const std::vector<const Tensor*>& outputs() const {
return outputs_;
}

private:
id_type op_call_id_;
std::weak_ptr<const Op> op_;
std::vector<std::weak_ptr<const Tensor>> inputs_;
std::vector<std::weak_ptr<const Tensor>> outputs_;
const Op* op_;
std::vector<const Tensor*> inputs_;
std::vector<const Tensor*> outputs_;
};

class ResultPattern {
Expand Down
52 changes: 26 additions & 26 deletions paddle/ir/pattern_rewrite/drr/pattern_graph.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -22,48 +22,48 @@
namespace ir {
namespace drr {

const drr::OpCall& PatternGraph::AddOpCall(
const std::shared_ptr<drr::OpCall>& op_call) {
const drr::OpCall &
PatternGraph::AddOpCall(const std::shared_ptr<drr::OpCall> &op_call) {
owned_op_call_.push_back(op_call);
for (const auto& input : op_call->inputs()) {
const auto& tensor_id = input.lock()->id();
for (const auto &input : op_call->inputs()) {
const auto &tensor_id = input->id();
CHECK(id2owned_tensor_.count(tensor_id));
id2owned_tensor_.at(tensor_id)->AddConsumer(op_call);
id2owned_tensor_.at(tensor_id)->AddConsumer(op_call.get());

if (input.lock()->producer().use_count() == 0) {
if (input->producer() == nullptr) {
input_tensors_.insert(tensor_id);
}
if (output_tensors_.find(tensor_id) != output_tensors_.end()) {
output_tensors_.erase(tensor_id);
}
}
for (auto& output : op_call->outputs()) {
const auto& out_tensor_id = output.lock()->id();
for (auto &output : op_call->outputs()) {
const auto &out_tensor_id = output->id();
CHECK(id2owned_tensor_.count(out_tensor_id));
id2owned_tensor_[output.lock()->id()]->set_producer(op_call);
id2owned_tensor_[output->id()]->set_producer(op_call.get());
}
return *owned_op_call_.back();
}

const drr::Tensor& PatternGraph::AddTensor(
const std::shared_ptr<drr::Tensor>& tensor) {
const drr::Tensor &
PatternGraph::AddTensor(const std::shared_ptr<drr::Tensor> &tensor) {
if (id2owned_tensor_.find(tensor->id()) == id2owned_tensor_.end()) {
id2owned_tensor_[tensor->id()] = tensor;
output_tensors_.insert(tensor->id());
}
return *id2owned_tensor_[tensor->id()];
}

drr::Tensor& PatternGraph::AddTmpTensor(
const std::shared_ptr<drr::Tensor>& tensor) {
drr::Tensor &
PatternGraph::AddTmpTensor(const std::shared_ptr<drr::Tensor> &tensor) {
CHECK(id2owned_tensor_.find(tensor->id()) == id2owned_tensor_.end());
id2owned_tensor_[tensor->id()] = tensor;
output_tensors_.insert(tensor->id());
return *id2owned_tensor_[tensor->id()];
}

void PatternGraph::UpdateTmpTensor(const id_type& tmp_tensor_id,
const id_type& new_tensor_id) {
void PatternGraph::UpdateTmpTensor(const id_type &tmp_tensor_id,
const id_type &new_tensor_id) {
if (input_tensors_.count(tmp_tensor_id)) {
input_tensors_.erase(tmp_tensor_id);
input_tensors_.insert(new_tensor_id);
Expand All @@ -83,44 +83,44 @@ void PatternGraph::UpdateTmpTensor(const id_type& tmp_tensor_id,

void PatternGraph::Print() const {
std::cout << "All Tensors:" << std::endl;
for (const auto& kv : id2owned_tensor_) {
for (const auto &kv : id2owned_tensor_) {
std::cout << " " << kv.first;
}
std::cout << "\n" << std::endl;

std::cout << "Input Tensors:" << std::endl;
for (const auto& tensor_id : input_tensors_) {
for (const auto &tensor_id : input_tensors_) {
std::cout << " " << tensor_id;
}
std::cout << "\n" << std::endl;

std::cout << "Output Tensors:" << std::endl;
for (const auto& tensor_id : output_tensors_) {
for (const auto &tensor_id : output_tensors_) {
std::cout << " " << tensor_id;
}
std::cout << "\n" << std::endl;

std::cout << "OpCalls:" << std::endl;
for (const auto& op_call : owned_op_call_) {
for (const auto &op_call : owned_op_call_) {
std::cout << " " << op_call->name() << " : ";
std::cout << "inputs[ ";
for (const auto& input : op_call->inputs()) {
std::cout << input.lock()->id() << " ";
for (const auto &input : op_call->inputs()) {
std::cout << input->id() << " ";
}
std::cout << "], ";

std::cout << "outputs[ ";
for (const auto& output : op_call->outputs()) {
std::cout << output.lock()->id() << " ";
for (const auto &output : op_call->outputs()) {
std::cout << output->id() << " ";
}
std::cout << "]" << std::endl;
}
std::cout << std::endl;
}

std::weak_ptr<OpCall> SourcePatternGraph::AnchorNode() const {
const OpCall *SourcePatternGraph::AnchorNode() const {
return id2owned_tensor_.at(*output_tensors_.begin())->producer();
}

} // namespace drr
} // namespace ir
} // namespace drr
} // namespace ir
2 changes: 1 addition & 1 deletion paddle/ir/pattern_rewrite/drr/pattern_graph.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class PatternGraph {

class SourcePatternGraph : public PatternGraph {
public:
std::weak_ptr<OpCall> AnchorNode() const;
const OpCall* AnchorNode() const;

private:
friend class DrrPatternContext;
Expand Down

0 comments on commit 7279d74

Please sign in to comment.