Skip to content

Commit

Permalink
Update Symbol and C API (apache#22)
Browse files Browse the repository at this point in the history
* Update tuple to be compatible with mshadow

* Move set error message to C API

* simplify with using

* updates to shape inference

* Add unnamed namespace to the implementations

* [SYMBOL] Enable inference of Auxiliary data, rename list_arguments to list_inputs
  • Loading branch information
tqchen committed May 26, 2018
1 parent 753f876 commit 14c4df4
Show file tree
Hide file tree
Showing 19 changed files with 346 additions and 105 deletions.
12 changes: 6 additions & 6 deletions nnvm/README.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
# NNVM: Build deep learning system by parts

NNVM is not a deep learning library. It is a modular, decentralized and lightweight library to
help build deep learning libraries efficiently.
NNVM is not a deep learning library. It is a modular, decentralized and lightweight part to
help build deep learning libraries.

## What is it

While most deep learning systems offer end to end solutions,
it is interesting to ask if we can actually assemble a deep learning system by parts.
The goal is to enable hackers can customize optimizations, target platforms and set of operators they care about.
We believe that the decentralized modular system is an interesting direction.

The hope is that effective parts can be assembled together just like you assemble your own desktops.
So the customized deep learning solution can be minimax, minimum in terms of dependencies,
while maxiziming the users' need.

NNVM offers one such part, it provides a generic to do generic
computation graph optimization such as memory reduction, device allocation,
operator fusion while being agnostic to the operator
interface defintion and how operators are executed.
NNVM offers one such part, it provides a generic way to do
computation graph optimization such as memory reduction, device allocation and more
while being agnostic to the operator interface defintion and how operators are executed.
NNVM is inspired by LLVM, aiming to be an intermediate representation library
for neural nets and computation graphs generation and optimizations.

Expand Down
34 changes: 5 additions & 29 deletions nnvm/include/nnvm/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,37 +16,13 @@
namespace nnvm {

/*! \brief any type */
using any = dmlc::any;
using dmlc::any;

/*!
* \brief array_veiw type
* \tparam ValueType The value content of array view.
*/
template<typename ValueType>
using array_view = dmlc::array_view<ValueType>;

/*!
* \brief get reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
* \tparam T The type to be fetched.
*/
template<typename T>
inline T& get(any& src) { // NOLINT(*)
return dmlc::get<T>(src);
}

/*!
* \brief get const reference of type T stored in src.
* \param src The source container
* \return the reference to the type.
* \tparam T The type to be fetched.
*/
/*! \brief array_veiw type */
using dmlc::array_view;

template<typename T>
inline const T& get(const any& src) {
return dmlc::get<T>(src);
}
/*!\brief getter function of any type */
using dmlc::get;

} // namespace nnvm

Expand Down
27 changes: 19 additions & 8 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ typedef void *SymbolHandle;
/*! \brief handle to Graph */
typedef void *GraphHandle;

/*!
* \brief Set the last error message needed by C API
* \param msg The error message to set.
*/
NNVM_DLL void NNAPISetLastError(const char* msg);

/*!
* \brief return str message of the last error
* all function in this file will return 0 when success
Expand Down Expand Up @@ -171,25 +177,30 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
nn_uint *out_size,
const char*** out);
/*!
* \brief List arguments in the symbol.
* \brief List inputs in the symbol.
* \param symbol the symbol
* \param option The option to list the inputs
* option=0 means list all arguments.
* option=1 means list arguments that are readed only by the graph.
* option=2 means list arguments that are mutated by the graph.
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListArguments(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
NNVM_DLL int NNSymbolListInputNames(SymbolHandle symbol,
int option,
nn_uint *out_size,
const char ***out_str_array);
/*!
* \brief List returns in the symbol.
* \brief List returns names in the symbol.
* \param symbol the symbol
* \param out_size output size
* \param out_str_array pointer to hold the output string array
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolListOutputs(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
NNVM_DLL int NNSymbolListOutputNames(SymbolHandle symbol,
nn_uint *out_size,
const char ***out_str_array);
/*!
* \brief Get a symbol that contains all the internals.
* \param symbol The symbol
Expand Down
4 changes: 4 additions & 0 deletions nnvm/include/nnvm/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,9 @@ template<typename ValueType>
inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
const any* ref = GetAttrMap(key);
if (ref == nullptr) {
// update the attribute map of the key by creating new empty OpMap
UpdateAttrMap(key, [key](any* pmap) {
// use callback so it is in lockscope
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = key;
Expand All @@ -304,7 +306,9 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
template<typename ValueType>
inline Op& Op::attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value) {
// update the attribute map of the key by creating new empty if needed.
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
// the callback is in lockscope so is threadsafe.
if (pmap->empty()) {
OpMap<ValueType> pm;
pm.attr_name_ = attr_name;
Expand Down
4 changes: 2 additions & 2 deletions nnvm/include/nnvm/pass_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ inline Graph InferType(Graph graph,
DTypeVector type_args = {},
std::string type_attr_key = "") {
if (type_args.size() != 0) {
graph.attrs["type_args"] = std::make_shared<any>(std::move(type_args));
graph.attrs["dtype_args"] = std::make_shared<any>(std::move(type_args));
}
if (type_attr_key.length() != 0) {
graph.attrs["type_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
graph.attrs["dtype_attr_key"] = std::make_shared<any>(std::move(type_attr_key));
}
return ApplyPass(std::move(graph), {"InferType"});
}
Expand Down
20 changes: 17 additions & 3 deletions nnvm/include/nnvm/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ class Symbol {
/*! \brief only list attributes in current node */
kShallow = 1
};
/*! \brief option passed to ListInputNames */
enum ListInputOption {
/*! \brief list all the arguments */
kAll = 0,
/*! \brief list only read only arguments */
kReadOnlyArgs = 1,
/*!
* \brief List auxiliary states that can be mutated by the graph.
* This excludes the ReadOnly arguments
*/
kAuxiliaryStates = 2
};

/*! \brief output entries contained in the symbol */
std::vector<NodeEntry> outputs;
Expand All @@ -51,18 +63,20 @@ class Symbol {
*/
Symbol operator[] (size_t index) const;
/*!
* \brief List the arguments names.
* \brief List the input names.
* \param option The options to list the arguments.
*
* The position of the returned list also corresponds to calling position in operator()
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
* \sa ListInputOption
*/
std::vector<std::string> ListArguments() const;
std::vector<std::string> ListInputNames(ListInputOption option) const;
/*!
* \brief List the names of outputs for this symbol.
* For normal operators, it is usually symbol node name + "_output"
* \return get the descriptions of outputs for this symbol.
*/
std::vector<std::string> ListOutputs() const;
std::vector<std::string> ListOutputNames() const;
/*!
* \brief Compose the symbol with arguments, this changes the current symbol.
* The kwargs passed in can be in-complete,
Expand Down
Loading

0 comments on commit 14c4df4

Please sign in to comment.