Skip to content

Commit

Permalink
[Tir]Adding detail error messages when MatchCopyPattern function is f…
Browse files Browse the repository at this point in the history
…ailed. (apache#10244)

There is an error message to show the body when 'MatchCopyPattern' is failed,
but the error message not give the information why this function get failed.
Adding the detail error information to help trouble shooting.
  • Loading branch information
huajsj authored and pfk-beta committed Apr 11, 2022
1 parent cda508c commit 41265be
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/tir/transforms/inject_copy_intrin.cc
Expand Up @@ -45,26 +45,35 @@ class CopyIntrinInjector : public StmtMutator {
Stmt VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == pragma_key_) {
Stmt ret;
ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body;
std::string error_info;
ICHECK(MatchCopyPattern(op->body, &ret, &error_info))
<< "Cannot match copy pattern. The error is " << error_info << " The body is "
<< op->body;
return ret;
}
return StmtMutator::VisitStmt_(op);
}

private:
bool MatchCopyPattern(Stmt stmt, Stmt* out) {
bool MatchCopyPattern(Stmt stmt, Stmt* out, std::string* error_info) {
using namespace arith;
Stmt body = stmt;

// strip the loops
std::vector<const ForNode*> loops;
while (const ForNode* op = body.as<ForNode>()) {
if (!is_zero(op->min)) return false;
if (!is_zero(op->min)) {
*error_info = "the 'min' value of body 'Fonode' is 0.";
return false;
}
loops.push_back(op);
body = op->body;
}
const StoreNode* store = body.as<StoreNode>();
if (store == nullptr) return false;
if (store == nullptr) {
*error_info = "the 'StoreNode' of body is a nullptr.";
return false;
}
// Expr sel_cond, sel_true_value, sel_false_value;
// match select or if
PVar<PrimExpr> sel_cond, sel_true_value, sel_false_value;
Expand All @@ -84,7 +93,10 @@ class CopyIntrinInjector : public StmtMutator {
if (cast != nullptr) {
load = cast->value.as<LoadNode>();
}
if (load == nullptr) return false;
if (load == nullptr) {
*error_info = "the 'LoadNode' of body is a nullptr.";
return false;
}
if (load->dtype.lanes() != 1) return false;
Array<Var> loop_vars;
for (const ForNode* op : loops) {
Expand All @@ -109,7 +121,10 @@ class CopyIntrinInjector : public StmtMutator {
if (has_cond) {
Array<PrimExpr> clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars);
pad_value = sel_false_value.Eval();
if (clip_bound.size() == 0) return false;
if (clip_bound.size() == 0) {
*error_info = "the size of clip bound is 0.";
return false;
}
ICHECK_EQ(src_shape.size(), loop_vars.size());
ICHECK_EQ(clip_bound.size(), loop_vars.size() * 2);
for (size_t i = 0; i < src_shape.size(); ++i) {
Expand Down Expand Up @@ -150,7 +165,10 @@ class CopyIntrinInjector : public StmtMutator {
Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset,
load->buffer_var->name_hint, 0, 0, kDefault);
*out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value);
ICHECK(out->defined()) << "flower function did not return correct stmt";
if (!out->defined()) {
*error_info = "flower function did not return correct stmt";
return false;
}
return true;
}

Expand Down

0 comments on commit 41265be

Please sign in to comment.