Skip to content

Commit

Permalink
[BugFix][SparseTIR] TVMScript Parser for Axis & SpIterVar (#12)
Browse files Browse the repository at this point in the history
* Update `cord` and `pos`

* Fix `idtype`

* Formatting..

* Bug fix 1

* Move new special stmts

* Parser for Axis and SpIterVar

* Fix context_maintainer.py
  • Loading branch information
MasterJH5574 authored and yzh119 committed Nov 16, 2021
1 parent 094f4a5 commit a0f44d3
Show file tree
Hide file tree
Showing 7 changed files with 218 additions and 282 deletions.
51 changes: 18 additions & 33 deletions include/tvm/tir/sparse.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,8 @@ class DenseVariableAxisNode : public DenseAxisNode {
v->Visit("indptr", &indptr);
}

bool SEqualReduce(const DenseVariableAxisNode* other,
SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr);
bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -165,11 +163,9 @@ class DenseVariableAxisNode : public DenseAxisNode {
*/
class DenseVariableAxis : public DenseAxis {
public:
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length,
Buffer indptr);
TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr);

TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis,
DenseVariableAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode);
};

/*!
Expand Down Expand Up @@ -206,8 +202,7 @@ class SparseFixedAxisNode : public SparseAxisNode {
v->Visit("num_cols", &num_cols);
}

bool SEqualReduce(const SparseFixedAxisNode* other,
SEqualReducer equal) const {
bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indices, other->indices) && equal(num_cols, other->num_cols);
}
Expand All @@ -229,11 +224,9 @@ class SparseFixedAxisNode : public SparseAxisNode {
*/
class SparseFixedAxis : public SparseAxis {
public:
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices,
PrimExpr num_cols);
TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols);

TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis,
SparseFixedAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode);
};

/*!
Expand All @@ -251,8 +244,7 @@ class SparseVariableAxisNode : public SparseAxisNode {
v->Visit("indices", &indices);
}

bool SEqualReduce(const SparseVariableAxisNode* other,
SEqualReducer equal) const {
bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const {
return equal(name, other->name) && equal(length, other->length) &&
equal(indptr, other->indptr) && equal(indices, other->indices);
}
Expand All @@ -274,11 +266,9 @@ class SparseVariableAxisNode : public SparseAxisNode {
*/
class SparseVariableAxis : public SparseAxis {
public:
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length,
Buffer indptr, Buffer indices);
TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices);

TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis,
SparseVariableAxisNode);
TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode);
};

/*!
Expand All @@ -287,12 +277,9 @@ class SparseVariableAxis : public SparseAxis {
class AxisTreeNode : public Object {
public:
// unordered map that stores the parent relationship between axes.
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual>
parent;
std::unordered_map<String, Optional<String>, ObjectPtrHash, ObjectPtrEqual> parent;
// unordered map that stores the children relationship between axes.
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash,
ObjectPtrEqual>
children;
std::unordered_map<Optional<String>, Array<String>, ObjectPtrHash, ObjectPtrEqual> children;

void VisitAttrs(AttrVisitor* v) {}

Expand All @@ -306,8 +293,7 @@ class AxisTreeNode : public Object {
*/
class AxisTree : public ObjectRef {
public:
TVM_DLL AxisTree(Array<String> axis_names,
Array<Optional<String>> axis_parent_names);
TVM_DLL AxisTree(Array<String> axis_names, Array<Optional<String>> axis_parent_names);

TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode);
};
Expand All @@ -333,8 +319,7 @@ class SparseBufferNode : public Object {
}

bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const {
return equal(axes, other->axes) && equal(data, other->data) &&
equal(name, other->name);
return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand Down Expand Up @@ -386,8 +371,8 @@ class SpIterVarNode : public Object {

bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const {
return equal(var, other->var) && equal(max_extent, other->max_extent) &&
equal(axis, other->axis) &&
equal(is_reduction, other->is_reduction) && equal(kind, other->kind);
equal(axis, other->axis) && equal(is_reduction, other->is_reduction) &&
equal(kind, other->kind);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand All @@ -406,8 +391,8 @@ class SpIterVarNode : public Object {

class SpIterVar : public ObjectRef {
public:
TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind,
bool is_reduction, Optional<Axis> axis = NullOpt);
TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction,
Optional<Axis> axis = NullOpt);

/*!
* \return the corresponding var in the IterVar.
Expand Down
9 changes: 3 additions & 6 deletions python/tvm/script/context_maintainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import tvm
from tvm.ir import Span
from tvm.ir.expr import Range
from tvm.script.tir.sparse import MatchSparseBuffer
from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion
from tvm.runtime import Object
from tvm.tir.expr import IterVar
Expand Down Expand Up @@ -76,10 +75,6 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None:
"""List[Buffer]: list of T.alloc_buffer statements in the block signature"""
match_buffers: List[MatchBufferRegion] = []
"""List[MatchBufferRegion]: list of T.match_buffer statements in the block signature"""
axes: List[Axis] = []
"""List[Axis]: list of sparse axis created in the block signature."""
match_sparse_buffers: List[MatchSparseBuffer]
"""List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature."""
iter_values: List[PrimExpr] = []
"""List[PrimExpr]: list of binding values for iter vars"""
iter_vars: List[IterVar] = []
Expand Down Expand Up @@ -217,7 +212,9 @@ def exit_block_scope(self):
# Pop block_info
self.block_info_stack.pop()

def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node):
def update_symbol(
self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node
):
"""Append a symbol into current scope"""
if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)):
if name in self.symbols[0]:
Expand Down
207 changes: 0 additions & 207 deletions python/tvm/script/tir/sparse.py

This file was deleted.

Loading

0 comments on commit a0f44d3

Please sign in to comment.