diff --git a/src/tir/transforms/inject_copy_intrin.cc b/src/tir/transforms/inject_copy_intrin.cc index f99cbd5b5a05e..9e74b8cd1fdb7 100644 --- a/src/tir/transforms/inject_copy_intrin.cc +++ b/src/tir/transforms/inject_copy_intrin.cc @@ -45,26 +45,35 @@ class CopyIntrinInjector : public StmtMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == pragma_key_) { Stmt ret; - ICHECK(MatchCopyPattern(op->body, &ret)) << "Cannot match copy pattern of " << op->body; + std::string error_info; + ICHECK(MatchCopyPattern(op->body, &ret, &error_info)) + << "Cannot match copy pattern. The error is " << error_info << " The body is " + << op->body; return ret; } return StmtMutator::VisitStmt_(op); } private: - bool MatchCopyPattern(Stmt stmt, Stmt* out) { + bool MatchCopyPattern(Stmt stmt, Stmt* out, std::string* error_info) { using namespace arith; Stmt body = stmt; // strip the loops std::vector loops; while (const ForNode* op = body.as()) { - if (!is_zero(op->min)) return false; + if (!is_zero(op->min)) { + *error_info = "the 'min' value of body 'Fonode' is 0."; + return false; + } loops.push_back(op); body = op->body; } const StoreNode* store = body.as(); - if (store == nullptr) return false; + if (store == nullptr) { + *error_info = "the 'StoreNode' of body is a nullptr."; + return false; + } // Expr sel_cond, sel_true_value, sel_false_value; // match select or if PVar sel_cond, sel_true_value, sel_false_value; @@ -84,7 +93,10 @@ class CopyIntrinInjector : public StmtMutator { if (cast != nullptr) { load = cast->value.as(); } - if (load == nullptr) return false; + if (load == nullptr) { + *error_info = "the 'LoadNode' of body is a nullptr."; + return false; + } if (load->dtype.lanes() != 1) return false; Array loop_vars; for (const ForNode* op : loops) { @@ -109,7 +121,10 @@ class CopyIntrinInjector : public StmtMutator { if (has_cond) { Array clip_bound = arith::DetectClipBound(sel_cond.Eval(), loop_vars); pad_value = sel_false_value.Eval(); - if (clip_bound.size() == 0) return false; + if (clip_bound.size() == 0) { + *error_info = "the size of clip bound is 0."; + return false; + } ICHECK_EQ(src_shape.size(), loop_vars.size()); ICHECK_EQ(clip_bound.size(), loop_vars.size() * 2); for (size_t i = 0; i < src_shape.size(); ++i) { @@ -150,7 +165,10 @@ class CopyIntrinInjector : public StmtMutator { Buffer src = Buffer(load->buffer_var, load->dtype, src_shape, src_strides, src_elem_offset, load->buffer_var->name_hint, 0, 0, kDefault); *out = flower_copy_fromto_(src, dst, pad_before, pad_after, pad_value); - ICHECK(out->defined()) << "flower function did not return correct stmt"; + if (!out->defined()) { + *error_info = "flower function did not return correct stmt"; + return false; + } return true; }