Skip to content

Commit

Permalink
[REFACTOR][API-Change] Migrate all Object construction to constructor. (
Browse files Browse the repository at this point in the history
apache#5784)

This PR migrates all the remaining object constructions to the new constructor style
that is consistent with the rest of the codebase and changes the affected files accordingly.

Other changes:

- ThreadScope::make -> ThreadScope::Create
- StorageScope::make -> StorageScope::Create
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent 1168a57 commit eaef6b3
Show file tree
Hide file tree
Showing 49 changed files with 469 additions and 416 deletions.
2 changes: 1 addition & 1 deletion docs/dev/codebase_walkthrough.rst
Expand Up @@ -84,7 +84,7 @@ This function is mapped to the C++ function in ``include/tvm/schedule.h``.
::

inline Schedule create_schedule(Array<Operation> ops) {
return ScheduleNode::make(ops);
return Schedule(ops);
}

``Schedule`` consists of collections of ``Stage`` and output ``Operation``.
Expand Down
2 changes: 1 addition & 1 deletion docs/dev/relay_add_pass.rst
Expand Up @@ -138,7 +138,7 @@ is shown below.
if (g->tuple == t) {
return GetRef<Expr>(g);
} else {
return TupleGetItemNode::make(t, g->index);
return TupleGetItem(t, g->index);
}
}
Expand Down
10 changes: 5 additions & 5 deletions docs/dev/relay_pass_infra.rst
Expand Up @@ -344,13 +344,13 @@ registration.
.. code:: c++

// Create a simple Relay program.
auto tensor_type = relay::TensorTypeNode::make({}, tvm::Bool());
auto x = relay::VarNode::make("x", relay::Type());
auto f = relay::FunctionNode::make(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});
auto tensor_type = relay::TensorType({}, tvm::Bool());
auto x = relay::Var("x", relay::Type());
auto f = relay::Function(tvm::Array<relay::Var>{ x }, x, relay::Type(), {});

auto y = relay::VarNode::make("y", tensor_type);
auto y = relay::Var("y", tensor_type);
auto call = relay::Call(f, tvm::Array<relay::Expr>{ y });
auto fx = relay::FunctionNode::make(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});
auto fx = relay::Function(tvm::Array<relay::Var>{ y }, call, relay::Type(), {});

// Create a module for optimization.
auto mod = IRModule::FromExpr(fx);
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/ir/span.h
Expand Up @@ -97,14 +97,14 @@ class SpanNode : public Object {
equal(col_offset, other->col_offset);
}

TVM_DLL static Span make(SourceName source, int lineno, int col_offset);

static constexpr const char* _type_key = "Span";
TVM_DECLARE_FINAL_OBJECT_INFO(SpanNode, Object);
};

class Span : public ObjectRef {
public:
TVM_DLL Span(SourceName source, int lineno, int col_offset);

TVM_DEFINE_OBJECT_REF_METHODS(Span, ObjectRef, SpanNode);
};

Expand Down
90 changes: 75 additions & 15 deletions include/tvm/te/operation.h
Expand Up @@ -177,12 +177,22 @@ class PlaceholderOpNode : public OperationNode {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static Operation make(std::string name, Array<PrimExpr> shape, DataType dtype);

static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};

/*!
* \brief Managed reference to PlaceholderOpNode
* \sa PlaceholderOpNode
*/
class PlaceholderOp : public Operation {
public:
TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);

TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};

/*!
* \brief A Compute op that compute a tensor on certain domain.
* This is the base class for ComputeOp (operating on a scalar at a time) and
Expand Down Expand Up @@ -237,13 +247,23 @@ class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);

static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};

/*!
* \brief Managed reference to ComputeOpNode
* \sa ComputeOpNode
*/
class ComputeOp : public Operation {
public:
TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);

TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
};

/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
Expand Down Expand Up @@ -285,15 +305,25 @@ class TensorComputeOpNode : public BaseComputeOpNode {
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static Operation make(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);

static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};

/*!
* \brief Managed reference to TensorComputeOpNode
* \sa TensorComputeOpNode
*/
class TensorComputeOp : public Operation {
public:
TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);

TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
};

/*!
* \brief Symbolic scan.
*/
Expand Down Expand Up @@ -353,14 +383,24 @@ class ScanOpNode : public OperationNode {
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
IterVar axis, Array<Tensor> init, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> input);

static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};

/*!
* \brief Managed reference to ScanOpNode
* \sa ScanOpNode
*/
class ScanOp : public Operation {
public:
TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
Array<Tensor> input);

TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
};

/*!
* \brief External computation that cannot be splitted.
*/
Expand Down Expand Up @@ -404,14 +444,24 @@ class ExternOpNode : public OperationNode {
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);

static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};

/*!
* \brief Managed reference to ExternOpNode
* \sa ExternOpNode
*/
class ExternOp : public Operation {
public:
TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
};

/*!
* \brief A computation operator that generated by hybrid script.
*/
Expand Down Expand Up @@ -459,13 +509,23 @@ class HybridOpNode : public OperationNode {
v->Visit("axis", &axis);
v->Visit("body", &body);
}
TVM_DLL static Operation make(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};

/*!
* \brief Managed reference to HybridOpNode
* \sa HybridOpNode
*/
class HybridOp : public Operation {
public:
TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);

TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
};

/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
Expand Down
68 changes: 51 additions & 17 deletions include/tvm/te/schedule.h
Expand Up @@ -277,6 +277,12 @@ class Schedule : public ObjectRef {
public:
Schedule() {}
explicit Schedule(ObjectPtr<Object> n) : ObjectRef(n) {}
/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
TVM_DLL explicit Schedule(Array<Operation> ops);
/*!
* \brief Get a copy of current schedule.
* \return The copied schedule.
Expand Down Expand Up @@ -553,13 +559,6 @@ class ScheduleNode : public Object {
*/
TVM_DLL bool Contain(const Tensor& tensor) const { return Contain(tensor->op); }

/*!
* \brief Create a schedule for array of ops(and their dependencies).
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
TVM_DLL static Schedule make(Array<Operation> ops);

static constexpr const char* _type_key = "Schedule";
TVM_DECLARE_FINAL_OBJECT_INFO(ScheduleNode, Object);
};
Expand All @@ -569,7 +568,7 @@ class ScheduleNode : public Object {
* \param ops The ops to be scheduled.
* \return sch The created Schedule.
*/
inline Schedule create_schedule(Array<Operation> ops) { return ScheduleNode::make(ops); }
inline Schedule create_schedule(Array<Operation> ops) { return Schedule(ops); }

/*! \brief node container for IterVar attr */
class IterVarAttrNode : public Object {
Expand Down Expand Up @@ -648,13 +647,21 @@ class SplitNode : public IterVarRelationNode {
v->Visit("nparts", &nparts);
}

static IterVarRelation make(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor,
PrimExpr nparts);

static constexpr const char* _type_key = "Split";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to SplitNode
* \sa SplitNode
*/
class Split : public IterVarRelation {
public:
TVM_DLL Split(IterVar parent, IterVar outer, IterVar inner, PrimExpr factor, PrimExpr nparts);

TVM_DEFINE_OBJECT_REF_METHODS(Split, IterVarRelation, SplitNode);
};

/*!
* \brief Fuse two domains into one domain.
*/
Expand All @@ -673,12 +680,21 @@ class FuseNode : public IterVarRelationNode {
v->Visit("fused", &fused);
}

static IterVarRelation make(IterVar outer, IterVar inner, IterVar fused);

static constexpr const char* _type_key = "Fuse";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to FuseNode
* \sa FuseNode
*/
class Fuse : public IterVarRelation {
public:
TVM_DLL Fuse(IterVar outer, IterVar inner, IterVar fused);

TVM_DEFINE_OBJECT_REF_METHODS(Fuse, IterVarRelation, FuseNode);
};

/*!
* \brief Rebase the iteration to make min to be 0.
* This is useful to normalize the Schedule
Expand All @@ -696,12 +712,21 @@ class RebaseNode : public IterVarRelationNode {
v->Visit("rebased", &rebased);
}

static IterVarRelation make(IterVar parent, IterVar rebased);

static constexpr const char* _type_key = "Rebase";
TVM_DECLARE_FINAL_OBJECT_INFO(RebaseNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to RebaseNode
* \sa RebaseNode
*/
class Rebase : public IterVarRelation {
public:
TVM_DLL Rebase(IterVar parent, IterVar rebased);

TVM_DEFINE_OBJECT_REF_METHODS(Rebase, IterVarRelation, RebaseNode);
};

/*!
* \brief Singleton iterator [0, 1)
*/
Expand All @@ -712,12 +737,21 @@ class SingletonNode : public IterVarRelationNode {

void VisitAttrs(AttrVisitor* v) { v->Visit("iter", &iter); }

static IterVarRelation make(IterVar iter);

static constexpr const char* _type_key = "Singleton";
TVM_DECLARE_FINAL_OBJECT_INFO(SingletonNode, IterVarRelationNode);
};

/*!
* \brief Managed reference to SingletonNode
* \sa SingletonNode
*/
class Singleton : public IterVarRelation {
public:
TVM_DLL explicit Singleton(IterVar iter);

TVM_DEFINE_OBJECT_REF_METHODS(Singleton, IterVarRelation, SingletonNode);
};

/*! \brief Container for specialization conditions. */
class SpecializedConditionNode : public Object {
public:
Expand Down

0 comments on commit eaef6b3

Please sign in to comment.