Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 11 additions & 52 deletions src/error/error_checks.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#include "error_checks.h"

#include <functional>
#include <map>
#include <set>
#include <stack>
#include <functional>
#include <tuple>

#include "taco/type.h"
#include "taco/index_notation/index_notation.h"
Expand All @@ -26,61 +26,24 @@ static vector<const AccessNode*> getAccessNodes(const IndexExpr& expr) {
return readNodes;
}

bool dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
const IndexExpr& expr,
const Shape& shape) {

std::map<IndexVar,Dimension> indexVarDims;
for (size_t mode = 0; mode < resultVars.size(); mode++) {
IndexVar var = resultVars[mode];
auto dimension = shape.getDimension(mode);
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
return false;
}
else {
indexVarDims.insert({var, dimension});
}
}

vector<const AccessNode*> readNodes = getAccessNodes(expr);
for (auto& readNode : readNodes) {
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
IndexVar var = readNode->indexVars[mode];
Dimension dimension =
readNode->tensorVar.getType().getShape().getDimension(mode);
if (util::contains(indexVarDims,var) &&
indexVarDims.at(var) != dimension) {
return false;
}
else {
indexVarDims.insert({var, dimension});
}
}
}

return true;
}

static string addDimensionError(const IndexVar& var,
Dimension dimension1, Dimension dimension2) {
return "Index variable " + util::toString(var) + " is used to index "
"modes of different dimensions (" + util::toString(dimension1) +
" and " + util::toString(dimension2) + ").";
}

std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
const IndexExpr& expr,
const Shape& shape) {
std::pair<bool, string> dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
const IndexExpr& expr,
const Shape& shape) {
vector<string> errors;

std::map<IndexVar,Dimension> indexVarDims;
for (size_t mode = 0; mode < resultVars.size(); mode++) {
IndexVar var = resultVars[mode];
auto dimension = shape.getDimension(mode);
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
}
else {
} else {
indexVarDims.insert({var, dimension});
}
}
Expand All @@ -89,20 +52,16 @@ std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
for (auto& readNode : readNodes) {
for (size_t mode = 0; mode < readNode->indexVars.size(); mode++) {
IndexVar var = readNode->indexVars[mode];
Dimension dimension =
readNode->tensorVar.getType().getShape().getDimension(mode);
if (util::contains(indexVarDims,var) &&
indexVarDims.at(var) != dimension) {
errors.push_back(addDimensionError(var, indexVarDims.at(var),
dimension));
}
else {
Dimension dimension = readNode->tensorVar.getType().getShape().getDimension(mode);
if (util::contains(indexVarDims,var) && indexVarDims.at(var) != dimension) {
errors.push_back(addDimensionError(var, indexVarDims.at(var), dimension));
} else {
indexVarDims.insert({var, dimension});
}
}
}

return util::join(errors, " ");
return std::make_pair(errors.empty(), util::join(errors, " "));
}

static void addEdges(vector<IndexVar> indexVars, vector<int> modeOrdering,
Expand Down
16 changes: 7 additions & 9 deletions src/error/error_checks.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <vector>
#include <string>
#include <tuple>

namespace taco {
class IndexVar;
Expand All @@ -12,15 +13,12 @@ class Shape;

namespace error {

/// Check that the dimensions indexed by the same variable are the same
bool dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
const IndexExpr& expr,
const Shape& shape);

/// Returns error strings for index variables that don't typecheck
std::string dimensionTypecheckErrors(const std::vector<IndexVar>& resultVars,
const IndexExpr& expr,
const Shape& shape);
/// Check whether all dimensions indexed by the same variable are the same.
/// If they are not, then the first element of the returned tuple will be false,
/// and a human readable error will be returned in the second component.
std::pair<bool, std::string> dimensionsTypecheck(const std::vector<IndexVar>& resultVars,
const IndexExpr& expr,
const Shape& shape);

/// Returns true iff the index expression contains a transposition.
bool containsTranspose(const Format& resultFormat,
Expand Down
11 changes: 5 additions & 6 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,8 @@ static void check(Assignment assignment) {
auto freeVars = assignment.getLhs().getIndexVars();
auto indexExpr = assignment.getRhs();
auto shape = tensorVar.getType().getShape();
taco_uassert(error::dimensionsTypecheck(freeVars, indexExpr, shape))
<< error::expr_dimension_mismatch << " "
<< error::dimensionTypecheckErrors(freeVars, indexExpr, shape);
auto typecheck = error::dimensionsTypecheck(freeVars, indexExpr, shape);
taco_uassert(typecheck.first) << error::expr_dimension_mismatch << " " << typecheck.second;
}

Assignment Access::operator=(const IndexExpr& expr) {
Expand Down Expand Up @@ -1952,9 +1951,9 @@ static bool isValid(Assignment assignment, string* reason) {
auto result = lhs.getTensorVar();
auto freeVars = lhs.getIndexVars();
auto shape = result.getType().getShape();
if(!error::dimensionsTypecheck(freeVars, rhs, shape)) {
*reason = error::expr_dimension_mismatch + " " +
error::dimensionTypecheckErrors(freeVars, rhs, shape);
auto typecheck = error::dimensionsTypecheck(freeVars, rhs, shape);
if (!typecheck.first) {
*reason = error::expr_dimension_mismatch + " " + typecheck.second;
return false;
}
return true;
Expand Down