Skip to content

Commit

Permalink
Merge pull request #3 from reyoung/feature/refactorize_framework_proto
Browse files Browse the repository at this point in the history
Step 1: Make code compile well.
  • Loading branch information
wangkuiyi committed Aug 8, 2017
2 parents 72e3ba5 + dba618c commit d97a2b4
Show file tree
Hide file tree
Showing 35 changed files with 927 additions and 960 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ cmake-build-*
python/paddle/v2/framework/core.so
CMakeFiles
cmake_install.cmake

paddle/.timestamp
python/paddlepaddle.egg-info/
2 changes: 1 addition & 1 deletion paddle/framework/attribute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return STRINGS;
}

Attribute GetAttrValue(const AttrDesc& attr_desc) {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: {
return attr_desc.i();
Expand Down
5 changes: 2 additions & 3 deletions paddle/framework/attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>

#include "paddle/framework/attribute.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/platform/enforce.h"

namespace paddle {
Expand All @@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template <typename T>
AttrType AttrTypeID();

Attribute GetAttrValue(const AttrDesc& attr_desc);
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);

// check whether a value(attribute) fit a certain limit
template <typename T>
Expand Down
65 changes: 41 additions & 24 deletions paddle/framework/backward.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,24 @@
namespace paddle {
namespace framework {

static bool AllInSet(const std::vector<std::string>& names,
const std::string& suffix,
const std::unordered_set<std::string>& set) {
template <typename Map, typename T>
static void ForEachVarName(Map& names, T callback) {
for (auto& name : names) {
if (set.find(name + suffix) == set.end()) {
return false;
for (auto& n : name.second) {
if (callback(n)) break;
}
}
return true;
}

static bool AllInSet(
const std::unordered_map<std::string, std::vector<std::string>>& names,
const std::string& suffix, const std::unordered_set<std::string>& set) {
bool ret_val = true;
ForEachVarName(names, [&ret_val, &set, &suffix](const std::string& n) {
ret_val = set.find(n + suffix) == set.end();
return !ret_val;
});
return ret_val;
}

static std::shared_ptr<OperatorBase> NOP() {
Expand Down Expand Up @@ -67,10 +76,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
// Then all input gradients cannot be computed at all, and we put them into
// `no_grad_names` set. Return an NOP.
if (AllInSet(forwardOp.outputs_, kGradVarSuffix, no_grad_names)) {
for (auto& name : forwardOp.inputs_) {
// Mark all input is not need
no_grad_names.insert(name + kGradVarSuffix);
}
ForEachVarName(forwardOp.inputs_,
[&no_grad_names](const std::string& name) -> bool {
no_grad_names.insert(GradVarName(name));
return false;
});
return NOP();
}

Expand All @@ -92,9 +102,11 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
auto fwd = *it;
auto bwd = BackwardRecursive(*fwd, no_grad_names, uniq_id);
net->AddOp(bwd);
for (auto& out : bwd->outputs_) {
dup_output_ops[out].emplace_back(local_op_id);
}
ForEachVarName(bwd->outputs_,
[&dup_output_ops, local_op_id](const std::string& out) {
dup_output_ops[out].emplace_back(local_op_id);
return false;
});
}
// Get unique ID for this method.
auto uid = uniq_id++;
Expand All @@ -116,7 +128,7 @@ std::shared_ptr<OperatorBase> BackwardRecursive(
insert_position.push_back(
{dup_op.back(),
OpRegistry::CreateOp(
"add", {dup_outputs}, {name},
"add", {{"X", {dup_outputs}}}, {{"Out", {name}}},
{{"input_format",
std::vector<int>{0, static_cast<int>(dup_outputs.size())}}})});
}
Expand All @@ -130,24 +142,29 @@ std::shared_ptr<OperatorBase> BackwardRecursive(

} else {
std::shared_ptr<OperatorBase> grad_op = OpRegistry::CreateGradOp(forwardOp);
for (std::string& grad_input : grad_op->inputs_) {

ForEachVarName(grad_op->inputs_, [&no_grad_names,
&net](std::string& grad_input) {
if (no_grad_names.count(grad_input)) {
std::string prefix =
grad_input.substr(0, grad_input.size() - kGradVarSuffix.size());
grad_input = prefix + kZeroVarSuffix;

// If part of input gradient of that operator is not calculated, fill
// zero variables to that input gradient.
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {prefix},
{grad_input}, {}));
net->AddOp(OpRegistry::CreateOp("fill_zeros_like", {{"Src", {prefix}}},
{{"Dst", {grad_input}}}, {}));
}
}

for (std::string& grad_output : grad_op->outputs_) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
}
}
return false;
});

ForEachVarName(grad_op->outputs_,
[&no_grad_names](std::string& grad_output) {
if (no_grad_names.count(grad_output)) {
grad_output = kEmptyVarName;
}
return false;
});

if (net->ops_.empty()) { // Current no aux op is added to network
return grad_op;
Expand Down

0 comments on commit d97a2b4

Please sign in to comment.