diff --git a/include/tvm/tir/sparse.h b/include/tvm/tir/sparse.h index 935da5a8f259..6be1062c3c86 100644 --- a/include/tvm/tir/sparse.h +++ b/include/tvm/tir/sparse.h @@ -143,10 +143,8 @@ class DenseVariableAxisNode : public DenseAxisNode { v->Visit("indptr", &indptr); } - bool SEqualReduce(const DenseVariableAxisNode* other, - SEqualReducer equal) const { - return equal(name, other->name) && equal(length, other->length) && - equal(indptr, other->indptr); + bool SEqualReduce(const DenseVariableAxisNode* other, SEqualReducer equal) const { + return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr); } void SHashReduce(SHashReducer hash_reduce) const { @@ -165,11 +163,9 @@ class DenseVariableAxisNode : public DenseAxisNode { */ class DenseVariableAxis : public DenseAxis { public: - TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, - Buffer indptr); + TVM_DLL explicit DenseVariableAxis(String name, PrimExpr length, Buffer indptr); - TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, - DenseVariableAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(DenseVariableAxis, DenseAxis, DenseVariableAxisNode); }; /*! @@ -206,8 +202,7 @@ class SparseFixedAxisNode : public SparseAxisNode { v->Visit("num_cols", &num_cols); } - bool SEqualReduce(const SparseFixedAxisNode* other, - SEqualReducer equal) const { + bool SEqualReduce(const SparseFixedAxisNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(length, other->length) && equal(indices, other->indices) && equal(num_cols, other->num_cols); } @@ -229,11 +224,9 @@ class SparseFixedAxisNode : public SparseAxisNode { */ class SparseFixedAxis : public SparseAxis { public: - TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, - PrimExpr num_cols); + TVM_DLL explicit SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols); - TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, - SparseFixedAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(SparseFixedAxis, SparseAxis, SparseFixedAxisNode); }; /*! @@ -251,8 +244,7 @@ class SparseVariableAxisNode : public SparseAxisNode { v->Visit("indices", &indices); } - bool SEqualReduce(const SparseVariableAxisNode* other, - SEqualReducer equal) const { + bool SEqualReduce(const SparseVariableAxisNode* other, SEqualReducer equal) const { return equal(name, other->name) && equal(length, other->length) && equal(indptr, other->indptr) && equal(indices, other->indices); } @@ -274,11 +266,9 @@ class SparseVariableAxisNode : public SparseAxisNode { */ class SparseVariableAxis : public SparseAxis { public: - TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, - Buffer indptr, Buffer indices); + TVM_DLL explicit SparseVariableAxis(String name, PrimExpr length, Buffer indptr, Buffer indices); - TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, - SparseVariableAxisNode); + TVM_DEFINE_OBJECT_REF_METHODS(SparseVariableAxis, SparseAxis, SparseVariableAxisNode); }; /*! @@ -287,12 +277,9 @@ class SparseVariableAxis : public SparseAxis { class AxisTreeNode : public Object { public: // unordered map that stores the parent relationship between axes. - std::unordered_map, ObjectPtrHash, ObjectPtrEqual> - parent; + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> parent; // unordered map that stores the children relationship between axes. - std::unordered_map, Array, ObjectPtrHash, - ObjectPtrEqual> - children; + std::unordered_map, Array, ObjectPtrHash, ObjectPtrEqual> children; void VisitAttrs(AttrVisitor* v) {} @@ -306,8 +293,7 @@ class AxisTreeNode : public Object { */ class AxisTree : public ObjectRef { public: - TVM_DLL AxisTree(Array axis_names, - Array> axis_parent_names); + TVM_DLL AxisTree(Array axis_names, Array> axis_parent_names); TVM_DEFINE_OBJECT_REF_METHODS(AxisTree, ObjectRef, AxisTreeNode); }; @@ -333,8 +319,7 @@ class SparseBufferNode : public Object { } bool SEqualReduce(const SparseBufferNode* other, SEqualReducer equal) const { - return equal(axes, other->axes) && equal(data, other->data) && - equal(name, other->name); + return equal(axes, other->axes) && equal(data, other->data) && equal(name, other->name); } void SHashReduce(SHashReducer hash_reduce) const { @@ -386,8 +371,8 @@ class SpIterVarNode : public Object { bool SEqualReduce(const SpIterVarNode* other, SEqualReducer equal) const { return equal(var, other->var) && equal(max_extent, other->max_extent) && - equal(axis, other->axis) && - equal(is_reduction, other->is_reduction) && equal(kind, other->kind); + equal(axis, other->axis) && equal(is_reduction, other->is_reduction) && + equal(kind, other->kind); } void SHashReduce(SHashReducer hash_reduce) const { @@ -406,8 +391,8 @@ class SpIterVarNode : public Object { class SpIterVar : public ObjectRef { public: - TVM_DLL explicit SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, - bool is_reduction, Optional axis = NullOpt); + TVM_DLL explicit SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, + Optional axis = NullOpt); /*! * \return the corresponding var in the IterVar. diff --git a/python/tvm/script/context_maintainer.py b/python/tvm/script/context_maintainer.py index 5938a1da6285..2a84c3d896d2 100644 --- a/python/tvm/script/context_maintainer.py +++ b/python/tvm/script/context_maintainer.py @@ -23,7 +23,6 @@ import tvm from tvm.ir import Span from tvm.ir.expr import Range -from tvm.script.tir.sparse import MatchSparseBuffer from tvm.tir import Var, Buffer, PrimExpr, Stmt, MatchBufferRegion from tvm.runtime import Object from tvm.tir.expr import IterVar @@ -76,10 +75,6 @@ def example_func(a: T.handle, b: T.handle, c: T.handle) -> None: """List[Buffer]: list of T.alloc_buffer statements in the block signature""" match_buffers: List[MatchBufferRegion] = [] """List[MatchBufferRegion]: list of T.match_buffer statements in the block signature""" - axes: List[Axis] = [] - """List[Axis]: list of sparse axis created in the block signature.""" - match_sparse_buffers: List[MatchSparseBuffer] - """List[MatchSparseBuffer]: list of T.match_sparse_buffer statements in the block signature.""" iter_values: List[PrimExpr] = [] """List[PrimExpr]: list of binding values for iter vars""" iter_vars: List[IterVar] = [] @@ -217,7 +212,9 @@ def exit_block_scope(self): # Pop block_info self.block_info_stack.pop() - def update_symbol(self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node): + def update_symbol( + self, name: str, symbol: Union[Buffer, Var, SparseBuffer, Axis], node: synr.ast.Node + ): """Append a symbol into current scope""" if isinstance(symbol, (Buffer, Var, SparseBuffer, Axis)): if name in self.symbols[0]: diff --git a/python/tvm/script/tir/sparse.py b/python/tvm/script/tir/sparse.py deleted file mode 100644 index 3a565545f575..000000000000 --- a/python/tvm/script/tir/sparse.py +++ /dev/null @@ -1,207 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""TVM Script Interface for Sparse TIR""" -import synr -import tvm -from synr import ast -from tvm.ir.base import Span -from tvm.ir.expr import PrimExpr, Range - -from tvm.script.tir.node import BufferSlice -from tvm.script.tir.utils import buffer_slice_to_region -from tvm.tir.expr import PrimExprWithOp -from .scope_handler import ScopeHandler, LoopInfo -from .intrin import Intrin -from ..context_maintainer import BlockInfo, ContextMaintainer -from .special_stmt import SpecialStmt -from tvm.tir.sparse import Axis, AxisTree, DenseFixedAxis, DenseVariableAxis, SpIterVar, SparseFixedAxis, SparseVariableAxis -from typing import List, Mapping, Optional, Tuple, Any -from tvm.runtime.object import Object -from tvm.script.registry import register -from ..utils import ( - tvm_span_from_synr, - call_with_error_reporting, -) - - -@register -class DenseFixed(SpecialStmt): - """Special Stmt for creating dense fixed axis. - """ - - def __init__(self): - def dense_fixed( - name: str, - length: PrimExpr, - idtype: str = 'int32', - span: Optional[Span] = None - ): - var_name = self.node.lhs[0].id.name - axis = DenseFixedAxis(name, length, idtype=idtype) - self.context.update_symbol(var_name, axis, self.node) - super().__init__(dense_fixed, def_symbol=True) - - -@register -class DenseVariable(SpecialStmt): - """Special Stmt for creating dense variable axis. - """ - - def __init__(self): - def dense_variable( - name: str, - shape: Tuple[PrimExpr, PrimExpr], - indptr: tvm.tir.Var, - idtype: str = 'int32', - span: Optional[Span] = None - ): - indptr_len, length = shape - var_name = self.node.lhs[0].id.name - indptr_buf = tvm.tir.decl_buffer( - (indptr_len,), - dtype=idtype, - name=name + "_indptr", - span=span - ) - axis = DenseVariableAxis(name, length, indptr_buf, idtype=idtype) - self.context.func_buffer_map[indptr] = indptr_buf - self.context.update_symbol(var_name, axis, self.node) - super().__init__(dense_variable, def_symbol=True) - - -@register -class SparseFixed(SpecialStmt): - """Special Stmt for creating sparse fixed axis. - """ - - def __init__(self): - def sparse_fixed( - name: str, - shape: Tuple[PrimExpr, PrimExpr, PrimExpr], - indices: tvm.tir.Var, - idtype: str = 'int32', - span: Optional[Span] = None - ): - var_name = self.node.lhs[0].id.name - length, nnz, nnz_cols = shape - indices_buf = tvm.tir.decl_buffer( - (nnz,), - dtype=idtype, - name=name+"_indices", - span=span - ) - axis = SparseFixedAxis(name, length, indices_buf, nnz_cols, idtype=idtype) - self.context.func_buffer_map[indices] = indices_buf - self.context.update_symbol(var_name, axis, self.node) - super().__init__(sparse_fixed, def_symbol=True) - - -@register -class SparseVariable(SpecialStmt): - """Special Stmt for creating sparse variable axis: - """ - - def __init__(self): - def sparse_variable( - name: str, - shape: Tuple[PrimExpr, PrimExpr], - data: Tuple[tvm.tir.Var, tvm.tir.Var], - idtype: str = 'int32', - span: Optional[Span] = None - ): - var_name = self.node.lhs[0].id.name - length, indptr_len, nnz = shape - indptr, indices = data - indptr_buf = tvm.tir.decl_buffer( - (indptr_len,), - dtype=idtype, - name=name+"_indptr", - span=span - ) - indices_buf = tvm.tir.decl_buffer( - (nnz,), - dtype=idtype, - name=name+"_indices", - span=span - ) - axis = SparseVariableAxis(name, length, indptr_buf, indices_buf, idtype=idtype) - self.context.func_buffer_map[indices] = indices_buf - self.context.func_buffer_map[indptr] = indptr_buf - self.context.update_symbol(var_name, axis, self.node) - super().__init__(sparse_variable, def_symbol=True) - - -@register -class MatchSparseBuffer(SpecialStmt): - """Special Stmt match_sparse_buffer() - """ - - def __init__(self): - def match_sparse_buffer( - param: tvm.tir.Var, - axes: List[Axis], - dtype: str = 'float32', - span: Optional[Span] = None, - ): - if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: - self.context.report_error( - "`match_sparse_buffer` must be assigned to a single sparse buffer, " - "e.g. A = match_sparse_buffer(...)" - ) - - buffer_name: str = self.node.lhs[0].id.name - if not isinstance(param, tvm.tir.Var): - self.context.report_error( - "The source of match_sparse_buffer expected Var, but got" - + str(type(param)), - self.node.rhs.params[0].span - ) - - if param in self.context.func_params: - buffer = tvm.tir.sparse.decl_buffer( - axes, - param, - buffer_name, - dtype, - span=span - ) - self.context.func_sparse_buffer_map[param] = buffer - self.context.update_symbol(buffer_name, buffer, self.node) - else: - self.context.report_error( - "Can not bind non-input param to sparse buffer", self.node.rhs.params[0].span - ) - - super().__init__(match_sparse_buffer, def_symbol=True) - - -@register -def to_dense(axis: Axis, span: Optional[Span] = None): - if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): - return DenseFixedAxis(axis.name, axis.length, axis.idtype) - else: - return axis - - -@register -def cord(axis: Axis, span: Optional[Span] = None): - return 'cord', axis - - -@register -def pos(axis: Axis, span: Optional[Span] = None): - return 'pos', axis diff --git a/python/tvm/script/tir/special_stmt.py b/python/tvm/script/tir/special_stmt.py index 20161ad106c1..ff4837696db8 100644 --- a/python/tvm/script/tir/special_stmt.py +++ b/python/tvm/script/tir/special_stmt.py @@ -29,6 +29,14 @@ from tvm.target import Target from tvm.ir import Span from tvm.tir import IntImm, IterVar +from tvm.tir.sparse import ( + Axis, + DenseFixedAxis, + DenseVariableAxis, + SpIterVar, + SparseFixedAxis, + SparseVariableAxis, +) from .node import BufferSlice from .utils import buffer_slice_to_region @@ -885,3 +893,169 @@ def __call__(self, target_config): f"T.target expected a config dict or string, but got {type(target_config)}" ) return Target(target_config) + + +@register +class DenseFixed(SpecialStmt): + """Special Stmt for creating dense fixed axis.""" + + def __init__(self): + def dense_fixed(name: str, length: PrimExpr, span: Optional[Span] = None): + var_name = self.node.lhs[0].id.name + axis = DenseFixedAxis(name, length) + self.context.update_symbol(var_name, axis, self.node) + + super().__init__(dense_fixed, def_symbol=True) + + +@register +class DenseVariable(SpecialStmt): + """Special Stmt for creating dense variable axis.""" + + def __init__(self): + def dense_variable( + name: str, + shape: Tuple[PrimExpr, PrimExpr], + indptr_var: tvm.tir.Var, + idtype: str = "int32", + span: Optional[Span] = None, + ): + indptr_len, length = shape + var_name = self.node.lhs[0].id.name + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), dtype=idtype, name=name + "_indptr", span=span + ) + axis = DenseVariableAxis(name, length, indptr_buf) + self.context.func_buffer_map[indptr_var] = indptr_buf + self.context.update_symbol(var_name, axis, self.node) + self.context.update_symbol(name + "_indptr", indptr_buf, self.node) + + super().__init__(dense_variable, def_symbol=True) + + +@register +class SparseFixed(SpecialStmt): + """Special Stmt for creating sparse fixed axis.""" + + def __init__(self): + def sparse_fixed( + name: str, + shape: Tuple[PrimExpr, PrimExpr, PrimExpr], + indices_var: tvm.tir.Var, + idtype: str = "int32", + span: Optional[Span] = None, + ): + var_name = self.node.lhs[0].id.name + length, nnz, nnz_cols = shape + indices_buf = tvm.tir.decl_buffer( + (nnz,), dtype=idtype, name=name + "_indices", span=span + ) + axis = SparseFixedAxis(name, length, indices_buf, nnz_cols) + self.context.func_buffer_map[indices_var] = indices_buf + self.context.update_symbol(var_name, axis, self.node) + self.context.update_symbol(name + "_indices", indices_buf, self.node) + + super().__init__(sparse_fixed, def_symbol=True) + + +@register +class SparseVariable(SpecialStmt): + """Special Stmt for creating sparse variable axis:""" + + def __init__(self): + def sparse_variable( + name: str, + shape: Tuple[PrimExpr, PrimExpr, PrimExpr], + data: Tuple[tvm.tir.Var, tvm.tir.Var], + idtype: str = "int32", + span: Optional[Span] = None, + ): + var_name = self.node.lhs[0].id.name + length, indptr_len, nnz = shape + indptr_var, indices_var = data + indptr_buf = tvm.tir.decl_buffer( + (indptr_len,), dtype=idtype, name=name + "_indptr", span=span + ) + indices_buf = tvm.tir.decl_buffer( + (nnz,), dtype=idtype, name=name + "_indices", span=span + ) + axis = SparseVariableAxis(name, length, indptr_buf, indices_buf) + self.context.func_buffer_map[indices_var] = indices_buf + self.context.func_buffer_map[indptr_var] = indptr_buf + self.context.update_symbol(var_name, axis, self.node) + self.context.update_symbol(name + "_indptr", indptr_buf, self.node) + self.context.update_symbol(name + "_indices", indices_buf, self.node) + + super().__init__(sparse_variable, def_symbol=True) + + +@register +class MatchSparseBuffer(SpecialStmt): + """Special Stmt match_sparse_buffer()""" + + def __init__(self): + def match_sparse_buffer( + param: tvm.tir.Var, + axes: List[Axis], + nnz: PrimExpr, + dtype: str = "float32", + span: Optional[Span] = None, + ): + if not isinstance(self.node, ast.Assign) or not len(self.node.lhs) == 1: + self.context.report_error( + "`match_sparse_buffer` must be assigned to a single sparse buffer, " + "e.g. A = match_sparse_buffer(...)" + ) + + buffer_name: str = self.node.lhs[0].id.name + if not isinstance(param, tvm.tir.Var): + self.context.report_error( + "The source of match_sparse_buffer expected Var, but got" + str(type(param)), + self.node.rhs.params[0].span, + ) + + if param in self.context.func_params: + data = tvm.tir.decl_buffer(nnz, dtype, buffer_name + "_data", span=span) + buffer = tvm.tir.sparse.SparseBuffer(axes, data, buffer_name) + self.context.func_buffer_map[param] = data + self.context.func_sparse_buffer_map[param] = buffer + self.context.update_symbol(buffer_name + "_data", data, self.node) + self.context.update_symbol(buffer_name, buffer, self.node) + else: + self.context.report_error( + "Can not bind non-input param to sparse buffer", self.node.rhs.params[0].span + ) + + super().__init__(match_sparse_buffer, def_symbol=True) + + +@register +def to_dense(axis: Axis, span: Optional[Span] = None): + if isinstance(axis, (SparseFixedAxis, SparseVariableAxis)): + return DenseFixedAxis(axis.name + "_dense", axis.length) + else: + return axis + + +@register +def cord(axis: Axis, span: Optional[Span] = None): + # The field `var` and `is_reduction` will be updated in SparseBlock scope handler + var_temp = tvm.te.var() + if isinstance(axis, DenseVariableAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) + else: + return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False) + + +@register +def pos(axis: Axis, span: Optional[Span] = None): + # The field `var` and `is_reduction` will be updated in SparseBlock scope handler + var_temp = tvm.te.var() + if isinstance(axis, DenseFixedAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.DenseFixed, False) + elif isinstance(axis, DenseVariableAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.DenseVariable, False, axis) + elif isinstance(axis, SparseFixedAxis): + return SpIterVar(var_temp, axis.length, SpIterVar.SparseFixed, False, axis) + else: + return SpIterVar(var_temp, axis.length, SpIterVar.SparseVariable, False, axis) diff --git a/python/tvm/tir/sparse.py b/python/tvm/tir/sparse.py index 4b0b857a8e6e..574ccc2352a6 100644 --- a/python/tvm/tir/sparse.py +++ b/python/tvm/tir/sparse.py @@ -28,6 +28,7 @@ class Axis(Object): """Base class of all the sparse axes.""" + @property def name(self): return _ffi_api.GetAxisName(self) @@ -66,9 +67,7 @@ class DenseFixedAxis(DenseAxis): length: PrimExpr def __init__(self, name, length): - self.__init_handle_by_constructor__( - _ffi_api.DenseFixedAxis, name, length # type: ignore - ) + self.__init_handle_by_constructor__(_ffi_api.DenseFixedAxis, name, length) # type: ignore @tvm._ffi.register_object("tir.sparse.DenseVariableAxis") @@ -198,9 +197,7 @@ class SparseBuffer(Object): name: str def __init__(self, axes, data, name): - self.__init_handle_by_constructor__( - _ffi_api.SparseBuffer, axes, data, name # type: ignore - ) + self.__init_handle_by_constructor__(_ffi_api.SparseBuffer, axes, data, name) # type: ignore @tvm._ffi.register_object("tir.sparse.SpIterVar") @@ -225,6 +222,7 @@ class SpIterVar(Object): The axis over which the SpIterVar iterates. Required to be defined when `kind` is not `DenseFixed` """ + var: Var max_extent: PrimExpr kind: int diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc index fec850ecc589..21812fdba593 100644 --- a/src/tir/ir/expr.cc +++ b/src/tir/ir/expr.cc @@ -1088,7 +1088,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // SparseBufferLoad SparseBufferLoad::SparseBufferLoad(SparseBuffer buffer, Array indices, Span span) { ObjectPtr node = make_object(); - node->dtype = buffer->dtype; node->buffer = std::move(buffer); node->indices = std::move(indices); node->span = std::move(span); diff --git a/src/tir/ir/sparse.cc b/src/tir/ir/sparse.cc index 95dcfa3d7a2a..c3e118611b22 100644 --- a/src/tir/ir/sparse.cc +++ b/src/tir/ir/sparse.cc @@ -52,14 +52,12 @@ DenseFixedAxis::DenseFixedAxis(String name, PrimExpr length) { TVM_REGISTER_NODE_TYPE(DenseFixedAxisNode); -TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis") - .set_body_typed([](String name, PrimExpr length) { - return DenseFixedAxis(name, length); - }); +TVM_REGISTER_GLOBAL("tir.sparse.DenseFixedAxis").set_body_typed([](String name, PrimExpr length) { + return DenseFixedAxis(name, length); +}); // DenseVariableAxis -DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, - Buffer indptr) { +DenseVariableAxis::DenseVariableAxis(String name, PrimExpr length, Buffer indptr) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); @@ -71,13 +69,11 @@ TVM_REGISTER_NODE_TYPE(DenseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.DenseVariableAxis") .set_body_typed([](String name, PrimExpr length, Buffer indptr) { - return DenseVariableAxis( - name, length, indptr); + return DenseVariableAxis(name, length, indptr); }); // SparseFixedAxis -SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, - PrimExpr num_cols) { +SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); @@ -89,16 +85,14 @@ SparseFixedAxis::SparseFixedAxis(String name, PrimExpr length, Buffer indices, TVM_REGISTER_NODE_TYPE(SparseFixedAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseFixedAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indices, - PrimExpr num_cols) { + .set_body_typed([](String name, PrimExpr length, Buffer indices, PrimExpr num_cols) { return SparseFixedAxis(name, length, indices, num_cols); }); // SparseVariableAxis -SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, - Buffer indptr, Buffer indices) { - ObjectPtr node = - make_object(); +SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, Buffer indptr, + Buffer indices) { + ObjectPtr node = make_object(); node->name = std::move(name); node->length = std::move(length); node->indptr = std::move(indptr); @@ -109,15 +103,12 @@ SparseVariableAxis::SparseVariableAxis(String name, PrimExpr length, TVM_REGISTER_NODE_TYPE(SparseVariableAxisNode); TVM_REGISTER_GLOBAL("tir.sparse.SparseVariableAxis") - .set_body_typed([](String name, PrimExpr length, Buffer indptr, - Buffer indices) { - return SparseVariableAxis( - name, length, indptr, indices); + .set_body_typed([](String name, PrimExpr length, Buffer indptr, Buffer indices) { + return SparseVariableAxis(name, length, indptr, indices); }); // AxisTree -AxisTree::AxisTree(Array axis_names, - Array> axis_parent_names) { +AxisTree::AxisTree(Array axis_names, Array> axis_parent_names) { CHECK_EQ(axis_names.size(), axis_parent_names.size()) << "ValueError: The axis_names array should have the same length as " "axis_parent_names " @@ -142,8 +133,7 @@ AxisTree::AxisTree(Array axis_names, TVM_REGISTER_NODE_TYPE(AxisTreeNode); TVM_REGISTER_GLOBAL("tir.sparse.AxisTree") - .set_body_typed([](Array axis_names, - Array> axis_parent_names) { + .set_body_typed([](Array axis_names, Array> axis_parent_names) { return AxisTree(axis_names, axis_parent_names); }); @@ -164,8 +154,8 @@ TVM_REGISTER_GLOBAL("tir.sparse.SparseBuffer") }); // SpIterVar -SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, - bool is_reduction, Optional axis) { +SpIterVar::SpIterVar(Var var, PrimExpr max_extent, SpIterKind kind, bool is_reduction, + Optional axis) { ObjectPtr node = make_object(); arith::Analyzer ana; @@ -185,7 +175,7 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, } } - node->var = Var(std::move(name)); + node->var = Var(std::move(var)); node->max_extent = std::move(max_extent); node->kind = kind; node->is_reduction = is_reduction; @@ -196,9 +186,9 @@ SpIterVar::SpIterVar(String name, PrimExpr max_extent, SpIterKind kind, TVM_REGISTER_NODE_TYPE(SpIterVarNode); TVM_REGISTER_GLOBAL("tir.sparse.SpIterVar") - .set_body_typed([](String name, PrimExpr max_extent, int kind, bool is_reduction, + .set_body_typed([](Var var, PrimExpr max_extent, int kind, bool is_reduction, Optional axis) { - return SpIterVar(name, max_extent, SpIterKind(kind), is_reduction, axis); + return SpIterVar(var, max_extent, SpIterKind(kind), is_reduction, axis); }); } // namespace tir