diff --git a/src/error/error_checks.cpp b/src/error/error_checks.cpp index ed2e63c24..9fc067d2d 100644 --- a/src/error/error_checks.cpp +++ b/src/error/error_checks.cpp @@ -1,9 +1,9 @@ #include "error_checks.h" +#include #include #include -#include -#include +#include #include "taco/type.h" #include "taco/index_notation/index_notation.h" @@ -26,41 +26,6 @@ static vector getAccessNodes(const IndexExpr& expr) { return readNodes; } -bool dimensionsTypecheck(const std::vector& resultVars, - const IndexExpr& expr, - const Shape& shape) { - - std::map indexVarDims; - for (size_t mode = 0; mode < resultVars.size(); mode++) { - IndexVar var = resultVars[mode]; - auto dimension = shape.getDimension(mode); - if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) { - return false; - } - else { - indexVarDims.insert({var, dimension}); - } - } - - vector readNodes = getAccessNodes(expr); - for (auto& readNode : readNodes) { - for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) { - IndexVar var = readNode->indexVars[mode]; - Dimension dimension = - readNode->tensorVar.getType().getShape().getDimension(mode); - if (util::contains(indexVarDims,var) && - indexVarDims.at(var) != dimension) { - return false; - } - else { - indexVarDims.insert({var, dimension}); - } - } - } - - return true; -} - static string addDimensionError(const IndexVar& var, Dimension dimension1, Dimension dimension2) { return "Index variable " + util::toString(var) + " is used to index " @@ -68,19 +33,17 @@ static string addDimensionError(const IndexVar& var, " and " + util::toString(dimension2) + ")."; } -std::string dimensionTypecheckErrors(const std::vector& resultVars, - const IndexExpr& expr, - const Shape& shape) { +std::pair dimensionsTypecheck(const std::vector& resultVars, + const IndexExpr& expr, + const Shape& shape) { vector errors; - std::map indexVarDims; for (size_t mode = 0; mode < resultVars.size(); mode++) { IndexVar var = resultVars[mode]; auto dimension = shape.getDimension(mode); if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) { errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension)); - } - else { + } else { indexVarDims.insert({var, dimension}); } } @@ -89,20 +52,16 @@ std::string dimensionTypecheckErrors(const std::vector& resultVars, for (auto& readNode : readNodes) { for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) { IndexVar var = readNode->indexVars[mode]; - Dimension dimension = - readNode->tensorVar.getType().getShape().getDimension(mode); - if (util::contains(indexVarDims,var) && - indexVarDims.at(var) != dimension) { - errors.push_back(addDimensionError(var, indexVarDims.at(var), - dimension)); - } - else { + Dimension dimension = readNode->tensorVar.getType().getShape().getDimension(mode); + if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) { + errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension)); + } else { indexVarDims.insert({var, dimension}); } } } - return util::join(errors, " "); + return std::make_pair(errors.empty(), util::join(errors, " ")); } static void addEdges(vector indexVars, vector modeOrdering, diff --git a/src/error/error_checks.h b/src/error/error_checks.h index 478e2afee..5211f8f3b 100644 --- a/src/error/error_checks.h +++ b/src/error/error_checks.h @@ -3,6 +3,7 @@ #include #include +#include namespace taco { class IndexVar; @@ -12,15 +13,12 @@ class Shape; namespace error { -/// Check that the dimensions indexed by the same variable are the same -bool dimensionsTypecheck(const std::vector& resultVars, - const IndexExpr& expr, - const Shape& shape); - -/// Returns error strings for index variables that don't typecheck -std::string dimensionTypecheckErrors(const std::vector& resultVars, - const IndexExpr& expr, - const Shape& shape); +/// Check whether all dimensions indexed by the same variable are the same. +/// If they are not, then the first element of the returned tuple will be false, +/// and a human readable error will be returned in the second component. +std::pair dimensionsTypecheck(const std::vector& resultVars, + const IndexExpr& expr, + const Shape& shape); /// Returns true iff the index expression contains a transposition. bool containsTranspose(const Format& resultFormat, diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 2406ae597..e5bf102e5 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -763,9 +763,8 @@ static void check(Assignment assignment) { auto freeVars = assignment.getLhs().getIndexVars(); auto indexExpr = assignment.getRhs(); auto shape = tensorVar.getType().getShape(); - taco_uassert(error::dimensionsTypecheck(freeVars, indexExpr, shape)) - << error::expr_dimension_mismatch << " " - << error::dimensionTypecheckErrors(freeVars, indexExpr, shape); + auto typecheck = error::dimensionsTypecheck(freeVars, indexExpr, shape); + taco_uassert(typecheck.first) << error::expr_dimension_mismatch << " " << typecheck.second; } Assignment Access::operator=(const IndexExpr& expr) { @@ -1952,9 +1951,9 @@ static bool isValid(Assignment assignment, string* reason) { auto result = lhs.getTensorVar(); auto freeVars = lhs.getIndexVars(); auto shape = result.getType().getShape(); - if(!error::dimensionsTypecheck(freeVars, rhs, shape)) { - *reason = error::expr_dimension_mismatch + " " + - error::dimensionTypecheckErrors(freeVars, rhs, shape); + auto typecheck = error::dimensionsTypecheck(freeVars, rhs, shape); + if (!typecheck.first) { + *reason = error::expr_dimension_mismatch + " " + typecheck.second; return false; } return true;