Skip to content

Commit

Permalink
[PASS] add plan memory (apache#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 26, 2018
1 parent a33e9ce commit 2f837ab
Show file tree
Hide file tree
Showing 9 changed files with 411 additions and 3 deletions.
6 changes: 6 additions & 0 deletions nnvm/include/nnvm/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ class IndexedGraph {
inline const std::vector<uint32_t>& arg_nodes() const {
return arg_nodes_;
}
/*! \return list of output entries */
inline const std::vector<NodeEntry>& outputs() const {
return outputs_;
}

private:
friend class Graph;
Expand All @@ -159,6 +163,8 @@ class IndexedGraph {
std::vector<Node> nodes_;
// index to argument nodes
std::vector<uint32_t> arg_nodes_;
// space to store the outputs entries
std::vector<NodeEntry> outputs_;
// mapping from node to index.
std::unordered_map<const nnvm::Node*, uint32_t> node2index_;
// CSR pointer of node entries
Expand Down
17 changes: 16 additions & 1 deletion nnvm/include/nnvm/graph_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using DTypeVector = std::vector<int>;
*
* \code
* Graph g = ApplyPass(src_graph, {"PlaceDevice"});
* const &device = g.GetAttr<DeviceVector>("dtype");
* const &device = g.GetAttr<DeviceVector>("device");
* // get device by node_id
* int device_type = device[g.indexed_graph().node_id(my_node)];
* \endcode
Expand All @@ -75,6 +75,21 @@ using DeviceVector = std::vector<int>;
*/
using DeviceAssignMap = std::unordered_map<std::string, int>;

/*!
* \brief The result holder of storage id of each NodeEntry in the graph.
*
* \note Stored under graph.attrs["storage"], provided by Pass "PlanMemory"
* Storage id is a continuous integer.
* If the storage id is -1 then the storage is not assigned.
*
* \code
* Graph g = ApplyPass(src_graph, {"PlanMemory"});
* const &storage = g.GetAttr<StorageVector>("storage");
* // get storage id by entry
* int storage_id = storage[g.indexed_graph().entry_id(my_entry)];
* \endcode
*/
using StorageVector = std::vector<int>;

} // namespace nnvm

Expand Down
15 changes: 15 additions & 0 deletions nnvm/include/nnvm/op_attr_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <vector>
#include <string>
#include <utility>
#include <functional>
#include "./base.h"
#include "./tuple.h"
Expand Down Expand Up @@ -93,6 +94,20 @@ using FInferType = FInferNodeEntryAttr<int>;
*/
using TIsBackwardOp = bool;

/*!
* \brief Get possible inplace options.
* This function enables optimization to reuse memory of inputs in output.
* \param attrs The attributes of the node
* \param in_data The input data.
* \param out_data The output data.
* \return list of pair of that maps input->output,
* indicating possible in place operations.
*
* \note Register under "FInplaceOption", by default no inplace can happen.
*/
using FInplaceOption = std::function<
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;

} // namespace nnvm

#endif // NNVM_OP_ATTR_TYPES_H_
5 changes: 5 additions & 0 deletions nnvm/src/core/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ IndexedGraph::IndexedGraph(const Graph &g) {
control_rptr.push_back(control_deps_.size());
});

for (const auto& e : g.outputs) {
outputs_.emplace_back(NodeEntry{
node2index_.at(e.node.get()), e.index, e.version});
}

// setup array view
// input_entries_ and control_rptr must not change after this step.
const NodeEntry* iptr = dmlc::BeginPtr(input_entries_);
Expand Down
11 changes: 9 additions & 2 deletions nnvm/src/example/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using nnvm::FListInputNames;
using nnvm::FMutateInput;
using nnvm::FInferShape;
using nnvm::FInferType;
using nnvm::FInplaceOption;
using nnvm::NodeAttrs;
using nnvm::TShape;
using nnvm::array_view;
Expand All @@ -32,6 +33,10 @@ inline bool SameShape(const NodeAttrs& attrs,
return true;
}

inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs) {
return {{0, 0}};
}

// simple demonstration of reshape.
NNVM_REGISTER_OP(reshape)
.describe("reshape source to target shape")
Expand All @@ -55,7 +60,8 @@ NNVM_REGISTER_OP(reshape)
CHECK_EQ(ishape[0]->Size(), target.Size())
<< "Reshape op: source target shape mismatch";
return true;
});
})
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);


NNVM_REGISTER_OP(cast)
Expand All @@ -82,7 +88,8 @@ NNVM_REGISTER_OP(cast)
NNVM_REGISTER_OP(add)
.describe("add two data together")
.set_num_inputs(2)
.attr<FInferShape>("FInferShape", SameShape);
.attr<FInferShape>("FInferShape", SameShape)
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);

NNVM_REGISTER_OP(__add_symbol__)
.describe("Alias of add")
Expand Down
112 changes: 112 additions & 0 deletions nnvm/src/pass/graph_algorithm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/*!
* Copyright (c) 2016 by Contributors
* \file graph_algorithm.h
* \brief This header contains graph algorithms on StaticGraph.
* It is used compute informations such as whether two
* operations can run in parallel, and helps allocation.
*/
#ifndef NNVM_PASS_GRAPH_ALGORITHM_H_
#define NNVM_PASS_GRAPH_ALGORITHM_H_

#include <nnvm/graph.h>
#include <vector>

namespace nnvm {
namespace pass {

/*!
* \brief Find best path in the DAG, with reward defined
* by sum of reward of each node along the path.
* \param graph the original static graph.
* \param topo_order topo order of the nodes in the graph.
* \param node_reward the reward of each node.
* \param path the output path of nodes.
* \return the total reward of best path.
*/
inline uint32_t FindBestPath(
const IndexedGraph& graph,
const std::vector<uint32_t>& node_reward,
std::vector<uint32_t>* path) {
const uint32_t num_nodes = static_cast<uint32_t>(graph.num_nodes());
CHECK_EQ(num_nodes, node_reward.size());

std::vector<uint32_t> best_reward(node_reward.size(), 0);
std::vector<uint32_t> next_node(node_reward.size(), num_nodes);
uint32_t best_solution = 0, best_start_node = 0;

// traverse in reverse topo order
for (uint32_t i = static_cast<uint32_t>(graph.num_nodes()); i != 0; --i) {
const uint32_t nid = i - 1;
best_reward[nid] += node_reward[nid];
if (best_reward[nid] > best_solution) {
best_solution = best_reward[nid];
best_start_node = nid;
}
for (const auto& e : graph[nid].inputs) {
const uint32_t prev = e.node_id;
if (best_reward[nid] > best_reward[prev]) {
best_reward[prev] = best_reward[nid];
next_node[prev] = nid;
}
}
}
path->clear();
uint32_t reward = 0;
for (uint32_t nid = best_start_node; nid < num_nodes; nid = next_node[nid]) {
path->push_back(nid); reward += node_reward[nid];
}
CHECK_EQ(reward, best_solution);
return best_solution;
}

/*!
* \brief Color the nodes in the graph into index.
* The coloring algorithm tries to assign node group
* such that node in the same group cannot run in parallel.
*
* \param graph the original indexed graph.
* \param node_importance The importance of the node
* \param max_ncolor maximum number of colors allowed.
* \param color the color index of each of the node.
* \return the total number of colors.
*/
inline uint32_t ColorNodeGroup(
const IndexedGraph &graph,
std::vector<uint32_t> node_importance,
uint32_t max_ncolor,
std::vector<uint32_t> *color) {
CHECK_NE(max_ncolor, 0);
CHECK_EQ(graph.num_nodes(), node_importance.size());

color->clear();
color->resize(graph.num_nodes(), max_ncolor);
uint32_t cindex;
// greedy algorithm, every time
// find a path with best reward and assign a new color
// All the nodes in the path cannot run in parallel.
for (cindex = 0; cindex < max_ncolor - 1; ++cindex) {
std::vector<uint32_t> path;
uint32_t reward = FindBestPath(graph, node_importance, &path);
if (reward == 0) break;
for (uint32_t nid : path) {
if (node_importance[nid] != 0) {
CHECK_EQ(color->at(nid), max_ncolor);
color->at(nid) = cindex;
// make the importance 0 after color is decided.
node_importance[nid] = 0;
}
}
}
// assign i for rest of the node
for (uint32_t i = 0; i < graph.num_nodes(); ++i) {
if (color->at(i) == max_ncolor) {
color->at(i) = cindex;
}
}
return cindex + 1;
}

} // namespace pass
} // namespace nnvm

#endif // NNVM_PASS_GRAPH_ALGORITHM_H_
1 change: 1 addition & 0 deletions nnvm/src/pass/infer_shape_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ NNVM_REGISTER_PASS(InferType)

DMLC_JSON_ENABLE_ANY(ShapeVector, list_shape);
DMLC_JSON_ENABLE_ANY(DTypeVector, list_int);
DMLC_JSON_ENABLE_ANY(size_t, size_t);

} // namespace pass
} // namespace nnvm
Loading

0 comments on commit 2f837ab

Please sign in to comment.