Skip to content

Commit

Permalink
[AutoScheduler] Fix FLOPS estimation (apache#8695)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored and ylc committed Jan 13, 2022
1 parent 2cdbd54 commit 1748566
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
12 changes: 8 additions & 4 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -611,10 +611,14 @@ class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
}

#define VisitBinary(Node) \
double VisitExpr_(const Node* op) final { \
double base = op->dtype.code() == cur_type_code_ ? 1.0 : 0.0; \
return base + VisitExpr(op->a) + VisitExpr(op->b); \
// Index calculations (e.g., the "i + j" expression in A[i + j]) are not counted in FLOPS.
#define VisitBinary(Node) \
double VisitExpr_(const Node* op) final { \
double base = 1.0; \
if ((op->a->dtype.code() != cur_type_code_) && (op->b->dtype.code() != cur_type_code_)) { \
base = 0.0; \
} \
return base + VisitExpr(op->a) + VisitExpr(op->b); \
}

#define VisitUnary(Node) \
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_auto_scheduler_compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ def test_estimate_flop():
dag = auto_scheduler.ComputeDAG([A, B, F])
assert abs(dag.flop_ct - (2 * N ** 3 + 1234)) < 0.5

A = te.placeholder((N, N), dtype="float32", name="A")
F = te.compute((N, N), lambda i, j: te.if_then_else(A[i, j] > 0, A[i, j], 0))
dag = auto_scheduler.ComputeDAG([A, F])
assert abs(dag.flop_ct - N ** 2) < 0.5


def test_stage_order():
"""Test if the stage order is preserved when recovering a DAG."""
Expand Down

0 comments on commit 1748566

Please sign in to comment.