From e2925e3ef3245821a03e5bb1754ab79dc3dd3abc Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Fri, 26 Aug 2016 19:30:24 -0700 Subject: [PATCH] Revert "Change function def to Node ref for more flexiblity" (#29) --- nnvm/example/src/operator.cc | 17 +++++++------- nnvm/include/dmlc/base.h | 13 ----------- nnvm/include/dmlc/json.h | 9 +++----- nnvm/include/dmlc/parameter.h | 3 +-- nnvm/include/dmlc/registry.h | 5 ++--- nnvm/include/nnvm/node.h | 7 ++++-- nnvm/include/nnvm/op.h | 26 +++++++++++---------- nnvm/include/nnvm/op_attr_types.h | 36 +++++++++--------------------- nnvm/include/nnvm/pass.h | 2 +- nnvm/include/nnvm/pass_functions.h | 29 ------------------------ nnvm/src/core/graph.cc | 2 +- nnvm/src/core/symbolic.cc | 8 +++---- nnvm/src/pass/infer_shape_type.cc | 2 +- nnvm/src/pass/order_mutation.cc | 4 ++-- nnvm/src/pass/plan_memory.cc | 2 +- 15 files changed, 54 insertions(+), 111 deletions(-) diff --git a/nnvm/example/src/operator.cc b/nnvm/example/src/operator.cc index 045e52a183c30..c0c729cdb4dd7 100644 --- a/nnvm/example/src/operator.cc +++ b/nnvm/example/src/operator.cc @@ -15,13 +15,12 @@ using nnvm::FMutateInputs; using nnvm::FInferShape; using nnvm::FInferType; using nnvm::FInplaceOption; -using nnvm::Node; using nnvm::NodeAttrs; using nnvm::TShape; using nnvm::array_view; // simply return the shape as same -inline bool SameShape(const Node& n, +inline bool SameShape(const NodeAttrs& attrs, std::vector *ishape, std::vector *oshape) { if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false; @@ -34,7 +33,7 @@ inline bool SameShape(const Node& n, return true; } -inline std::vector > InplaceIn0Out0(const Node& n) { +inline std::vector > InplaceIn0Out0(const NodeAttrs& attrs) { return {{0, 0}}; } @@ -51,11 +50,11 @@ NNVM_REGISTER_OP(reshape) attrs->parsed = std::move(target); }) .attr( - "FInferShape", [] (const Node& n, + "FInferShape", [] (const NodeAttrs& attrs, std::vector *ishape, std::vector *oshape) { // get parsed attribute - const TShape& target = nnvm::get(n.attrs.parsed); + const TShape& target = nnvm::get(attrs.parsed); (*oshape)[0] = target; if ((*ishape)[0].ndim() == 0) return false; CHECK_EQ((*ishape)[0].Size(), target.Size()) @@ -78,10 +77,10 @@ NNVM_REGISTER_OP(cast) }) .attr("FInferShape", SameShape) .attr( - "FInferType", [](const Node& n, + "FInferType", [](const NodeAttrs& attrs, std::vector *itype, std::vector *otype) { - (*otype)[0] = nnvm::get(n.attrs.parsed); + (*otype)[0] = nnvm::get(attrs.parsed); return true; }); @@ -110,7 +109,7 @@ NNVM_REGISTER_OP(cross_device_copy) NNVM_REGISTER_OP(conv2d) .describe("take conv of input") .set_num_inputs(2) -.attr("FListInputNames", [](const Node& n) { +.attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data", "weight"}; }); @@ -120,7 +119,7 @@ NNVM_REGISTER_OP(add) NNVM_REGISTER_OP(assign) .set_num_inputs(2) .set_num_outputs(1) -.attr("FMutateInputs", [](const Node& n) { +.attr("FMutateInputs", [](const NodeAttrs& attrs) { return std::vector{0}; }); diff --git a/nnvm/include/dmlc/base.h b/nnvm/include/dmlc/base.h index 9eca4135f1191..5b34fd6b4e345 100644 --- a/nnvm/include/dmlc/base.h +++ b/nnvm/include/dmlc/base.h @@ -58,11 +58,6 @@ __cplusplus >= 201103L || defined(_MSC_VER)) #endif -/*! \brief strict CXX11 support */ -#ifndef DMLC_STRICT_CXX11 -#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER)) -#endif - /// check if g++ is before 4.6 #if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__) #if __GNUC__ == 4 && __GNUC_MINOR__ < 6 @@ -74,7 +69,6 @@ #endif #endif - /*! * \brief Enable std::thread related modules, * Used to disable some module in mingw compile. @@ -88,13 +82,6 @@ #define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER)) #endif -/*! \brief helper macro to supress unused warning */ -#if defined(__GNUC__) -#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused)) -#else -#define DMLC_ATTRIBUTE_UNUSED -#endif - /*! \brief helper macro to generate string concat */ #define DMLC_STR_CONCAT_(__x, __y) __x##__y #define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y) diff --git a/nnvm/include/dmlc/json.h b/nnvm/include/dmlc/json.h index 1934aee6a2ce4..2daa0aaa017f6 100644 --- a/nnvm/include/dmlc/json.h +++ b/nnvm/include/dmlc/json.h @@ -25,9 +25,7 @@ #include #include #include -#if DMLC_STRICT_CXX11 #include "./any.h" -#endif // DMLC_STRICT_CXX11 #endif // DMLC_USE_CXX11 namespace dmlc { @@ -322,8 +320,7 @@ class JSONObjectReadHelper { }; #define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \ - static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \ - __make_AnyJSONType ## _ ## KeyName ## __ + static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __ /*! * \def DMLC_JSON_ENABLE_ANY @@ -478,7 +475,7 @@ struct Handler { } }; -#if DMLC_STRICT_CXX11 +#if DMLC_USE_CXX11 // Manager to store json serialization strategy. class AnyJSONManager { public: @@ -564,7 +561,7 @@ struct Handler { CHECK(!reader->NextArrayItem()) << "invalid any json format"; } }; -#endif // DMLC_STRICT_CXX11 +#endif // DMLC_USE_CXX11 } // namespace json diff --git a/nnvm/include/dmlc/parameter.h b/nnvm/include/dmlc/parameter.h index 2fbab2a44e32f..4ff99f860cc33 100644 --- a/nnvm/include/dmlc/parameter.h +++ b/nnvm/include/dmlc/parameter.h @@ -251,8 +251,7 @@ struct Parameter { static ::dmlc::parameter::ParamManagerSingleton inst(#PType); \ return &inst.manager; \ } \ - static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \ - __make__ ## PType ## ParamManager__ = \ + static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \ (*PType::__MANAGER__()) \ //! \endcond diff --git a/nnvm/include/dmlc/registry.h b/nnvm/include/dmlc/registry.h index 380b31cd3d61e..67fbc43ded682 100644 --- a/nnvm/include/dmlc/registry.h +++ b/nnvm/include/dmlc/registry.h @@ -216,7 +216,7 @@ class FunctionRegEntryBase { * \sa FactoryRegistryEntryBase */ #define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \ - static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ + static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \ ::dmlc::Registry::Get()->__REGISTER__(#Name) \ /*! @@ -272,7 +272,6 @@ class FunctionRegEntryBase { */ #define DMLC_REGISTRY_LINK_TAG(UniqueTag) \ int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \ - static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \ - __dmlc_registry_file_tag_ ## UniqueTag ## __(); + static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __(); } // namespace dmlc #endif // DMLC_REGISTRY_H_ diff --git a/nnvm/include/nnvm/node.h b/nnvm/include/nnvm/node.h index f42e272305cba..470d4d5763813 100644 --- a/nnvm/include/nnvm/node.h +++ b/nnvm/include/nnvm/node.h @@ -17,6 +17,7 @@ namespace nnvm { // Forward declare node. class Node; + /*! * \brief we always used NodePtr for a reference pointer * to the node, so this alias can be changed in case. @@ -47,6 +48,8 @@ struct NodeEntry { struct NodeAttrs { /*! \brief name of the node */ std::string name; + /*! \brief Vector representation of positional attributes */ + std::vector scalars; /*! \brief The dictionary representation of attributes */ std::unordered_map dict; /*! @@ -105,7 +108,7 @@ inline uint32_t Node::num_outputs() const { if (this->op->get_num_outputs == nullptr) { return this->op->num_outputs; } else { - return this->op->get_num_outputs(*this); + return this->op->get_num_outputs(this->attrs); } } @@ -114,7 +117,7 @@ inline uint32_t Node::num_inputs() const { if (this->op->get_num_inputs == nullptr) { return this->op->num_inputs; } else { - return this->op->get_num_inputs(*this); + return this->op->get_num_inputs(this->attrs); } } diff --git a/nnvm/include/nnvm/op.h b/nnvm/include/nnvm/op.h index 150f364716510..721e8e736e09b 100644 --- a/nnvm/include/nnvm/op.h +++ b/nnvm/include/nnvm/op.h @@ -102,16 +102,16 @@ class Op { uint32_t num_outputs = 1; /*! * \brief get number of outputs given information about the node. - * \param n The node + * \param attrs The attribute of the node * \return number of outputs. */ - std::function get_num_outputs = nullptr; + std::function get_num_outputs = nullptr; /*! * \brief get number of inputs given information about the node. - * \param n The node + * \param attrs The attribute of the node * \return number of inputs */ - std::function get_num_inputs = nullptr; + std::function get_num_inputs = nullptr; /*! * \brief Attribute parser to parse the NodeAttrs information. * @@ -136,11 +136,11 @@ class Op { * attrs->parsed = std::move(param); * } * // The other function that can utilize the parsed result. - * TShape SumInferShape(const NodePtr& ptr, + * TShape SumInferShape(const NodeAttrs& attrs, * const std::vector& ishapes) { * // we can use the parsed version of param * // without repeatively parsing the parameter - * const SumParam& param = nnvm::get(ptr->attrs.parsed); + * const SumParam& param = nnvm::get(attrs.parsed); * } * \endcode */ @@ -180,7 +180,7 @@ class Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_inputs(std::function fn); // NOLINT(*) + inline Op& set_num_inputs(std::function fn); // NOLINT(*) /*! * \brief Set the num_outputs * \param n The number of outputs to be set. @@ -192,7 +192,7 @@ class Op { * \param fn The function to be set. * \return reference to self. */ - inline Op& set_num_outputs(std::function fn); // NOLINT(*) + inline Op& set_num_outputs(std::function fn); // NOLINT(*) /*! * \brief Set the attr_parser function. * \param fn The number of outputs to be set. @@ -279,8 +279,10 @@ class OpMap { }; // internal macros to make +#define NNVM_STR_CONCAT_(__x, __y) __x##__y +#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y) #define NNVM_REGISTER_VAR_DEF(OpName) \ - static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName + static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName /*! * \def NNVM_REGISTER_OP @@ -298,7 +300,7 @@ class OpMap { * \endcode */ #define NNVM_REGISTER_OP(OpName) \ - DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ + NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \ ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName) // implementations of template functions after this. @@ -375,7 +377,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_inputs(std::function fn) { // NOLINT(*) this->get_num_inputs = fn; return *this; } @@ -385,7 +387,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*) return *this; } -inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) +inline Op& Op::set_num_outputs(std::function fn) { // NOLINT(*) this->get_num_outputs = fn; return *this; } diff --git a/nnvm/include/nnvm/op_attr_types.h b/nnvm/include/nnvm/op_attr_types.h index b1cb44d64df6e..675b93a6c9d2f 100644 --- a/nnvm/include/nnvm/op_attr_types.h +++ b/nnvm/include/nnvm/op_attr_types.h @@ -12,7 +12,6 @@ #include #include "./base.h" #include "./tuple.h" -#include "./node.h" namespace nnvm { @@ -22,34 +21,34 @@ namespace nnvm { /*! * \brief Return list of input arguments names of each operator. * - * \param n The node. + * \param attrs The attributes of the node. * \return list of inputs * \note Register under "FListInputNames", default return {"data"}. * * FListInputNames enables automatic variable creation for missing arguments. */ -using FListInputNames = std::function (const Node& n)>; +using FListInputNames = std::function (const NodeAttrs& attrs)>; /*! * \brief Return list of output arguments names of each operator. * - * \param n The node. + * \param attrs The attributes of the node. * \return list of inputs * \note Register under "FListOutputNames", default return {"outputs"}. * * FListOutputNames customized naming for operator outputs. */ -using FListOutputNames = std::function (const Node& n)>; +using FListOutputNames = std::function (const NodeAttrs& attrs)>; /*! * \brief Check whether operator will mutate k-th input. - * \param n The node. + * \param attrs The attributes of the node. * \return list of input indices it mutates. * * \note Register under "FMutateInputs", default return false * FMutateInputs enables mutation order handling correctly. */ -using FMutateInputs = std::function (const Node& n)>; +using FMutateInputs = std::function (const NodeAttrs& attrs)>; /*! * \brief Inference function of certain type. @@ -57,9 +56,9 @@ using FMutateInputs = std::function (const Node& n)>; * \return whether all attributes are inferred. */ template -using FInferNodeEntryAttr = std::function *in_ptr, - std::vector *out_ptr)>; +using FInferNodeEntryAttr = std::function *in_attrs, + std::vector *out_attrs)>; /*! * \brief Shape inference function. * Update the shapes given the input shape information. @@ -97,7 +96,7 @@ using TIsBackwardOp = bool; /*! * \brief Get possible inplace options. * This function enables optimization to reuse memory of inputs in output. - * \param n The node + * \param attrs The attributes of the node * \param in_data The input data. * \param out_data The output data. * \return list of pair of that maps input->output, @@ -106,20 +105,7 @@ using TIsBackwardOp = bool; * \note Register under "FInplaceOption", by default no inplace can happen. */ using FInplaceOption = std::function< - std::vector > (const Node& n)>; - -/*! - * \brief Get the gradient node of the op node - * This function generates the backward graph of the node - * \param nodeptr The node to take gradient - * \param out_grads Gradient of current node's outputs - * \return gradients of the inputs - * - * \note Register under "FGradient" - */ -using FGradient = std::function( - const NodePtr& nodeptr, - const std::vector& out_grads)>; + std::vector > (const NodeAttrs& attrs)>; } // namespace nnvm diff --git a/nnvm/include/nnvm/pass.h b/nnvm/include/nnvm/pass.h index fb97cea5ae0bf..438226f5c93f5 100644 --- a/nnvm/include/nnvm/pass.h +++ b/nnvm/include/nnvm/pass.h @@ -23,7 +23,7 @@ namespace nnvm { * \param src The graph to be transformed. * \return The generated graph. */ -using PassFunction = std::function; +typedef std::function PassFunction; /*! * \brief Apply a series of pass transformations on g. diff --git a/nnvm/include/nnvm/pass_functions.h b/nnvm/include/nnvm/pass_functions.h index 25f5b26957f83..8cca33e97cfde 100644 --- a/nnvm/include/nnvm/pass_functions.h +++ b/nnvm/include/nnvm/pass_functions.h @@ -11,11 +11,9 @@ #define NNVM_PASS_FUNCTIONS_H_ #include -#include #include #include "./base.h" #include "./pass.h" -#include "./node.h" #include "./graph_attr_types.h" namespace nnvm { @@ -111,33 +109,6 @@ inline Graph PlaceDevice(Graph graph, return ApplyPass(std::move(graph), {"PlaceDevice"}); } -/*! - * \brief Get the gradient graph whose outputs are gradients of xs wrt to ys. - * \param graph source graph - * \param ys The entries we want to take gradient from. - * \param xs The input we want to - * \param aggregate_fun aggregation function applied to aggregate the inputs - * \param mirror_fun Optional mirror function to do mirror optimization and save memory. - * \return A new graph, whose outputs corresponds to inputs of xs. - */ -inline Graph Gradient( - Graph graph, - std::vector ys, - std::vector xs, - std::function&& inputs)> aggregate_fun = nullptr, - std::function mirror_fun = nullptr) { - graph.attrs["grad_ys"] = std::make_shared(std::move(ys)); - graph.attrs["grad_xs"] = std::make_shared(std::move(xs)); - if (aggregate_fun != nullptr) { - graph.attrs["grad_aggregate_fun"] = std::make_shared(aggregate_fun); - } - if (mirror_fun != nullptr) { - graph.attrs["grad_mirror_fun"] = std::make_shared(mirror_fun); - } - - return ApplyPass(std::move(graph), {"Gradient"}); -} - } // namespace pass } // namespace nnvm #endif // NNVM_PASS_FUNCTIONS_H_ diff --git a/nnvm/src/core/graph.cc b/nnvm/src/core/graph.cc index f6f1faa5e0bb5..2e0160072612f 100644 --- a/nnvm/src/core/graph.cc +++ b/nnvm/src/core/graph.cc @@ -68,7 +68,7 @@ IndexedGraph::IndexedGraph(const Graph &g) { iptr + inputs_rptr[nid], iptr + inputs_rptr[nid + 1]); if (nodes_[nid].source->op != nullptr && fmutate_inputs.count(nodes_[nid].source->op)) { - for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](*(nodes_[nid].source))) { + for (uint32_t i : fmutate_inputs[nodes_[nid].source->op](nodes_[nid].source->attrs)) { mutable_input_nodes_.insert(nodes_[nid].inputs[i].node_id); } } diff --git a/nnvm/src/core/symbolic.cc b/nnvm/src/core/symbolic.cc index d839f2b9bf5b5..d595880aed1ce 100644 --- a/nnvm/src/core/symbolic.cc +++ b/nnvm/src/core/symbolic.cc @@ -38,7 +38,7 @@ inline void UpdateNodeVersion(Node *n) { } } if (fmutate_inputs.count(n->op) != 0) { - for (uint32_t i : fmutate_inputs[n->op](*n)) { + for (uint32_t i : fmutate_inputs[n->op](n->attrs)) { NodeEntry& e = n->inputs[i]; CHECK(e.node->is_variable()) << "Mutation target can only be Variable"; @@ -197,7 +197,7 @@ std::vector Symbol::ListInputNames(ListInputOption option) const { if (node->is_variable()) { vlist.push_back(node.get()); } else if (fmutate_inputs.count(node->op)) { - for (uint32_t i : fmutate_inputs[node->op](*node)){ + for (uint32_t i : fmutate_inputs[node->op](node->attrs)){ mutable_set.insert(node->inputs[i].node.get()); } } @@ -223,7 +223,7 @@ std::vector Symbol::ListOutputNames() const { std::string rname; FListOutputNames fn = flist_ouputs.get(head.node->op, nullptr); if (fn != nullptr) { - rname = fn(*head.node)[head.index]; + rname = fn(head.node->attrs)[head.index]; } else { rname = "output"; if (head.node->num_outputs() != 1) { @@ -279,7 +279,7 @@ void Symbol::Compose(const array_view& args, // switch to keyword argument matching if (args.size() != n_req) { FListInputNames fn = flist_inputs.get(n->op, nullptr); - auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(*n); + auto arg_names = (fn == nullptr) ? std::vector{"data"} : fn(n->attrs); if (arg_names.size() != n_req) { LOG(FATAL) << "Not enough argument to call operator " << outputs[0].node->op->name; } diff --git a/nnvm/src/pass/infer_shape_type.cc b/nnvm/src/pass/infer_shape_type.cc index 13e225744416e..5978ecdb79f2e 100644 --- a/nnvm/src/pass/infer_shape_type.cc +++ b/nnvm/src/pass/infer_shape_type.cc @@ -75,7 +75,7 @@ Graph InferAttr(Graph &&ret, oshape[i] = rshape[idx.entry_id(nid, i)]; } num_unknown += - !(finfer_shape[inode.source->op](*inode.source, &ishape, &oshape)); + !(finfer_shape[inode.source->op](inode.source->attrs, &ishape, &oshape)); for (uint32_t i = 0; i < num_inputs; ++i) { rshape[idx.entry_id(inode.inputs[i])] = ishape[i]; } diff --git a/nnvm/src/pass/order_mutation.cc b/nnvm/src/pass/order_mutation.cc index eddd5727b31b3..3bcfd9922d531 100644 --- a/nnvm/src/pass/order_mutation.cc +++ b/nnvm/src/pass/order_mutation.cc @@ -44,7 +44,7 @@ Graph OrderMutation(const Graph& src) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (!n->is_variable() && fmutate_inputs.count(n->op)) { - mutate_inputs = fmutate_inputs[n->op](*n); + mutate_inputs = fmutate_inputs[n->op](n->attrs); } std::sort(mutate_inputs.begin(), mutate_inputs.end()); @@ -102,7 +102,7 @@ Graph OrderMutation(const Graph& src) { static auto& fmutate_inputs = Op::GetAttr("FMutateInputs"); std::vector mutate_inputs; if (fmutate_inputs.count(kv.first->op)) { - mutate_inputs = fmutate_inputs[kv.first->op](*kv.first); + mutate_inputs = fmutate_inputs[kv.first->op](kv.first->attrs); } std::sort(mutate_inputs.begin(), mutate_inputs.end()); diff --git a/nnvm/src/pass/plan_memory.cc b/nnvm/src/pass/plan_memory.cc index 4541e7de5937c..14a88d217de8b 100644 --- a/nnvm/src/pass/plan_memory.cc +++ b/nnvm/src/pass/plan_memory.cc @@ -169,7 +169,7 @@ Graph PlanMemory(Graph ret) { if (inode.source->is_variable()) continue; // check inplace option if (finplace_option.count(inode.source->op) != 0) { - auto inplace_pairs = finplace_option[inode.source->op](*inode.source); + auto inplace_pairs = finplace_option[inode.source->op](inode.source->attrs); for (auto& kv : inplace_pairs) { uint32_t eid_out = idx.entry_id(nid, kv.second); uint32_t eid_in = idx.entry_id(inode.inputs[kv.first]);