Skip to content

Commit

Permalink
Revert "Change function def to Node ref for more flexiblity" (apache#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent 98a67d9 commit 803db5d
Show file tree
Hide file tree
Showing 15 changed files with 54 additions and 111 deletions.
17 changes: 8 additions & 9 deletions nnvm/example/src/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@ using nnvm::FMutateInputs;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::FInplaceOption;
using nnvm::Node;
using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;

// simply return the shape as same
inline bool SameShape(const Node& n,
inline bool SameShape(const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
if (ishape->size() == 0 || (*ishape)[0].ndim() == 0) return false;
Expand All @@ -34,7 +33,7 @@ inline bool SameShape(const Node& n,
return true;
}

inline std::vector<std::pair<int, int> > InplaceIn0Out0(const Node& n) {
inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) {
return {{0, 0}};
}

Expand All @@ -51,11 +50,11 @@ NNVM_REGISTER_OP(reshape)
attrs->parsed = std::move(target);
})
.attr<FInferShape>(
"FInferShape", [] (const Node& n,
"FInferShape", [] (const NodeAttrs& attrs,
std::vector<TShape> *ishape,
std::vector<TShape> *oshape) {
// get parsed attribute
const TShape& target = nnvm::get<TShape>(n.attrs.parsed);
const TShape& target = nnvm::get<TShape>(attrs.parsed);
(*oshape)[0] = target;
if ((*ishape)[0].ndim() == 0) return false;
CHECK_EQ((*ishape)[0].Size(), target.Size())
Expand All @@ -78,10 +77,10 @@ NNVM_REGISTER_OP(cast)
})
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInferType>(
"FInferType", [](const Node& n,
"FInferType", [](const NodeAttrs& attrs,
std::vector<int> *itype,
std::vector<int> *otype) {
(*otype)[0] = nnvm::get<int>(n.attrs.parsed);
(*otype)[0] = nnvm::get<int>(attrs.parsed);
return true;
});

Expand Down Expand Up @@ -110,7 +109,7 @@ NNVM_REGISTER_OP(cross_device_copy)
NNVM_REGISTER_OP(conv2d)
.describe("take conv of input")
.set_num_inputs(2)
.attr<FListInputNames>("FListInputNames", [](const Node& n) {
.attr<FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "weight"};
});

Expand All @@ -120,7 +119,7 @@ NNVM_REGISTER_OP(add)
NNVM_REGISTER_OP(assign)
.set_num_inputs(2)
.set_num_outputs(1)
.attr<FMutateInputs>("FMutateInputs", [](const Node& n) {
.attr<FMutateInputs>("FMutateInputs", [](const NodeAttrs& attrs) {
return std::vector<uint32_t>{0};
});

Expand Down
13 changes: 0 additions & 13 deletions nnvm/include/dmlc/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@
__cplusplus >= 201103L || defined(_MSC_VER))
#endif

/*! \brief strict CXX11 support */
#ifndef DMLC_STRICT_CXX11
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
#endif

/// check if g++ is before 4.6
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
Expand All @@ -74,7 +69,6 @@
#endif
#endif


/*!
* \brief Enable std::thread related modules,
* Used to disable some module in mingw compile.
Expand All @@ -88,13 +82,6 @@
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
#endif

/*! \brief helper macro to supress unused warning */
#if defined(__GNUC__)
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
#else
#define DMLC_ATTRIBUTE_UNUSED
#endif

/*! \brief helper macro to generate string concat */
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)
Expand Down
9 changes: 3 additions & 6 deletions nnvm/include/dmlc/json.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@
#include <typeindex>
#include <typeinfo>
#include <unordered_map>
#if DMLC_STRICT_CXX11
#include "./any.h"
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11

namespace dmlc {
Expand Down Expand Up @@ -322,8 +320,7 @@ class JSONObjectReadHelper {
};

#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
__make_AnyJSONType ## _ ## KeyName ## __
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __

/*!
* \def DMLC_JSON_ENABLE_ANY
Expand Down Expand Up @@ -478,7 +475,7 @@ struct Handler {
}
};

#if DMLC_STRICT_CXX11
#if DMLC_USE_CXX11
// Manager to store json serialization strategy.
class AnyJSONManager {
public:
Expand Down Expand Up @@ -564,7 +561,7 @@ struct Handler<any> {
CHECK(!reader->NextArrayItem()) << "invalid any json format";
}
};
#endif // DMLC_STRICT_CXX11
#endif // DMLC_USE_CXX11

} // namespace json

Expand Down
3 changes: 1 addition & 2 deletions nnvm/include/dmlc/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,7 @@ struct Parameter {
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
return &inst.manager; \
} \
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
__make__ ## PType ## ParamManager__ = \
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
(*PType::__MANAGER__()) \

//! \endcond
Expand Down
5 changes: 2 additions & 3 deletions nnvm/include/dmlc/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ class FunctionRegEntryBase {
* \sa FactoryRegistryEntryBase
*/
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \

/*!
Expand Down Expand Up @@ -272,7 +272,6 @@ class FunctionRegEntryBase {
*/
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
__dmlc_registry_file_tag_ ## UniqueTag ## __();
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
} // namespace dmlc
#endif // DMLC_REGISTRY_H_
7 changes: 5 additions & 2 deletions nnvm/include/nnvm/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ namespace nnvm {

// Forward declare node.
class Node;

/*!
* \brief we always used NodePtr for a reference pointer
* to the node, so this alias can be changed in case.
Expand Down Expand Up @@ -47,6 +48,8 @@ struct NodeEntry {
struct NodeAttrs {
/*! \brief name of the node */
std::string name;
/*! \brief Vector representation of positional attributes */
std::vector<double> scalars;
/*! \brief The dictionary representation of attributes */
std::unordered_map<std::string, std::string> dict;
/*!
Expand Down Expand Up @@ -105,7 +108,7 @@ inline uint32_t Node::num_outputs() const {
if (this->op->get_num_outputs == nullptr) {
return this->op->num_outputs;
} else {
return this->op->get_num_outputs(*this);
return this->op->get_num_outputs(this->attrs);
}
}

Expand All @@ -114,7 +117,7 @@ inline uint32_t Node::num_inputs() const {
if (this->op->get_num_inputs == nullptr) {
return this->op->num_inputs;
} else {
return this->op->get_num_inputs(*this);
return this->op->get_num_inputs(this->attrs);
}
}

Expand Down
26 changes: 14 additions & 12 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,16 +102,16 @@ class Op {
uint32_t num_outputs = 1;
/*!
* \brief get number of outputs given information about the node.
* \param n The node
* \param attrs The attribute of the node
* \return number of outputs.
*/
std::function<uint32_t(const Node& n)> get_num_outputs = nullptr;
std::function<uint32_t(const NodeAttrs& attrs)> get_num_outputs = nullptr;
/*!
* \brief get number of inputs given information about the node.
* \param n The node
* \param attrs The attribute of the node
* \return number of inputs
*/
std::function<uint32_t(const Node& n)> get_num_inputs = nullptr;
std::function<uint32_t(const NodeAttrs& attrs)> get_num_inputs = nullptr;
/*!
* \brief Attribute parser to parse the NodeAttrs information.
*
Expand All @@ -136,11 +136,11 @@ class Op {
* attrs->parsed = std::move(param);
* }
* // The other function that can utilize the parsed result.
* TShape SumInferShape(const NodePtr& ptr,
* TShape SumInferShape(const NodeAttrs& attrs,
* const std::vector<TShape>& ishapes) {
* // we can use the parsed version of param
* // without repeatively parsing the parameter
* const SumParam& param = nnvm::get<SumParam>(ptr->attrs.parsed);
* const SumParam& param = nnvm::get<SumParam>(attrs.parsed);
* }
* \endcode
*/
Expand Down Expand Up @@ -180,7 +180,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_inputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
inline Op& set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the num_outputs
* \param n The number of outputs to be set.
Expand All @@ -192,7 +192,7 @@ class Op {
* \param fn The function to be set.
* \return reference to self.
*/
inline Op& set_num_outputs(std::function<uint32_t (const Node& n)> fn); // NOLINT(*)
inline Op& set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn); // NOLINT(*)
/*!
* \brief Set the attr_parser function.
* \param fn The number of outputs to be set.
Expand Down Expand Up @@ -279,8 +279,10 @@ class OpMap {
};

// internal macros to make
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
#define NNVM_REGISTER_VAR_DEF(OpName) \
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName

/*!
* \def NNVM_REGISTER_OP
Expand All @@ -298,7 +300,7 @@ class OpMap {
* \endcode
*/
#define NNVM_REGISTER_OP(OpName) \
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)

// implementations of template functions after this.
Expand Down Expand Up @@ -375,7 +377,7 @@ inline Op& Op::set_num_inputs(uint32_t n) { // NOLINT(*)
return *this;
}

inline Op& Op::set_num_inputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
inline Op& Op::set_num_inputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_inputs = fn;
return *this;
}
Expand All @@ -385,7 +387,7 @@ inline Op& Op::set_num_outputs(uint32_t n) { // NOLINT(*)
return *this;
}

inline Op& Op::set_num_outputs(std::function<uint32_t (const Node& n)> fn) { // NOLINT(*)
inline Op& Op::set_num_outputs(std::function<uint32_t (const NodeAttrs& attr)> fn) { // NOLINT(*)
this->get_num_outputs = fn;
return *this;
}
Expand Down
36 changes: 11 additions & 25 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
#include <functional>
#include "./base.h"
#include "./tuple.h"
#include "./node.h"

namespace nnvm {

Expand All @@ -22,44 +21,44 @@ namespace nnvm {
/*!
* \brief Return list of input arguments names of each operator.
*
* \param n The node.
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListInputNames", default return {"data"}.
*
* FListInputNames enables automatic variable creation for missing arguments.
*/
using FListInputNames = std::function<std::vector<std::string> (const Node& n)>;
using FListInputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;

/*!
* \brief Return list of output arguments names of each operator.
*
* \param n The node.
* \param attrs The attributes of the node.
* \return list of inputs
* \note Register under "FListOutputNames", default return {"outputs"}.
*
* FListOutputNames customized naming for operator outputs.
*/
using FListOutputNames = std::function<std::vector<std::string> (const Node& n)>;
using FListOutputNames = std::function<std::vector<std::string> (const NodeAttrs& attrs)>;

/*!
* \brief Check whether operator will mutate k-th input.
* \param n The node.
* \param attrs The attributes of the node.
* \return list of input indices it mutates.
*
* \note Register under "FMutateInputs", default return false
* FMutateInputs enables mutation order handling correctly.
*/
using FMutateInputs = std::function<std::vector<uint32_t> (const Node& n)>;
using FMutateInputs = std::function<std::vector<uint32_t> (const NodeAttrs& attrs)>;

/*!
* \brief Inference function of certain type.
* \tparam AttrType The type of the attribute to be infered.
* \return whether all attributes are inferred.
*/
template<typename AttrType>
using FInferNodeEntryAttr = std::function<bool (const Node& n,
std::vector<AttrType> *in_ptr,
std::vector<AttrType> *out_ptr)>;
using FInferNodeEntryAttr = std::function<bool (const NodeAttrs& attrs,
std::vector<AttrType> *in_attrs,
std::vector<AttrType> *out_attrs)>;
/*!
* \brief Shape inference function.
* Update the shapes given the input shape information.
Expand Down Expand Up @@ -97,7 +96,7 @@ using TIsBackwardOp = bool;
/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param n The node
* \param attrs The attributes of the node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
Expand All @@ -106,20 +105,7 @@ using TIsBackwardOp = bool;
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const Node& n)>;

/*!
* \brief Get the gradient node of the op node
* This function generates the backward graph of the node
* \param nodeptr The node to take gradient
* \param out_grads Gradient of current node's outputs
* \return gradients of the inputs
*
* \note Register under "FGradient"
*/
using FGradient = std::function<std::vector<NodeEntry>(
const NodePtr& nodeptr,
const std::vector<NodeEntry>& out_grads)>;
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;

} // namespace nnvm

Expand Down
2 changes: 1 addition & 1 deletion nnvm/include/nnvm/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace nnvm {
* \param src The graph to be transformed.
* \return The generated graph.
*/
using PassFunction = std::function<Graph (Graph src)>;
typedef std::function<Graph (Graph src)> PassFunction;

/*!
* \brief Apply a series of pass transformations on g.
Expand Down
Loading

0 comments on commit 803db5d

Please sign in to comment.