Skip to content

Commit

Permalink
[graph-optimizer] Replace reshapes of splat nodes by splats with requ…
Browse files Browse the repository at this point in the history
…ired shapes
  • Loading branch information
Chi Zhang authored and ZchiPitt committed Jun 19, 2018
1 parent 911f7ea commit b75b3e2
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 2 deletions.
17 changes: 15 additions & 2 deletions lib/Optimizer/GraphOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1075,15 +1075,28 @@ static void optimizeSliceOfSplat(Function *F) {
}
}

/// Eliminate ReshapeNode when the input is already the correct shape.
/// Optimize reshape nodes.
static void optimizeReshape(Function *F) {
for (auto &node : F->getNodes()) {
auto *reshapeNode = dyn_cast<ReshapeNode>(&node);
if (!reshapeNode)
continue;
auto inputNode = reshapeNode->getInput();
auto &inputNode = reshapeNode->getNthInput(0);
// Eliminate ReshapeNode when the input is already the correct shape.
if (inputNode.dims() == reshapeNode->dims()) {
reshapeNode->getResult().replaceAllUsesOfWith(inputNode);
continue;
}
// Reshape(Splat(args)) -> Splat(args').
auto *splatNode = dyn_cast<SplatNode>(inputNode);
if (splatNode && splatNode->hasOneUse()) {
// Splat node with more than one use can not be transformed, otherwise
// we would increase the number of splats, which may lead to increased
// memory consumption during the execution of the NN model.
auto *newSplatNode = F->createSplat(
splatNode->getName(), reshapeNode->getType(), splatNode->getValue());
reshapeNode->getResult().replaceAllUsesOfWith(newSplatNode);
continue;
}
}
}
Expand Down
53 changes: 53 additions & 0 deletions tests/unittests/graphOptzTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,59 @@ TEST_F(GraphOptz, ReshapeNoop) {
EXPECT_TRUE(SN->getType()->dims().equals(shape));
}

/// Test the Reshape(Splat(args)) -> Splat(args') transformation.
/// Including a positive and a negative test case. In the positive case,
/// the optimization will take place for the splat node (Z2) that has only one
/// use. In the negative case, the optimization will not happen as the splat
/// node (Z1) has more than one use.
TEST_F(GraphOptz, ReshapeAfterSplat) {
const size_t shape[] = {10, 20, 30};
const size_t reshape[] = {1, 6000};
Type t1(ElemKind::FloatTy, shape);
Type t2(ElemKind::FloatTy, reshape);
Node *input =
F_->getParent()->createVariable(ElemKind::FloatTy, shape, "input");
auto *Z1 = F_->createSplat("zero1", &t1, 1.5);
auto *A1 = F_->createAdd("add1", Z1->getType(), input, Z1);
auto *R1 = F_->createReshape("reshape1", Z1, reshape);
// Z1 is used by R1 and A1.
// The reshape optimization will thus NOT be able to remove this reshape node
// (R1).
auto *R2 = F_->createReshape("reshape2", A1, reshape);
auto *A2 = F_->createAdd("add", R1->getType(), R1, R2);
auto *Z2 = F_->createSplat("zero2", &t1, 2.5);
auto *R3 = F_->createReshape("reshape3", Z2, reshape);
// Z2 is only used by R3.
// The Z2,R3 nodes will be replaced by a new splat node with the shape of R3.
auto *A3 = F_->createAdd("add", A2->getType(), A2, R3);
auto *O = F_->createSave("ret", A3);

// Before optimization, we have 9 nodes in the graph.
EXPECT_EQ(F_->getNodes().size(), 9);

::glow::optimize(F_, CompilationMode::Infer);

// After optimization, we expect to see only 8 nodes, as Z2,R2 would be
// replace by a new splat node.
EXPECT_EQ(F_->getNodes().size(), 8);

// The second input of A3 shoule be a splat node with a shape of R3.
auto *SN = llvm::dyn_cast<SplatNode>(
llvm::dyn_cast<SaveNode>(O)->getInput()->getNthInput(1));
EXPECT_TRUE(SN);
EXPECT_TRUE(SN->getType()->dims().equals(reshape));

// R1 should still be in the graph.
EXPECT_TRUE(std::find_if(F_->getNodes().begin(), F_->getNodes().end(),
IsSameNodeAddress(R1)) != F_->getNodes().end());

// R3 and Z2 should not be in the graph any more.
EXPECT_TRUE(std::find_if(F_->getNodes().begin(), F_->getNodes().end(),
IsSameNodeAddress(R3)) == F_->getNodes().end());
EXPECT_TRUE(std::find_if(F_->getNodes().begin(), F_->getNodes().end(),
IsSameNodeAddress(Z2)) == F_->getNodes().end());
}

TEST_F(GraphOptz, DCEPublicVars) {
mod_.createVariable(ElemKind::FloatTy, {4, 320, 200, 3}, "input",
VisibilityKind::Public);
Expand Down

0 comments on commit b75b3e2

Please sign in to comment.