From fee26d0b66f570e1eedf61843b234b5da844e2d4 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Fri, 14 Jan 2022 02:20:12 -0800 Subject: [PATCH] Fix irregular BMM example. (#44) * upd * upd * upd * upd --- include/tvm/tir/sparse.h | 17 +++--- src/tir/ir/sparse.cc | 61 +++++++++++-------- src/tir/transforms/lower_sparse_tir.cc | 3 +- .../python/sparsetir/test_tir_sparse_lower.py | 3 +- 4 files changed, 49 insertions(+), 35 deletions(-) diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 5a5e1d8b7d89..6a8196c50abc 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -87,7 +87,7 @@ class AxisNode : public Object { virtual PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const = 0; virtual PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const = 0; - virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const = 0; + virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const = 0; std::tuple GetOffsetExtent(SparseCtx* ctx) const; static constexpr const char* _type_key = "tir.sparse.Axis"; @@ -156,7 +156,7 @@ class DenseFixedAxisNode : public DenseAxisNode { PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; - PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const; + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis"; TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode); @@ -285,7 +285,7 @@ class DenseVariableAxisNode : public DenseAxisNode { PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; - PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const; + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis"; TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode); @@ -309,9 +309,9 @@ class DenseVariableAxis : public DenseAxis { class AttachedAxisNode : public DenseVariableAxisNode { public: /* The original axis before attaching. */ - Axis orig_; + DenseVariableAxis orig_; - Axis GetOriginalAxis() const { return orig_; } + PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const; @@ -325,7 +325,8 @@ class AttachedAxisNode : public DenseVariableAxisNode { */ class AttachedAxis : public DenseVariableAxis { public: - TVM_DLL explicit AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr); + TVM_DLL explicit AttachedAxis(String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz, + Buffer indptr); TVM_DEFINE_OBJECT_REF_METHODS(AttachedAxis, DenseVariableAxis, AttachedAxisNode); }; @@ -366,7 +367,7 @@ class SparseFixedAxisNode : public SparseAxisNode { PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; - PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const; + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode); @@ -420,7 +421,7 @@ class SparseVariableAxisNode : public SparseAxisNode { PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const; - PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const; + PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const; static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis"; TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 6adec3fdbf04..6709795dc4cb 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -91,7 +91,9 @@ PrimExpr DenseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const return coordinate; } -PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; } +PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { + return index; +} TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); @@ -187,14 +189,16 @@ DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length, PrimExpr DenseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { Axis prev_axis = ctx->GetPrevAxis(GetRef(this)).value(); PrimExpr prev_offset = ctx->GetOffset(prev_axis); - return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index)); + return BufferLoad(indptr, {std::move(prev_offset)}) + std::move(index); } PrimExpr DenseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const { return coordinate; } -PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; } +PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { + return index; +} TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); @@ -213,7 +217,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) /******** AttachedAxis ********/ /*! \brief Default constructor of AttachedAxis */ -AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) { +AttachedAxis::AttachedAxis(String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz, + Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); node->parent_ = std::move(parent); @@ -225,37 +230,43 @@ AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Bu } PrimExpr AttachedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { - PrimExpr parent_offset = ctx->GetOffset(parent_); - PrimExpr base_offset = BufferLoad(indptr, {parent_offset}); - PrimExpr accum_offset = Integer(0); - PrimExpr length = Integer(0); - Array collect_axes; + PrimExpr root_offset = ctx->GetOffset(orig_->parent_); + PrimExpr accum_offset = BufferLoad(indptr, {root_offset}); + Array collect_axes; Array collect_coordinates; + Array strides; Axis axis; + PrimExpr stride = Integer(1); for (axis = GetRef(this); axis->kind() == AxisKind::kDenseVariable; axis = ctx->GetPrevAxis(axis).value()) { - collect_axes.push_back(axis); + DenseVariableAxis dv_axis = Downcast(axis); + collect_axes.push_back(dv_axis); collect_coordinates.push_back(ctx->GetCoordinate(axis)); + Buffer indptr; + if (auto att_axis = dv_axis.as()) { + indptr = att_axis->orig_->indptr; + } else { + indptr = dv_axis->indptr; + } + strides.push_back(stride); + stride = stride * (BufferLoad(indptr, {root_offset + 1}) - BufferLoad(indptr, {root_offset})); } - ICHECK(axis.get() == parent_.get()) - << "The root of attached axis should be the same as stored parent axis."; - for (int i = collect_axes.size() - 1; i != 0; --i) { - Axis axis = std::move(collect_axes[i]); - auto* ptr = axis.as(); - ICHECK(ptr != nullptr) - << "Each attached axis except for the root must be a dense variable axis"; + ICHECK(axis == orig_->parent_) << "Root axis mismatch."; + PrimExpr length = Integer(0); + for (int i = collect_axes.size() - 1; i >= 0; --i) { + DenseVariableAxis axis = std::move(collect_axes[i]); PrimExpr coordinate = std::move(collect_coordinates[i]); - accum_offset = accum_offset * length + coordinate; - length = - BufferLoad(ptr->indptr, {parent_offset + 1}) - BufferLoad(ptr->indptr, {parent_offset}); + PrimExpr stride = std::move(strides[i]); + accum_offset = accum_offset + coordinate * stride; } - return base_offset + accum_offset; + return accum_offset; } TVM_REGISTER_NODE_TYPE(AttachedAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis") - .set_body_typed([](String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) { + .set_body_typed([](String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz, + Buffer indptr) { return AttachedAxis(std::move(name), std::move(parent), std::move(orig), std::move(nnz), std::move(indptr)); }); @@ -293,7 +304,7 @@ PrimExpr SparseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) cons return lower_bound(indices->data, coordinate, lb, ub) - lb; } -PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { +PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { return BufferLoad(indices, {offset}); } @@ -330,7 +341,7 @@ SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length PrimExpr SparseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const { Axis prev_axis = ctx->GetPrevAxis(GetRef(this)).value(); PrimExpr prev_offset = ctx->GetOffset(prev_axis); - return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index)); + return BufferLoad(indptr, {std::move(prev_offset)}) + std::move(index); } PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const { @@ -339,7 +350,7 @@ PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) c return lower_bound(indices->data, coordinate, lb, ub) - lb; } -PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { +PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const { return BufferLoad(indices, {offset}); } diff --git a/src/tir/transforms/lower_sparse_tir.cc b/src/tir/transforms/lower_sparse_tir.cc index 35c9ce7e7c11..987deaad5319 100644 --- a/src/tir/transforms/lower_sparse_tir.cc +++ b/src/tir/transforms/lower_sparse_tir.cc @@ -128,9 +128,10 @@ class SparseBlockCtx : public SparseCtx { for (size_t i = 0; i < n_iters; ++i) { SpIterVar sp_iter_var = sp_block->sp_iter_vars[i]; Axis axis = sp_iter_var->axis; + PrimExpr offset = AggregateOffset(this, axis, sp_iter_var->var, ana_); SetOffset(axis, offset); - PrimExpr coordinate = axis->Decompress(this, offset); + PrimExpr coordinate = axis->Decompress(this, offset, sp_iter_var->var); SetCoordinate(axis, coordinate); } } diff --git a/tests/python/sparsetir/test_tir_sparse_lower.py b/tests/python/sparsetir/test_tir_sparse_lower.py index 318e47040dca..3567c25eaa98 100644 --- a/tests/python/sparsetir/test_tir_sparse_lower.py +++ b/tests/python/sparsetir/test_tir_sparse_lower.py @@ -647,6 +647,7 @@ def test_csr_element_wise(): def test_bmm(): mod = tvm.IRModule.from_expr(bmm) mod = tvm.tir.transform.LowerSparseTIR()(mod) + print(mod["main"].script()) # TODO @@ -754,6 +755,6 @@ def test_square_sum_two_K(): test_csr_element_wise() test_sddmm() # test_fused_sddmm() - # test_bmm() + test_bmm() test_square_sum() test_square_sum_two_K()