Skip to content

Commit

Permalink
General Constant Folding for Onnx and Caffe2 Operators in Loader (#3141)
Browse files Browse the repository at this point in the history
Summary:
**Summary**
- Add constant folding of operators to ONNXModelLoader:
- Add constant folding of operators to Caffe2Loader
- Made Constant folding of Operators an optional feature controlled by command line option const-fold-ops
- Added constant folding test based on gather and reshape to OnnxImporterTests and Caffe2ImporterTest
- Add constant folding test to OnnxImporterTests

**Documentation**
First check if the folding is enabled. If yes, check if current operator is foldable. If yes get the output of a graph where all the inputs are constant and current operator is the only operator in the graph. Run the graph and instead of loading the original opreator load the constant operator with value calculated. When done sequentially during the loader this can fold not just a single operator but a whole subgraph as long as all the inputs to the subgraphs are constants.

**Optional Fixes**
- Fix an issue with batchboxCox batchBoxCox.onnxtxt (needed for general onnx operator folding test)

**Test Plan**
(Test constant folding of ONNX operators with ONNXImporterTest)
First replace placeholders in the original onnx model with constants. After loading the graph get the graph output values without running the GraphProto. If this output is same as output of original graph after running constant folding, the test passes.

A simple network in caffe2 and onnx which requires constant fold to work:
The test gatherOpConstantFoldingAndReshape in ONNxImportertest and Caffe2ImporterTest verifies that Gather gets constant-folded, so that the argument of the reshape becomes constant.

Related to #2168
Pull Request resolved: #3141

Reviewed By: opti-mix

Differential Revision: D15910349

Pulled By: jfix71

fbshipit-source-id: 32fd3aed51c36fadf0bc6bf461a0a548ed8e0c6e
  • Loading branch information
ksaurabh-cadence authored and facebook-github-bot committed Jul 10, 2019
1 parent cb7c305 commit 99b6571
Show file tree
Hide file tree
Showing 16 changed files with 676 additions and 42 deletions.
20 changes: 19 additions & 1 deletion include/glow/Graph/Graph.h
Expand Up @@ -46,6 +46,15 @@ using PlaceholderList = std::list<Placeholder *>;
using UnsignedArrayRef = llvm::ArrayRef<size_t>;
/// Map from original Nodes to cloned Nodes.
using NodeMap = llvm::DenseMap<Node *, Node *>;
/// State of a function. This can be used to control optimizations which depend
/// on the state of the Function. This is a temporary workaround until GH Issue
/// #3213 is complete.
enum class FunctionState {
/// Indicates that the function has been created but not completely loaded.
FuncCreated,
/// Indicates that the function has been completely loaded.
FuncLoaded,
};

class Module final {
/// Stores the functions in the module.
Expand Down Expand Up @@ -239,16 +248,25 @@ class Function final : public Named {
/// The log context associated with this function.
std::shared_ptr<LogContext> logCtx_;

/// The state of this function.
FunctionState state_;

public:
Function(Module *parent, llvm::StringRef Name = {})
: Named(Name), parent_(parent) {
: Named(Name), parent_(parent), state_(FunctionState::FuncCreated) {
logCtx_ = std::make_shared<LogContext>();
logCtx_->setParent(this);
logCtx_->loadModuleLogContext();
}

~Function();

/// Sets the state of the function.
void setState(FunctionState state) { state_ = state; }

/// Gets the state of the function.
FunctionState getState() { return state_; }

std::string getFilename() { return getName().rsplit('/').second.str(); }

/// Return the log context.
Expand Down
10 changes: 10 additions & 0 deletions include/glow/Importer/Caffe2ModelLoader.h
Expand Up @@ -60,6 +60,9 @@ class Caffe2ModelLoader
/// in the network.
llvm::Error loadOperator(const caffe2::OperatorDef &op);

/// \returns True if the operator \p op is successfully folded.
llvm::Expected<bool> foldOperator(const caffe2::OperatorDef &op);

/// Load the Conv or ConvRelu operators.
llvm::Error loadConv(const caffe2::OperatorDef &op,
ArgumentDictionaryTy &dict);
Expand Down Expand Up @@ -107,6 +110,13 @@ class Caffe2ModelLoader

friend class ONNXIFIModelLoader;

/// \returns success if the folding of operator \p op in the loader
/// \p loader is successful. The folding utility uses temporary
/// loader \p tmpLoader, and associated temporary function \p F.
template <class LoaderType, class OpType>
friend llvm::Error constantFoldInLoader(Function *F, LoaderType &tmpLoader,
LoaderType *loader, const OpType &op);

public:
/// Loads the caffe2 model that's represented by a network description file,
/// serialized in \p netDescFilename, and weights file, serialized in
Expand Down
10 changes: 10 additions & 0 deletions include/glow/Importer/ONNXModelLoader.h
Expand Up @@ -56,6 +56,9 @@ class ONNXModelLoader
/// in the network. \returns Error if operator \p op cannot be loaded.
llvm::Error loadOperator(const ONNX_NAMESPACE::NodeProto &op);

/// \returns True if the operator\ op is successfully folded.
llvm::Expected<bool> foldOperator(const ONNX_NAMESPACE::NodeProto &op);

/// ONNX model ir_version;
size_t irVersion_;

Expand Down Expand Up @@ -195,6 +198,13 @@ class ONNXModelLoader

friend class ONNXIFIModelLoader;

/// \returns success if the folding of operator \p op in the loader
/// \p loader is successful. The folding utility uses temporary
/// loader \p tmpLoader, and associated temporary function \p F.
template <class LoaderType, class OpType>
friend llvm::Error constantFoldInLoader(Function *F, LoaderType &tmpLoader,
LoaderType *loader, const OpType &op);

public:
/// Creates a ONNX model loader to build \p F.
/// If \p errPtr is not null then if an error occurs it will get assigned
Expand Down
62 changes: 62 additions & 0 deletions include/glow/Importer/ProtobufLoader.h
Expand Up @@ -18,7 +18,9 @@
#define GLOW_IMPORTER_PROTOBUFLOADER_H

#include "glow/Base/Tensor.h"
#include "glow/ExecutionEngine/ExecutionEngine.h"
#include "glow/Graph/Graph.h"
#include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"

#include "glow/Support/Error.h"
#include "llvm/ADT/ArrayRef.h"
Expand All @@ -38,6 +40,12 @@

namespace glow {

/// Enables or disables constant-folding of Loader Ops with \p flag.
void setConstantFoldLoaderOpsFlag(bool flag);

/// Returns true if constant-folding for loader Ops is enabled.
bool getConstantFoldLoaderOpsFlag();

/// Returns true iff all elements of \p a are the same.
bool isArrayConstant(const llvm::ArrayRef<size_t> a);

Expand Down Expand Up @@ -181,8 +189,62 @@ class ProtobufLoader {
/// \returns the Placeholder for the external output with \p name.
/// \pre outputVarsByName_.find(name) != outputVarsByName_.end()
llvm::Expected<Placeholder *> getOutputByName(llvm::StringRef name) const;

/// \returns True if the operator with name \p typeName having input node
/// list as \p inputs is constant foldable.
bool isConstantFoldable(llvm::ArrayRef<NodeValue> inputs,
std::string typeName) const;
};

/// \returns success if the folding of operator \p op in the loader
/// \p loader is successful. The folding utility uses temporary
/// loader \p tmpLoader, and associated temporary function \p F.
template <class LoaderType, class OpType>
llvm::Error constantFoldInLoader(Function *F, LoaderType &tmpLoader,
LoaderType *loader, const OpType &op) {
PlaceholderBindings bindings;
std::vector<Tensor *> outTensors;
Module *mod = F->getParent();

// Register the constant inputs to the current op with the constant folding
// loader.
for (unsigned i = 0; i < op.input_size(); i++) {
Constant *tmpConst = mod->getConstantByName(op.input(i));
RETURN_ERR_IF_NOT(tmpConst, "No constant found");
tmpLoader.nodeValueByName_[op.input(i)] = tmpConst->getOutput();
}

// Using the loader to load the current operator.
RETURN_IF_ERR(tmpLoader.loadOperator(op));

// To collect the folded outputs allocate and add save nodes to the folding
// function.
for (int i = 0; i < op.output_size(); i++) {
const auto &outputName = op.output(i);
NodeValue r;
ASSIGN_VALUE_OR_RETURN_ERR(r, tmpLoader.getNodeValueByName(outputName));
SaveNode *SN = F->createSave("save_" + outputName, r);
auto *result = bindings.allocate(SN->getPlaceholder());
outTensors.push_back(result);
}

// Evaluate the constant outputs using interpreter backend.
std::unique_ptr<Backend> backend(createBackend("Interpreter"));
CompilationContext cctx;
cctx.compMode = CompilationMode::Infer;
cctx.optimizationOpts.enableConstantFolding = false;
cctx.backendOpts.collectConstants = true;
RETURN_IF_ERR(executeConstantFunction(*backend, *F, bindings, cctx));

// Using the graph output, place constant nodes in the original graph.
for (int i = 0; i < op.output_size(); i++) {
RETURN_IF_ERR(loader->createAndRegisterConstant(op.output(i),
std::move(*outTensors[i])));
}

return llvm::Error::success();
}

} // namespace glow

#endif // GLOW_IMPORTER_PROTOBUFLOADER_H
14 changes: 14 additions & 0 deletions include/glow/Support/Error.h
Expand Up @@ -204,6 +204,20 @@ class GlowErr final : public llvm::ErrorInfo<GlowErr> {
} \
} while (0)

/// Takes an llvm::Expected<T> \p lhsOrErr and if it is an Error then returns
/// false, otherwise takes the value from lhsOrErr and assigns it to \p rhs.
#define ASSIGN_VALUE_OR_RETURN_FALSE(rhs, lhsOrErr) \
do { \
auto lhsOrErrV = (lhsOrErr); \
static_assert(IsLLVMExpected<decltype(lhsOrErrV)>(), \
"Expected value to be a llvm::Expected"); \
if (lhsOrErrV) { \
rhs = std::move(lhsOrErrV.get()); \
} else { \
return false; \
} \
} while (0)

/// Takes an llvm::Error and returns it if it's not success.
// TODO: extend this to work with llvm::Expected as well.
#define RETURN_IF_ERR(err) \
Expand Down
34 changes: 34 additions & 0 deletions lib/Importer/Caffe2ModelLoader.cpp
Expand Up @@ -528,6 +528,33 @@ llvm::Error Caffe2ModelLoader::loadConvQuantized(const caffe2::OperatorDef &op,
return llvm::Error::success();
}

llvm::Expected<bool>
Caffe2ModelLoader::foldOperator(const caffe2::OperatorDef &op) {
const unsigned numInputs = op.input_size();
const std::string &typeName = op.type();
llvm::SmallVector<NodeValue, 4> inputs;
inputs.reserve(numInputs);
for (unsigned i = 0; i < numInputs; i++) {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
inputs.push_back(in);
}

if (!isConstantFoldable(inputs, typeName)) {
return false;
}

// Create a temporary lightweight loader to construct function representing
// current Op, and then constant fold the function using Interp backend.
Function *tmpF = G_.getParent()->createFunction("eval_const_fold__");
Caffe2ModelLoader tmpLoader(*tmpF, nullptr);
bool foldStatus =
!errToBool(constantFoldInLoader<Caffe2ModelLoader, caffe2::OperatorDef>(
tmpF, tmpLoader, this, op));
G_.getParent()->eraseFunction(tmpF);
return foldStatus;
}

llvm::Error Caffe2ModelLoader::loadOperator(const caffe2::OperatorDef &op) {
ArgumentDictionaryTy dict = loadArgumentMap(op);
const std::string &typeName = op.type();
Expand Down Expand Up @@ -1340,6 +1367,13 @@ llvm::Error Caffe2ModelLoader::loadNetwork(caffe2::NetDef &net) {
/// Load the network operators:
for (int i = 0; i < net.op_size(); i++) {
auto &op = net.op(i);
if (getConstantFoldLoaderOpsFlag()) {
auto foldstatus = foldOperator(op);
if (foldstatus && foldstatus.get()) {
// Folded successfully.
continue;
}
}
RETURN_IF_ERR(loadOperator(op));
}

Expand Down
35 changes: 35 additions & 0 deletions lib/Importer/ONNXModelLoader.cpp
Expand Up @@ -1152,6 +1152,34 @@ llvm::Error ONNXModelLoader::loadTile(const ONNX_NAMESPACE::NodeProto &op,
return llvm::Error::success();
}

llvm::Expected<bool>
ONNXModelLoader::foldOperator(const ONNX_NAMESPACE::NodeProto &op) {
const unsigned numInputs = op.input_size();
const std::string &typeName = op.op_type();
llvm::SmallVector<NodeValue, 4> inputs;
inputs.reserve(numInputs);
for (unsigned i = 0; i < numInputs; i++) {
NodeValue in;
ASSIGN_VALUE_OR_RETURN_ERR(in, getNodeValueByName(op.input(i)));
inputs.push_back(in);
}

if (!isConstantFoldable(inputs, typeName)) {
return false;
}

// Create a temporary lightweight loader to construct function representing
// current Op, and then constant fold the function using Interp backend.
Function *tmpF = G_.getParent()->createFunction("eval_const_fold__");
ONNXModelLoader tmpLoader(*tmpF);
tmpLoader.opsetVersion_ = opsetVersion_;
bool foldStatus = !errToBool(
constantFoldInLoader<ONNXModelLoader, ONNX_NAMESPACE::NodeProto>(
tmpF, tmpLoader, this, op));
G_.getParent()->eraseFunction(tmpF);
return foldStatus;
}

llvm::Error ONNXModelLoader::loadOperator(const ONNX_NAMESPACE::NodeProto &op) {
ArgumentDictionaryTy dict = loadArgumentMap(op);
const std::string &typeName = op.op_type();
Expand Down Expand Up @@ -1256,6 +1284,13 @@ llvm::Error ONNXModelLoader::loadNetwork(ONNX_NAMESPACE::GraphProto &net) {
/// Load the network operators:
for (int i = 0; i < net.node_size(); i++) {
auto &op = net.node(i);
if (getConstantFoldLoaderOpsFlag()) {
auto foldstatus = foldOperator(op);
if (foldstatus && foldstatus.get()) {
// Folded successfully.
continue;
}
}
RETURN_IF_ERR(loadOperator(op));
}

Expand Down
38 changes: 37 additions & 1 deletion lib/Importer/ProtobufLoader.cpp
Expand Up @@ -15,18 +15,54 @@
*/

#include "glow/Importer/ProtobufLoader.h"

#include "llvm/Support/CommandLine.h"
#include <string>

namespace glow {

llvm::cl::OptionCategory loaderOptCat("Model Loader Options");

static llvm::cl::opt<bool> isConstFoldLoaderOps(
"const-fold-ops",
llvm::cl::desc(
"Performs constant folding on ONNX and Caffe Operators while loading."),
llvm::cl::init(false), llvm::cl::cat(loaderOptCat));

bool isArrayConstant(llvm::ArrayRef<size_t> a) {
for (size_t i = 1; i < a.size(); i++)
if (a[0] != a[i])
return false;
return true;
}

void setConstantFoldLoaderOpsFlag(bool flag) { isConstFoldLoaderOps = flag; }

bool getConstantFoldLoaderOpsFlag() { return isConstFoldLoaderOps; }

bool ProtobufLoader::isConstantFoldable(llvm::ArrayRef<NodeValue> inputs,
std::string typeName) const {
int numInputs = inputs.size();
if (!getConstantFoldLoaderOpsFlag()) {
return false;
}
// foldUnsupportedTypes: List of typenames unsupported for folding.
std::string foldUnsupportedTypes[] = {"Constant"};
std::string *findType = std::find(std::begin(foldUnsupportedTypes),
std::end(foldUnsupportedTypes), typeName);
// Early exit if folding is not supported for current operator.
if (findType != std::end(foldUnsupportedTypes)) {
return false;
}

// If all the inputs to the operator are constant this op can be folded.
for (int i = 0; i < numInputs; i++) {
if (inputs[i].getNode()->getKind() != Kinded::Kind::ConstantKind) {
return false;
}
}
return true;
}

Constant *ProtobufLoader::getConstantByNameOrNull(llvm::StringRef name) const {
auto it = nodeValueByName_.find(name);
if (it == nodeValueByName_.end()) {
Expand Down
20 changes: 19 additions & 1 deletion lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp
Expand Up @@ -62,6 +62,18 @@ static bool shouldDeleteNode(Node *N) {
return true;
}

/// Helper that \returns whether all sibling Functions of \p F (other Functions
/// inside its Module) are Loaded.
static bool shouldDeleteConstants(Function *F) {
Module *mod = F->getParent();
for (auto *MF : mod->getFunctions()) {
if (MF->getState() < FunctionState::FuncLoaded) {
return false;
}
}
return true;
}

bool DCE::run(Function *F) {
LOG_SCOPE(F->getLogContext(), getName());

Expand Down Expand Up @@ -99,6 +111,10 @@ bool DCE::run(Function *F) {
}
}

if (!shouldDeleteConstants(F)) {
return changed;
}

// Delete unused Constants.
for (auto it = consts.begin(), e = consts.end(); it != e;) {
if (!shouldDeleteNode(*it)) {
Expand Down Expand Up @@ -2757,7 +2773,9 @@ void glow::fold(Function *F, CompilationMode mode) {

void glow::optimize(Function *F, CompilationContext &cctx) {
LOG_SCOPE(F->getLogContext(), "glow::optimize")

// Indicates if the given function is completely loaded. A temporary
// workaround until #3213 is complete.
F->setState(FunctionState::FuncLoaded);
// Optimize may be called after backend specific transformations and some
// nodes may have become unused. It is a good idea to remove them, before
// proceeding with any further optimizations.
Expand Down
1 change: 1 addition & 0 deletions lib/Optimizer/GraphOptimizer/Lower.cpp
Expand Up @@ -1038,6 +1038,7 @@ static void lowerNode(Function *F, Node *node, CompilationContext &cctx) {
void glow::lower(Function *F, CompilationContext &cctx, const Backend *B,
const KindSet &doNotLowerKinds) {
LOG_SCOPE(F->getLogContext(), "glow::lower")
F->setState(FunctionState::FuncLoaded);

auto &nodes = F->getNodes();
for (auto &N : nodes) {
Expand Down

0 comments on commit 99b6571

Please sign in to comment.