From 0cbc42e356e1444172b9844f41949e9ce0857f4b Mon Sep 17 00:00:00 2001 From: Rohan Yadav Date: Sat, 24 Apr 2021 11:17:05 -0700 Subject: [PATCH] index_notation: fix a bug where windows would be dropped through `+=` Fixes #451. This commit fixes a bug where windows applied to index variables would be dropped when assigned to via the `+=` operator. --- include/taco/index_notation/index_notation.h | 6 ++++-- src/index_notation/index_notation.cpp | 14 +++++++++++--- src/tensor.cpp | 6 ++++-- test/tests-windowing.cpp | 12 ++++++++++++ 4 files changed, 31 insertions(+), 7 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index c706c7d4b..cc8c34dea 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -714,9 +714,11 @@ class Assignment : public IndexStmt { Assignment(Access lhs, IndexExpr rhs, IndexExpr op = IndexExpr()); /// Create an assignment. Can specify an optional operator `op` that turns the - /// assignment into a compound assignment, e.g. `+=`. + /// assignment into a compound assignment, e.g. `+=`. Additionally, specify + /// any modifers on reduction index variables (windows, index sets, etc.). Assignment(TensorVar tensor, std::vector indices, IndexExpr rhs, - IndexExpr op = IndexExpr()); + IndexExpr op = IndexExpr(), + const std::map>& modifiers = {}); /// Return the assignment's left-hand side. Access getLhs() const; diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index e14f746ed..12544f62e 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -899,7 +899,14 @@ Assignment Access::operator=(const TensorVar& var) { Assignment Access::operator+=(const IndexExpr& expr) { TensorVar result = getTensorVar(); - Assignment assignment = Assignment(result, getIndexVars(), expr, Add()); + Assignment assignment = Assignment( + result, + getIndexVars(), + expr, + Add(), + // Include any windows on LHS index vars. + getNode(*this)->packageModifiers() + ); // check(assignment); TODO: fix check for precompute const_cast(getNode(*this))->setAssignment(assignment); return assignment; @@ -1686,8 +1693,9 @@ Assignment::Assignment(Access lhs, IndexExpr rhs, IndexExpr op) } Assignment::Assignment(TensorVar tensor, vector indices, - IndexExpr rhs, IndexExpr op) - : Assignment(Access(tensor, indices), rhs, op) { + IndexExpr rhs, IndexExpr op, + const std::map>& modifiers) + : Assignment(Access(tensor, indices, modifiers), rhs, op) { } Access Assignment::getLhs() const { diff --git a/src/tensor.cpp b/src/tensor.cpp index 4877ebc0c..fab437ff1 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -760,8 +760,10 @@ vector packArguments(const TensorBase& tensor) { // Pack any index sets on the result tensor at the front of the arguments list. auto lhs = getNode(tensor.getAssignment().getLhs()); - if (isa(lhs)) { - auto indexSetModes = to(lhs)->indexSetModes; + // We check isa rather than isa to catch cases + // where the underlying access is represented with the base AccessNode class. + if (isa(lhs)) { + auto indexSetModes = to(lhs)->indexSetModes; for (auto& it : indexSetModes) { arguments.push_back(it.second.tensor.getStorage()); } diff --git a/test/tests-windowing.cpp b/test/tests-windowing.cpp index 014619a35..6199c525e 100644 --- a/test/tests-windowing.cpp +++ b/test/tests-windowing.cpp @@ -540,3 +540,15 @@ TEST(windowing, lhsIndexSet) { a.evaluate(); ASSERT_TRUE(equals(a, expected)) << a << endl << expected << endl; } + +// See issue #451 for details about this test. +TEST(windowing, compoundAssign) { + Tensor A("A", {8}, Format{Dense}); + Tensor B("B", {3}, Format{Dense}); + IndexVar i("i"); + A(i(0,3)) += B(i); + // We explicitly call assemble here as the += makes evaluate not call it. + A.compile(); A.assemble(); A.compute(); + A(i({1, 3, 6})) += B(i); + A.evaluate(); +} \ No newline at end of file