Skip to content

Commit

Permalink
Fix irregular BMM example. (#44)
Browse files Browse the repository at this point in the history
* upd

* upd

* upd

* upd
  • Loading branch information
yzh119 committed Jan 21, 2022
1 parent b648888 commit 2f54e22
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 35 deletions.
17 changes: 9 additions & 8 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class AxisNode : public Object {

virtual PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const = 0;
virtual PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const = 0;
virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const = 0;
virtual PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const = 0;
std::tuple<PrimExpr, PrimExpr> GetOffsetExtent(SparseCtx* ctx) const;

static constexpr const char* _type_key = "tir.sparse.Axis";
Expand Down Expand Up @@ -156,7 +156,7 @@ class DenseFixedAxisNode : public DenseAxisNode {

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.DenseFixedAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseFixedAxisNode, DenseAxisNode);
Expand Down Expand Up @@ -285,7 +285,7 @@ class DenseVariableAxisNode : public DenseAxisNode {

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.DenseVariableAxis";
TVM_DECLARE_BASE_OBJECT_INFO(DenseVariableAxisNode, DenseAxisNode);
Expand All @@ -309,9 +309,9 @@ class DenseVariableAxis : public DenseAxis {
class AttachedAxisNode : public DenseVariableAxisNode {
public:
/* The original axis before attaching. */
Axis orig_;
DenseVariableAxis orig_;

Axis GetOriginalAxis() const { return orig_; }
PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

PrimExpr Aggregate(SparseCtx* ctx, PrimExpr index) const;

Expand All @@ -325,7 +325,8 @@ class AttachedAxisNode : public DenseVariableAxisNode {
*/
class AttachedAxis : public DenseVariableAxis {
public:
TVM_DLL explicit AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr);
TVM_DLL explicit AttachedAxis(String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz,
Buffer indptr);
TVM_DEFINE_OBJECT_REF_METHODS(AttachedAxis, DenseVariableAxis, AttachedAxisNode);
};

Expand Down Expand Up @@ -366,7 +367,7 @@ class SparseFixedAxisNode : public SparseAxisNode {

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.SparseFixedAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseFixedAxisNode, SparseAxisNode);
Expand Down Expand Up @@ -420,7 +421,7 @@ class SparseVariableAxisNode : public SparseAxisNode {

PrimExpr Compress(SparseCtx* ctx, PrimExpr coordinate) const;

PrimExpr Decompress(SparseCtx* ctx, PrimExpr index) const;
PrimExpr Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const;

static constexpr const char* _type_key = "tir.sparse.SparseVariableAxis";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseVariableAxisNode, SparseAxisNode);
Expand Down
61 changes: 36 additions & 25 deletions src/tir/ir/sparse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ PrimExpr DenseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const
return coordinate;
}

PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; }
PrimExpr DenseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const {
return index;
}

TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode);

Expand Down Expand Up @@ -187,14 +189,16 @@ DenseVariableAxis::DenseVariableAxis(String name, Axis parent, PrimExpr length,
PrimExpr DenseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index));
return BufferLoad(indptr, {std::move(prev_offset)}) + std::move(index);
}

PrimExpr DenseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
return coordinate;
}

PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const { return offset; }
PrimExpr DenseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const {
return index;
}

TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode);

Expand All @@ -213,7 +217,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)

/******** AttachedAxis ********/
/*! \brief Default constructor of AttachedAxis */
AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) {
AttachedAxis::AttachedAxis(String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz,
Buffer indptr) {
ObjectPtr<AttachedAxisNode> node = make_object<AttachedAxisNode>();
node->name = std::move(name);
node->parent_ = std::move(parent);
Expand All @@ -225,37 +230,43 @@ AttachedAxis::AttachedAxis(String name, Axis parent, Axis orig, PrimExpr nnz, Bu
}

PrimExpr AttachedAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
PrimExpr parent_offset = ctx->GetOffset(parent_);
PrimExpr base_offset = BufferLoad(indptr, {parent_offset});
PrimExpr accum_offset = Integer(0);
PrimExpr length = Integer(0);
Array<Axis> collect_axes;
PrimExpr root_offset = ctx->GetOffset(orig_->parent_);
PrimExpr accum_offset = BufferLoad(indptr, {root_offset});
Array<DenseVariableAxis> collect_axes;
Array<PrimExpr> collect_coordinates;
Array<PrimExpr> strides;
Axis axis;
PrimExpr stride = Integer(1);
for (axis = GetRef<Axis>(this); axis->kind() == AxisKind::kDenseVariable;
axis = ctx->GetPrevAxis(axis).value()) {
collect_axes.push_back(axis);
DenseVariableAxis dv_axis = Downcast<DenseVariableAxis>(axis);
collect_axes.push_back(dv_axis);
collect_coordinates.push_back(ctx->GetCoordinate(axis));
Buffer indptr;
if (auto att_axis = dv_axis.as<AttachedAxisNode>()) {
indptr = att_axis->orig_->indptr;
} else {
indptr = dv_axis->indptr;
}
strides.push_back(stride);
stride = stride * (BufferLoad(indptr, {root_offset + 1}) - BufferLoad(indptr, {root_offset}));
}
ICHECK(axis.get() == parent_.get())
<< "The root of attached axis should be the same as stored parent axis.";
for (int i = collect_axes.size() - 1; i != 0; --i) {
Axis axis = std::move(collect_axes[i]);
auto* ptr = axis.as<DenseVariableAxisNode>();
ICHECK(ptr != nullptr)
<< "Each attached axis except for the root must be a dense variable axis";
ICHECK(axis == orig_->parent_) << "Root axis mismatch.";
PrimExpr length = Integer(0);
for (int i = collect_axes.size() - 1; i >= 0; --i) {
DenseVariableAxis axis = std::move(collect_axes[i]);
PrimExpr coordinate = std::move(collect_coordinates[i]);
accum_offset = accum_offset * length + coordinate;
length =
BufferLoad(ptr->indptr, {parent_offset + 1}) - BufferLoad(ptr->indptr, {parent_offset});
PrimExpr stride = std::move(strides[i]);
accum_offset = accum_offset + coordinate * stride;
}
return base_offset + accum_offset;
return accum_offset;
}

TVM_REGISTER_NODE_TYPE(AttachedAxisNode);

TVM_REGISTER_GLOBAL("tir.sparse.AttachedAxis")
.set_body_typed([](String name, Axis parent, Axis orig, PrimExpr nnz, Buffer indptr) {
.set_body_typed([](String name, Axis parent, DenseVariableAxis orig, PrimExpr nnz,
Buffer indptr) {
return AttachedAxis(std::move(name), std::move(parent), std::move(orig), std::move(nnz),
std::move(indptr));
});
Expand Down Expand Up @@ -293,7 +304,7 @@ PrimExpr SparseFixedAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) cons
return lower_bound(indices->data, coordinate, lb, ub) - lb;
}

PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const {
PrimExpr SparseFixedAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const {
return BufferLoad(indices, {offset});
}

Expand Down Expand Up @@ -330,7 +341,7 @@ SparseVariableAxis::SparseVariableAxis(String name, Axis parent, PrimExpr length
PrimExpr SparseVariableAxisNode::Aggregate(SparseCtx* ctx, PrimExpr index) const {
Axis prev_axis = ctx->GetPrevAxis(GetRef<Axis>(this)).value();
PrimExpr prev_offset = ctx->GetOffset(prev_axis);
return add(BufferLoad(indptr, {std::move(prev_offset)}), std::move(index));
return BufferLoad(indptr, {std::move(prev_offset)}) + std::move(index);
}

PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) const {
Expand All @@ -339,7 +350,7 @@ PrimExpr SparseVariableAxisNode::Compress(SparseCtx* ctx, PrimExpr coordinate) c
return lower_bound(indices->data, coordinate, lb, ub) - lb;
}

PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset) const {
PrimExpr SparseVariableAxisNode::Decompress(SparseCtx* ctx, PrimExpr offset, PrimExpr index) const {
return BufferLoad(indices, {offset});
}

Expand Down
3 changes: 2 additions & 1 deletion src/tir/transforms/lower_sparse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,10 @@ class SparseBlockCtx : public SparseCtx {
for (size_t i = 0; i < n_iters; ++i) {
SpIterVar sp_iter_var = sp_block->sp_iter_vars[i];
Axis axis = sp_iter_var->axis;

PrimExpr offset = AggregateOffset(this, axis, sp_iter_var->var, ana_);
SetOffset(axis, offset);
PrimExpr coordinate = axis->Decompress(this, offset);
PrimExpr coordinate = axis->Decompress(this, offset, sp_iter_var->var);
SetCoordinate(axis, coordinate);
}
}
Expand Down
3 changes: 2 additions & 1 deletion tests/python/sparsetir/test_tir_sparse_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,7 @@ def test_csr_element_wise():
def test_bmm():
mod = tvm.IRModule.from_expr(bmm)
mod = tvm.tir.transform.LowerSparseTIR()(mod)
print(mod["main"].script())
# TODO


Expand Down Expand Up @@ -754,6 +755,6 @@ def test_square_sum_two_K():
test_csr_element_wise()
test_sddmm()
# test_fused_sddmm()
# test_bmm()
test_bmm()
test_square_sum()
test_square_sum_two_K()

0 comments on commit 2f54e22

Please sign in to comment.