Skip to content

Commit

Permalink
[SparseTIR] SparseBlock on C++/Python side (#11)
Browse files Browse the repository at this point in the history
* Fix a bug in the last commit

* SparseBlock on C++ & Python side
  • Loading branch information
MasterJH5574 committed Nov 6, 2021
1 parent 19400c9 commit 863ba59
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 27 deletions.
77 changes: 55 additions & 22 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,28 +327,6 @@ class BufferStore : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode);
};

/*!
* \brief Sparse Block node.
*/
class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse buffers defined in the block. */
Array<SparseBuffer> sp_buffers;
/*! \brief The body of the block */
Stmt body;

static constexpr const char* _type_key = "tir.SparseBlock";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
};

class SparseBlock : public Stmt {
public:
TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
};


/*!
* \brief Store value to the high dimension sparse buffer.
*
Expand Down Expand Up @@ -1300,6 +1278,61 @@ class BlockRealize : public Stmt {
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
};

/*!
* \brief Sparse Block node.
*/
class SparseBlockNode : public StmtNode {
public:
/*! \brief The sparse iteration variables of the block. */
Array<SpIterVar> sp_iter_vars;
/*! \brief The sparse buffers defined in the block. */
Array<SparseBuffer> sp_buffers;
/*! \brief The name of the block */
String name;
/*! \brief The body of the block */
Stmt body;
/*! \brief The init statement of the block */
Optional<Stmt> init;

void VisitAttrs(AttrVisitor* v) {
v->Visit("sp_iter_vars", &sp_iter_vars);
v->Visit("sp_buffers", &sp_buffers);
v->Visit("name", &name);
v->Visit("body", &body);
v->Visit("init", &init);
}

bool SEqualReduce(const SparseBlockNode* other, SEqualReducer equal) const {
return equal(sp_iter_vars, other->sp_iter_vars) && equal(sp_buffers, other->sp_buffers) &&
equal(name, other->name) && equal(body, other->body) && equal(init, other->init);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(sp_iter_vars);
hash_reduce(sp_buffers);
hash_reduce(name);
hash_reduce(body);
hash_reduce(init);
}

static constexpr const char* _type_key = "tir.SparseBlock";
TVM_DECLARE_FINAL_OBJECT_INFO(SparseBlockNode, StmtNode);
};

/*!
* \brief Managed reference to SparseBufferNode
* \sa SparseBufferNode
*/
class SparseBlock : public Stmt {
public:
TVM_DLL explicit SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers,
String name, Stmt body, Optional<Stmt> init = NullOpt,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(SparseBlock, Stmt, SparseBlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(SparseBlockNode);
};

/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ class SpIterVar(Object):
SparseFixed = 2
SparseVariable = 3

def __init__(self, var, max_extent, kind, axis=None):
def __init__(self, var, max_extent, kind, is_reduction, axis=None):
self.__init_handle_by_constructor__(
_ffi_api.SpIterVar, var, max_extent, kind, is_reduction, axis # type: ignore
)
53 changes: 53 additions & 0 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from . import _ffi_api
from .buffer import Buffer
from .expr import IterVar
from .sparse import SpIterVar, SparseBuffer


class Stmt(Object):
Expand Down Expand Up @@ -614,6 +615,58 @@ def __init__(
) # type: ignore


@tvm._ffi.register_object("tir.SparseBlock")
class SparseBlock(Stmt):
"""SparseBlock node.
Parameters
----------
sp_iter_vars : List[SpIterVar]
The sparse iteration variables of the block.
sp_buffers : List[SparseBuffer]
The sparse buffers defined in the block.
name : str
The name of the block.
body : Stmt
The body of the block.
init : Optional[Stmt]
The init statement of the block.
span : Optional[Span]
The location of this block in the source code.
"""

sp_iter_vars: List[SpIterVar]
sp_buffers: List[SparseBuffer]
name: str
body: Stmt
init: Optional[Stmt]
span: Optional[Span]

def __init__(
self,
sp_iter_vars: List[SpIterVar],
sp_buffers: List[SparseBuffer],
name: str,
body: Stmt,
init: Optional[Stmt] = None,
span: Optional[Span] = None,
):
self.__init_handle_by_constructor__(
_ffi_api.SparseBlock, # type: ignore
sp_iter_vars,
sp_buffers,
name,
body,
init,
span,
) # type: ignore


@tvm._ffi.register_object("tir.BlockRealize")
class BlockRealize(Stmt):
"""BlockRealize node.
Expand Down
65 changes: 61 additions & 4 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -876,17 +876,21 @@ void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) {
}
}

void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
// Print init
if (op->init.defined()) {
void PrintInitStmt(const Optional<Stmt>& init, ReprPrinter* p) {
if (init.defined()) {
p->PrintIndent();
p->stream << "with init() {\n";
p->indent += 2;
p->Print(op->init.value());
p->Print(init.value());
p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
}
}

void PrintBlockBody(const BlockNode* op, ReprPrinter* p) {
// Print init
PrintInitStmt(op->init, p);
// Print body
p->Print(op->body);
}
Expand Down Expand Up @@ -964,6 +968,59 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "}\n";
});

SparseBlock::SparseBlock(Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
Stmt body, Optional<Stmt> init, Span span) {
ObjectPtr<SparseBlockNode> node = make_object<SparseBlockNode>();
node->sp_iter_vars = std::move(sp_iter_vars);
node->sp_buffers = std::move(sp_buffers);
node->name = std::move(name);
node->body = std::move(body);
node->init = std::move(init);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.SparseBlock")
.set_body_typed([](Array<SpIterVar> sp_iter_vars, Array<SparseBuffer> sp_buffers, String name,
Stmt body, Optional<Stmt> init, Span span) {
return SparseBlock(sp_iter_vars, sp_buffers, name, body, init, span);
});

TVM_REGISTER_NODE_TYPE(SparseBlockNode);

void PrintSparseBlockTitle(const SparseBlockNode* op, ReprPrinter* p) {
p->stream << "sparse_block " << op->name << "(";
for (int i = 0; i < static_cast<int>(op->sp_iter_vars.size()); ++i) {
p->Print(op->sp_iter_vars[i]);
if (i < static_cast<int>(op->sp_iter_vars.size()) - 1) {
p->stream << ", ";
}
}
p->stream << ")";
}

void PrintSparseBlockBody(const SparseBlockNode* op, ReprPrinter* p) {
// Print init
PrintInitStmt(op->init, p);
// Print body
p->Print(op->body);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<SparseBlockNode>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const SparseBlockNode*>(node.get());
p->PrintIndent();
PrintSparseBlockTitle(op, p);
p->stream << " {\n";
p->indent += 2;

PrintSparseBlockBody(op, p);

p->indent -= 2;
p->PrintIndent();
p->stream << "}\n";
});

PrimExpr TypeAnnotation(DataType dtype, Span span) {
static auto op = Op::Get("tir.type_annotation");
return tir::Call(dtype, op, {}, span);
Expand Down

0 comments on commit 863ba59

Please sign in to comment.