Skip to content

Commit

Permalink
Add opt for eliminating concat->slices (#3688)
Browse files Browse the repository at this point in the history
Summary:
Skip unnecessary patterns where we see a concat followed by slices of its output, where we can eliminate both concat and slices. For example the following graph:

<img width="1226" alt="Screen Shot 2019-10-26 at 12 22 26 AM" src="https://user-images.githubusercontent.com/1198212/67615893-bbcd8f00-f786-11e9-8921-0fbb91eed09c.png">

Is transformed into:

<img width="1226" alt="Screen Shot 2019-10-26 at 12 23 17 AM" src="https://user-images.githubusercontent.com/1198212/67615900-d142b900-f786-11e9-8b95-5d91578013c8.png">
Pull Request resolved: #3688

Differential Revision: D18159045

Pulled By: jfix71

fbshipit-source-id: 1669797d1bfe7498329a13712924f417bcc325ba
  • Loading branch information
jfix71 authored and facebook-github-bot committed Oct 28, 2019
1 parent 59c33b9 commit a13d88d
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 10 deletions.
13 changes: 8 additions & 5 deletions include/glow/Base/Type.h
Expand Up @@ -441,8 +441,9 @@ struct Type final {
return std::make_pair(lowFloat, highFloat);
}

/// \returns true if \p other is the same type.
bool isEqual(const Type &other) const {
/// \returns true if \p other is the same type. If \p allowDifferentShape then
/// shapes will not be considered as part of the equal comparison.
bool isEqual(const Type &other, bool allowDifferentShape = false) const {
// Element type must be the same.
if (elementType_ != other.elementType_) {
return false;
Expand All @@ -452,9 +453,11 @@ struct Type final {
return false;
}
// Sizes must be the same.
for (size_t i = 0; i < numSizes_; i++) {
if (sizes_[i] != other.sizes_[i]) {
return false;
if (!allowDifferentShape) {
for (size_t i = 0; i < numSizes_; i++) {
if (sizes_[i] != other.sizes_[i]) {
return false;
}
}
}

Expand Down
68 changes: 63 additions & 5 deletions lib/Optimizer/GraphOptimizer/GraphOptimizer.cpp
Expand Up @@ -1101,8 +1101,24 @@ static bool findSlicesThatSpanInput(llvm::ArrayRef<SliceNode *> input,

// For each slice:
for (SliceNode *SN : input) {
// Ignore slices of invalid types.
if (lastSlice->getResult().getType() != SN->getResult().getType()) {
// Ignore slices of invalid types. Ignore shapes for now, that's checked
// next while ignoring the axis dimension.
if (!lastSlice->getResult().getType()->isEqual(
*SN->getResult().getType(),
/* allowDifferentShape */ true)) {
continue;
}

// Check if shapes match except for the axis dimension.
bool skip = false;
for (size_t i = 0, e = lastSlice->getResult().dims().size(); i < e; ++i) {
if (i != dimension &&
lastSlice->getResult().dims()[i] != SN->getResult().dims()[i]) {
skip = true;
break;
}
}
if (skip) {
continue;
}

Expand Down Expand Up @@ -1653,6 +1669,43 @@ static NodeValue simplifyConcatNode(Function *F, ConcatNode *CN) {
return NodeValue(nullptr);
}

/// If all of the outputs of \p CN are essentially piped from the inputs of the
/// concat (i.e. same shape, axis, order) then we can get rid of the slices and
/// concat. \returns true if this optimization is successful and changes the
/// Function.
static bool combineConcatSlices(ConcatNode *CN) {
auto inputsToCN = CN->getInputs();
std::vector<SliceNode *> slices;
std::vector<SliceNode *> orderedSlices;
for (auto &user : CN->getUsers()) {
if (SliceNode *SN = dyn_cast<SliceNode>(user.getUser())) {
slices.push_back(SN);
}
}

// Check if the slices span the input value.
bool found = findSlicesThatSpanInput(slices, CN->getDim(), orderedSlices);
if (!found || orderedSlices.size() != slices.size() ||
orderedSlices.size() != inputsToCN.size()) {
return false;
}

// Now verify that all of the inputs to CN have the same shape as all of the
// slices for the result of CN.
for (size_t i = 0, e = orderedSlices.size(); i < e; ++i) {
if (orderedSlices[i]->getResult().dims() != inputsToCN[i].dims()) {
return false;
}
}

// We can now replace all of the inputs to the concat to the result of
// each slice.
for (size_t i = 0, e = inputsToCN.size(); i < e; ++i) {
orderedSlices[i]->getResult().replaceAllUsesOfWith(inputsToCN[i]);
}
return true;
}

/// Optimize Concat nodes.
bool OptimizeConcatNodes::run(Function *F, const CompilationContext &cctx) {
LOG_SCOPE(F->getLogContext(), getName());
Expand All @@ -1666,11 +1719,16 @@ bool OptimizeConcatNodes::run(Function *F, const CompilationContext &cctx) {
continue;
}
NodeValue newCN = simplifyConcatNode(F, CN);
if (!newCN.getNode()) {
if (newCN.getNode()) {
CN->getResult().replaceAllUsesOfWith(newCN);
changed = true;
continue;
}

if (combineConcatSlices(CN)) {
changed = true;
continue;
}
CN->getResult().replaceAllUsesOfWith(newCN);
changed = true;
}
return changed;
}
Expand Down
32 changes: 32 additions & 0 deletions tests/unittests/GraphOptzTest.cpp
Expand Up @@ -2538,6 +2538,38 @@ TEST_F(GraphOptz, concatElim) {
EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 0);
}

/// Check that we are able to eliminate concat followed by slices under certain
/// conditions.
TEST_F(GraphOptz, concatSliceElim) {
constexpr size_t N = 5;
std::array<NodeValue, N> inputs;
for (size_t i = 0; i < N; i++) {
inputs[i] = mod_.createPlaceholder(ElemKind::FloatTy, {1 + i, 10, 20},
"input", true);
}
auto *CN = F_->createConcat("merge", inputs, 0);

// Split the concat to a bunch of slices of the same shape as the concat
// inputs and on the same axis.
for (size_t i = 0; i < N; i++) {
auto *SN = F_->createSlice("extract", CN, {(i * (i + 1)) / 2, 0, 0},
{((i + 1) * (i + 2)) / 2, 10, 20});
F_->createSave("save", SN);
}

// We created a concat followed by N slices of its results.
EXPECT_EQ(countNodeKind(F_, Kinded::Kind::SliceNodeKind), N);
EXPECT_EQ(countNodeKind(F_, Kinded::Kind::ConcatNodeKind), 1);

optimizedF_ = optimizeFunction(F_);

// Check that the concat and slices are gone.
EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::ConcatNodeKind), 0);
EXPECT_EQ(countNodeKind(optimizedF_, Kinded::Kind::SliceNodeKind), 0);

checkNumericalEquivalence();
}

// Check the transformation Concat(Reshape(x) * N) -> Reshape(Concat(x * N)).
TEST_F(GraphOptz, concatReshapes) {
const size_t shape1[] = {2, 5, 2, 1, 20};
Expand Down

0 comments on commit a13d88d

Please sign in to comment.