Skip to content

Commit

Permalink
Fix onnx exporter to remove unnecessary Reshape
Browse files Browse the repository at this point in the history
Summary: We should not simply add the `Reshape` when we see a `Save`, which doesn't make sense.

Reviewed By: tracelogfb

Differential Revision: D18632551

fbshipit-source-id: 4f06a030d5453610d68710c61cae27e45be2f4cb
  • Loading branch information
Yinghai Lu authored and facebook-github-bot committed Dec 2, 2019
1 parent a210f97 commit 0d57af5
Showing 1 changed file with 17 additions and 16 deletions.
33 changes: 17 additions & 16 deletions lib/Exporter/ONNXModelWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,19 @@ void inputsToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
}
}

/// Write the output of the provided node, add SaveNode if necessary
void outputKindToProto(const Node *node, ONNX_NAMESPACE::NodeProto *proto) {
for (const auto &use : node->getUsers()) {
const auto *user = use.getUser();
if (user->getKind() == Kinded::Kind::SaveNodeKind) {
const SaveNode *SN = llvm::cast<SaveNode>(user);
proto->add_output(SN->getPlaceholder()->getName());
} else {
outputsToProto(user, proto);
}
}
}

/// Write the output of the provided type only of node outputs.
bool outputKindToProto(Kinded::Kind kind, const Node *node,
ONNX_NAMESPACE::NodeProto *proto) {
Expand Down Expand Up @@ -757,14 +770,8 @@ Error ONNXModelWriter::writeBatchedReduceMean(const BatchedReduceMeanNode *node,
proto->set_op_type("ReduceMean");
inputsToProto(node, proto);

// Use the output of reshape node.
if (outputKindToProto(Kinded::Kind::ReshapeNodeKind, node, proto)) {
// Add dictionary entries.
addValueAttribute(proto, "keepdims", 1);
} else {
addValueAttribute(proto, "keepdims", 0);
outputsToProto(node, proto);
}
addValueAttribute(proto, "keepdims", 0);
outputKindToProto(node, proto);

return Error::success();
}
Expand All @@ -781,14 +788,8 @@ Error ONNXModelWriter::writeBatchedReduceAdd(const BatchedReduceAddNode *node,
proto->set_op_type("ReduceSum");
inputsToProto(node, proto);

// Use the output of reshape node.
if (outputKindToProto(Kinded::Kind::ReshapeNodeKind, node, proto)) {
// Add dictionary entries.
addValueAttribute(proto, "keepdims", 1);
} else {
addValueAttribute(proto, "keepdims", 0);
outputsToProto(node, proto);
}
addValueAttribute(proto, "keepdims", 0);
outputKindToProto(node, proto);

return Error::success();
}
Expand Down

0 comments on commit 0d57af5

Please sign in to comment.