From 665e1a335b1b30f465914e361d05dfe2d13092c9 Mon Sep 17 00:00:00 2001 From: qingqing01 Date: Wed, 9 Aug 2017 20:57:58 +0800 Subject: [PATCH] Update grad_op_builder after refactoring framework proto. --- paddle/framework/grad_op_builder.cc | 68 ++++------------ paddle/framework/grad_op_builder_test.cc | 81 +++++++++---------- paddle/framework/op_registry_test.cc | 10 --- paddle/framework/operator_test.cc | 19 +---- .../v2/framework/tests/test_operator.py | 2 + 5 files changed, 56 insertions(+), 124 deletions(-) diff --git a/paddle/framework/grad_op_builder.cc b/paddle/framework/grad_op_builder.cc index da9613e776369..27f37d99232fd 100644 --- a/paddle/framework/grad_op_builder.cc +++ b/paddle/framework/grad_op_builder.cc @@ -18,59 +18,32 @@ permissions and limitations under the License. */ namespace paddle { namespace framework { -/** + class OpRegistry; using VarIndexMap = std::unordered_map; enum class OpArgType { IN, OUT }; -static std::vector* GetOpFormat(OperatorBase* op, const OpArgType& type) { - std::string key = type == OpArgType::IN ? "input_format" : "output_format"; - return op->attrs_.count(key) - ? &boost::get>(op->attrs_.at(key)) - : nullptr; -} - -static const std::vector* GetOpFormat(const OperatorBase* op, - const OpArgType& type) { - std::string key = type == OpArgType::IN ? "input_format" : "output_format"; - return op->attrs_.count(key) - ? &boost::get>(op->attrs_.at(key)) - : nullptr; -} - static void TransOpArg(const OperatorBase* src_op, OperatorBase* dst_op, const OpArgType& src_type, const OpArgType& dst_type, - int& idx, bool is_grad) { - const std::vector& src_inout = + bool is_grad) { + const auto& src_inout = src_type == OpArgType::IN ? src_op->inputs_ : src_op->outputs_; - const std::vector* src_format = GetOpFormat(src_op, src_type); - std::vector& dst_inout = + auto& dst_inout = dst_type == OpArgType::IN ? dst_op->inputs_ : dst_op->outputs_; - std::vector* dst_format = GetOpFormat(dst_op, dst_type); const OpProto& proto = OpRegistry::protos().at(src_op->type_); const auto& src_arg_list = src_type == OpArgType::IN ? proto.inputs() : proto.outputs(); for (const auto& arg : src_arg_list) { std::string src_name = arg.name(); - std::string dst_name = is_grad ? src_name + kGradVarSuffix : src_name; - (*dst_op->in_out_idxs_)[dst_name] = idx++; - int src_arg_idx = src_op->in_out_idxs_->at(src_name); - int src_begin = - src_format == nullptr ? src_arg_idx : src_format->at(src_arg_idx); - int src_end = src_format == nullptr ? src_arg_idx + 1 - : src_format->at(src_arg_idx + 1); - for (int i = src_begin; i < src_end; ++i) { - std::string s = - is_grad ? src_inout[i] + kGradVarSuffix - : (arg.ignore_gradient() ? kEmptyVarName : src_inout[i]); - dst_inout.emplace_back(s); - } - if (dst_format != nullptr) { - dst_format->push_back(dst_inout.size()); + std::string dst_name = is_grad ? GradVarName(src_name) : src_name; + for (auto& var_name : src_inout.at(src_name)) { + std::string s = is_grad ? GradVarName(var_name) + : (arg.no_gradient() ? kEmptyVarName : var_name); + dst_inout[dst_name].emplace_back(s); } } } @@ -80,25 +53,12 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { OperatorBase* grad_op = OpRegistry::op_creators().at(grad_op_type)(); grad_op->type_ = grad_op_type; grad_op->attrs_ = op->attrs_; - grad_op->attrs_.erase("input_format"); - grad_op->attrs_.erase("output_format"); - if (GetOpFormat(op, OpArgType::IN) != nullptr) { - grad_op->attrs_["output_format"] = std::vector({0}); - } - if (GetOpFormat(op, OpArgType::IN) != nullptr || - GetOpFormat(op, OpArgType::OUT) != nullptr) { - grad_op->attrs_["input_format"] = std::vector({0}); - } - grad_op->in_out_idxs_.reset(new VarIndexMap()); - int in_idx = 0; - int out_idx = 0; - TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, in_idx, false); // I - TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, false); // G - TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, in_idx, true); // OG - TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG + TransOpArg(op, grad_op, OpArgType::IN, OpArgType::IN, false); // I + TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, false); // O + TransOpArg(op, grad_op, OpArgType::OUT, OpArgType::IN, true); // OG + TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, true); // IG return grad_op; } -**/ -OperatorBase* BuildGradOp(const OperatorBase* op) { return nullptr; } + } // namespace framework } // namespace paddle diff --git a/paddle/framework/grad_op_builder_test.cc b/paddle/framework/grad_op_builder_test.cc index f308abfa79fad..19da90967f05b 100644 --- a/paddle/framework/grad_op_builder_test.cc +++ b/paddle/framework/grad_op_builder_test.cc @@ -51,14 +51,14 @@ TEST(GradOpBuilder, AddTwo) { "add_two", {{"X", {"x"}}, {"Y", {"y"}}}, {{"Out", {"out"}}}, {})); std::shared_ptr grad_add_op = f::OpRegistry::CreateGradOp(*add_op); - EXPECT_EQ(static_cast(grad_add_op->inputs_.size()), 4); - EXPECT_EQ(static_cast(grad_add_op->outputs_.size()), 2); + EXPECT_EQ(grad_add_op->inputs_.size(), 4UL); + EXPECT_EQ(grad_add_op->outputs_.size(), 2UL); EXPECT_EQ(grad_add_op->Input("X"), "x"); EXPECT_EQ(grad_add_op->Input("Y"), "y"); EXPECT_EQ(grad_add_op->Input("Out"), "out"); - EXPECT_EQ(grad_add_op->Input("Out@GRAD"), "out@GRAD"); - EXPECT_EQ(grad_add_op->Output("X@GRAD"), "x@GRAD"); - EXPECT_EQ(grad_add_op->Output("Y@GRAD"), "y@GRAD"); + EXPECT_EQ(grad_add_op->Input(f::GradVarName("Out")), f::GradVarName("out")); + EXPECT_EQ(grad_add_op->Output(f::GradVarName("X")), f::GradVarName("x")); + EXPECT_EQ(grad_add_op->Output(f::GradVarName("Y")), f::GradVarName("y")); } REGISTER_OP(mult_io, f::NOP, f::MutiInOutOpMaker); @@ -67,17 +67,16 @@ REGISTER_OP(io_ignored, f::NOP, f::IOIgnoredOpMaker); REGISTER_GRADIENT_OP(io_ignored, io_ignored_grad, f::NOP); TEST(GradOpBuilder, MutiInOut) { - f::AttributeMap attrs{{"input_format", std::vector{0, 1, 4, 5}}, - {"output_format", std::vector{0, 1, 3}}}; std::shared_ptr test_op(f::OpRegistry::CreateOp( - "mult_io", {{"In1", {"in1"}}, - {"In2_mult", {"in2_1", "in2_2", "in2_3"}}, - {"In3", {"in3"}}}, - {{"Out1", {"Out2_mult"}}, {"Out2", {"out2_1", "out2_2"}}}, attrs)); + "mult_io", + {{"In1", {"in1"}}, + {"In2_mult", {"in2_1", "in2_2", "in2_3"}}, + {"In3", {"in3"}}}, + {{"Out1", {"out1"}}, {"Out2_mult", {"out2_1", "out2_2"}}}, {})); std::shared_ptr grad_test_op = f::OpRegistry::CreateGradOp(*test_op); - ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Inputs("In2_mult"), std::vector({"in2_1", "in2_2", "in2_3"})); @@ -85,36 +84,33 @@ TEST(GradOpBuilder, MutiInOut) { EXPECT_EQ(grad_test_op->Input("Out1"), "out1"); EXPECT_EQ(grad_test_op->Inputs("Out2_mult"), std::vector({"out2_1", "out2_2"})); - EXPECT_EQ(grad_test_op->Input("Out1" + f::kGradVarSuffix), - "out1" + f::kGradVarSuffix); - EXPECT_EQ(grad_test_op->Inputs("Out2_mult" + f::kGradVarSuffix), + EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out1")), + f::GradVarName("out1")); + EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out2_mult")), std::vector( - {"out2_1" + f::kGradVarSuffix, "out2_2" + f::kGradVarSuffix})); + {f::GradVarName("out2_1"), f::GradVarName("out2_2")})); - ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); - EXPECT_EQ(grad_test_op->Output("In1" + f::kGradVarSuffix), - "in1" + f::kGradVarSuffix); - EXPECT_EQ(grad_test_op->Outputs("In2_mult" + f::kGradVarSuffix), - std::vector({"in2_1" + f::kGradVarSuffix, - "in2_2" + f::kGradVarSuffix, - "in2_3" + f::kGradVarSuffix})); - EXPECT_EQ(grad_test_op->Output("In3" + f::kGradVarSuffix), - "in3" + f::kGradVarSuffix); + ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); + EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); + EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), + std::vector({f::GradVarName("in2_1"), + f::GradVarName("in2_2"), + f::GradVarName("in2_3")})); + EXPECT_EQ(grad_test_op->Output(f::GradVarName("In3")), f::GradVarName("in3")); } TEST(GradOpBuilder, IOIgnoredInGradient) { - f::AttributeMap attrs{{"input_format", std::vector{0, 1, 3, 5}}, - {"output_format", std::vector{0, 2, 3}}}; std::shared_ptr test_op(f::OpRegistry::CreateOp( - "io_ignored", {{"In1", {"in1"}}, - {"In2_mult", {"in2_1", "in2_2"}}, - {"In3_mult", {"in3_1", "in3_2"}}}, - {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, attrs)); + "io_ignored", + {{"In1", {"in1"}}, + {"In2_mult", {"in2_1", "in2_2"}}, + {"In3_mult", {"in3_1", "in3_2"}}}, + {{"Out1_mult", {"out1_1", "out1_2"}}, {"Out2", {"out2"}}}, {})); std::shared_ptr grad_test_op = f::OpRegistry::CreateGradOp(*test_op); // 'In2' and 'Out2' are ignored in gradient calculating - ASSERT_EQ(grad_test_op->inputs_.size(), 5UL + 3UL + 3UL); + ASSERT_EQ(grad_test_op->inputs_.size(), 3UL + 2UL + 2UL); EXPECT_EQ(grad_test_op->Input("In1"), "in1"); EXPECT_EQ(grad_test_op->Inputs("In2_mult"), std::vector({f::kEmptyVarName, f::kEmptyVarName})); @@ -123,19 +119,18 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { EXPECT_EQ(grad_test_op->Inputs("Out1_mult"), std::vector({"out1_1", "out1_2"})); EXPECT_EQ(grad_test_op->Input("Out2"), f::kEmptyVarName); - EXPECT_EQ(grad_test_op->Inputs("Out1_mult" + f::kGradVarSuffix), + EXPECT_EQ(grad_test_op->Inputs(f::GradVarName("Out1_mult")), std::vector( - {"out1_1" + f::kGradVarSuffix, "out1_2" + f::kGradVarSuffix})); - EXPECT_EQ(grad_test_op->Input("Out2" + f::kGradVarSuffix), - "out2" + f::kGradVarSuffix); + {f::GradVarName("out1_1"), f::GradVarName("out1_2")})); + EXPECT_EQ(grad_test_op->Input(f::GradVarName("Out2")), + f::GradVarName("out2")); - ASSERT_EQ(grad_test_op->outputs_.size(), 5UL); - EXPECT_EQ(grad_test_op->Output("In1" + f::kGradVarSuffix), - "in1" + f::kGradVarSuffix); - EXPECT_EQ(grad_test_op->Outputs("In2_mult" + f::kGradVarSuffix), + ASSERT_EQ(grad_test_op->outputs_.size(), 3UL); + EXPECT_EQ(grad_test_op->Output(f::GradVarName("In1")), f::GradVarName("in1")); + EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In2_mult")), std::vector( - {"in2_1" + f::kGradVarSuffix, "in2_2" + f::kGradVarSuffix})); - EXPECT_EQ(grad_test_op->Outputs("In3_mult" + f::kGradVarSuffix), + {f::GradVarName("in2_1"), f::GradVarName("in2_2")})); + EXPECT_EQ(grad_test_op->Outputs(f::GradVarName("In3_mult")), std::vector( - {"in3_1" + f::kGradVarSuffix, "in3_2" + f::kGradVarSuffix})); + {f::GradVarName("in3_1"), f::GradVarName("in3_2")})); } diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index 7eb4de003b413..32861b9f13036 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -131,14 +131,6 @@ TEST(OpRegistry, DefaultValue) { ASSERT_EQ(op->GetAttr("scale"), 1.0); } -static void SetInputFormat(paddle::framework::OpDesc* desc) { - auto attr = desc->add_attrs(); - attr->set_name("input_format"); - attr->set_type(paddle::framework::INTS); - attr->mutable_ints()->Add(0); - attr->mutable_ints()->Add(1); -} - TEST(OpRegistry, CustomChecker) { paddle::framework::OpDesc op_desc; op_desc.set_type("my_test_op"); @@ -149,7 +141,6 @@ TEST(OpRegistry, CustomChecker) { auto output = op_desc.add_outputs(); output->set_op_proto_name("output"); *output->mutable_var_names()->Add() = "oo"; - SetInputFormat(&op_desc); // attr 'test_attr' is not set bool caught = false; @@ -189,7 +180,6 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - SetInputFormat(&op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc); paddle::platform::CPUDeviceContext dev_ctx; paddle::framework::Scope scope; diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index cbfbaa56c131d..51039c8fa86a7 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -185,11 +185,11 @@ TEST(OpKernel, all) { op_desc.set_type("op_with_kernel"); auto* ipt = op_desc.mutable_inputs()->Add(); *ipt->mutable_var_names()->Add() = "IN1"; - ipt->set_op_proto_name("input"); + ipt->set_op_proto_name("x"); auto* output = op_desc.mutable_outputs()->Add(); *output->mutable_var_names()->Add() = "OUT1"; - output->set_op_proto_name("output"); + output->set_op_proto_name("y"); auto attr = op_desc.mutable_attrs()->Add(); attr->set_name("scale"); @@ -234,21 +234,6 @@ TEST(OpKernel, multi_inputs) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(3.14); - auto attr0 = op_desc.mutable_attrs()->Add(); - attr0->set_name("input_format"); - attr0->set_type(paddle::framework::AttrType::INTS); - auto input_format = attr0->mutable_ints(); - input_format->Add(0); // x0 - input_format->Add(3); // k - input_format->Add(4); // end - - auto attr1 = op_desc.mutable_attrs()->Add(); - attr1->set_name("output_format"); - attr1->set_type(paddle::framework::AttrType::INTS); - auto output_format = attr1->mutable_ints(); - output_format->Add(0); // y0 - output_format->Add(2); // y1 - paddle::platform::CPUDeviceContext cpu_device_context; paddle::framework::Scope scope; scope.NewVar("x0")->GetMutable(); diff --git a/python/paddle/v2/framework/tests/test_operator.py b/python/paddle/v2/framework/tests/test_operator.py index 4f164e1a69e3f..ef635b464c09d 100644 --- a/python/paddle/v2/framework/tests/test_operator.py +++ b/python/paddle/v2/framework/tests/test_operator.py @@ -74,6 +74,7 @@ def test_multiple_input_plain_output(self): expected1.inputs.extend(['x', 'w', 'b']) expected1.outputs.extend(['y']) expected1.type = 'fc' + # the input_format can be removed after testing attr = expected1.attrs.add() attr.name = 'input_format' attr.type = attribute_pb2.INTS @@ -86,6 +87,7 @@ def test_multiple_input_plain_output(self): expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b']) expected2.outputs.extend(['y']) expected2.type = 'fc' + # the input_format can be removed after testing attr = expected2.attrs.add() attr.name = 'input_format' attr.type = attribute_pb2.INTS