Skip to content

Commit

Permalink
[nnc] Test cases for uneven split + reorder
Browse files Browse the repository at this point in the history
Split with tail followed by reorder causes a segfault in NNC
Split with mask followed by reorder generates invalid code that writes out of
bounds

Differential Revision: [D26746254](https://our.internmc.facebook.com/intern/diff/D26746254/)

ghstack-source-id: 122827308
Pull Request resolved: #53091
  • Loading branch information
bertmaher committed Mar 2, 2021
1 parent 0569f63 commit a45c787
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions test/cpp/tensorexpr/test_loopnest.cpp
Expand Up @@ -3743,5 +3743,63 @@ TEST(LoopNest, InlineFromLoad) {
oss.str());
}

Tensor* colReduce(int M, int N) {
Placeholder a("a", kInt, {M, N});
return Reduce(
"b",
{{N, "n"}},
Sum(),
[&](const VarHandle& n, const VarHandle& m) { return a.load(m, n); },
{{M, "m"}});
}

void splitTailReorder(Tensor* b) {
constexpr int kVectorWidth = 8;
For *outer, *inner, *tail;
LoopNest nest({b});
auto loops = nest.getLoopStmtsFor(b);
nest.splitWithTail(loops[0], kVectorWidth, &outer, &inner, &tail);
loops = nest.getLoopStmtsFor(b);
nest.reorderAxis(loops[1], loops[2]);
}

void splitMaskReorder(Tensor* b) {
constexpr int kVectorWidth = 8;
For *outer, *inner;
LoopNest nest({b});
auto loops = nest.getLoopStmtsFor(b);
nest.splitWithMask(loops[0], kVectorWidth, &outer, &inner);
loops = nest.getLoopStmtsFor(b);
nest.reorderAxis(loops[1], loops[2]);
std::clog << *nest.root_stmt() << "\n";
}

TEST(LoopNest, ColReduceSplitEvenReorder) {
KernelScope kernel_scope;
constexpr int M = 76, N = 128;
Tensor* b = colReduce(M, N);
splitTailReorder(b);
}

TEST(LoopNest, ColReduceSplitUnevenReorder) {
KernelScope kernel_scope;
constexpr int M = 76, N = 100;
Tensor* b = colReduce(M, N);
splitTailReorder(b);
}

TEST(LoopNest, ColReduceSplitMaskEvenReorder) {
KernelScope kernel_scope;
constexpr int M = 76, N = 128;
Tensor* b = colReduce(M, N);
splitMaskReorder(b);
}

TEST(LoopNest, ColReduceSplitMaskUnevenReorder) {
KernelScope kernel_scope;
constexpr int M = 76, N = 100;
Tensor* b = colReduce(M, N);
splitMaskReorder(b);
}
} // namespace jit
} // namespace torch

0 comments on commit a45c787

Please sign in to comment.