forked from neo-ai/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
operation.h
641 lines (588 loc) · 24.8 KB
/
operation.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/te/operation.h
* \brief Operation node can generate one or multiple Tensors
*/
#ifndef TVM_TE_OPERATION_H_
#define TVM_TE_OPERATION_H_
#include <tvm/arith/analyzer.h>
#include <tvm/te/schedule.h>
#include <tvm/te/tensor.h>
#include <tvm/tir/buffer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <string>
#include <unordered_map>
#include <vector>
namespace tvm {
/*! \brief Tensor expression language DSL. */
namespace te {
/*!
* \brief Temporary data structure to store union
* of bounds of each axis of Tensor.
*/
struct TensorDom {
// constructor
explicit TensorDom(int ndim) : data(ndim) {}
/*! \brief The domain data */
std::vector<std::vector<IntSet> > data;
};
/*!
* \brief Base class of all operation nodes
*/
class OperationNode : public tir::FunctionBaseNode {
public:
/*! \brief optional name of the operation */
std::string name;
/*! \brief optional tag of the operation */
std::string tag;
/*! \brief additional attributes of the operation*/
Map<String, ObjectRef> attrs;
/*! \return name of the operation */
const std::string& func_name() const final { return name; }
/*!
* \return The list of iteration variable at root
* \note root_iter_vars decides the shape of the outputs.
*/
virtual Array<IterVar> root_iter_vars() const = 0;
/*!
* \brief Get data type. i-th output tensor.
* \param i The output index.
* \return type of i-th output.
*/
virtual DataType output_dtype(size_t i) const = 0;
/*!
* \brief Get shape of i-th output tensor.
* \param i The output index.
* \return shape of i-th output.
*/
virtual Array<PrimExpr> output_shape(size_t i) const = 0;
/*!
* \brief List all the input Tensors.
* \return List of input tensors.
*/
virtual Array<Tensor> InputTensors() const = 0;
/*!
* \brief Replace the input of the operation by pattern specified by rmap.
*
* \param self The reference to self.
* \param rmap The replacement map.
* \return self if nothing is replaced, otherwise return replaced op.
*/
virtual Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const = 0;
/*!
* \brief Propagate the bounds to inputs
* \param self The reference to self.
* \param analyzer The analyzer to be used in the function.
* \param dom_map the domain map of Variables(corresponds to root_iter_vars)
* \param out_dom_map The output domain.
* The function is only asked to fill the bounds for Tensors that
* is already in the out_dom_map
*/
virtual void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const = 0;
/*!
* \brief Gather the bound from output tensor.
* Set the range of each root_iter_vars in the op to out_dom_map
*
* \param self The reference to self.
* \param tensor_dom Domain map of Tensor->access set of each dimension.
* \param out_dom_map The output domain map of each IterVar to be setted.
*/
virtual void GatherBound(const Operation& self,
const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const = 0;
/*!
* \brief Build the Realize statement that realizes
* the op's output tensors.
* \param stage the op's stage.
* \param realize_map The realization domain map of the operators.
* \param body The body that is going to get
* \return A realization statement that wraps body.
*/
virtual Stmt BuildRealize(const Stage& stage,
const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const = 0;
/*!
* \brief Build the statement that provide the output tensors.
* \param stage The schedule stage of the op.
* \param dom_map The domain map of all iteration domains.
* \param debug_keep_trivial_loop Whether keep trivial loops with extent of 1
* \return A statement that add production and wraps consumer.
*/
virtual Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const = 0;
static constexpr const char* _type_key = "Operation";
TVM_DECLARE_BASE_OBJECT_INFO(OperationNode, Object);
};
/*!
* \brief A placeholder op represents an input placeholder.
*/
class PlaceholderOpNode : public OperationNode {
public:
/*! \brief The shape of the input */
Array<PrimExpr> shape;
/*! \brief The data type of the input. */
DataType dtype;
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
}
static constexpr const char* _type_key = "PlaceholderOp";
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
};
/*!
* \brief Managed reference to PlaceholderOpNode
* \sa PlaceholderOpNode
*/
class PlaceholderOp : public Operation {
public:
TVM_DLL PlaceholderOp(std::string name, Array<PrimExpr> shape, DataType dtype);
TVM_DEFINE_OBJECT_REF_METHODS(PlaceholderOp, Operation, PlaceholderOpNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
* This is the base class for ComputeOp (operating on a scalar at a time) and
* TensorComputeOp (operating on a TensorSlice at a time)
*/
class TVM_DLL BaseComputeOpNode : public OperationNode {
public:
/*! \brief IterVar on each axis */
Array<IterVar> axis;
/*! \brief IterVar on each reduction axis, if the body is a Reduce */
Array<IterVar> reduce_axis;
// override functions
Array<IterVar> root_iter_vars() const final;
Array<PrimExpr> output_shape(size_t idx) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
virtual size_t num_schedulable_dims() const = 0;
static constexpr const char* _type_key = "BaseComputeOp";
TVM_DECLARE_BASE_OBJECT_INFO(BaseComputeOpNode, OperationNode);
};
/*!
* \brief A Compute op that compute a tensor on certain domain.
*/
class TVM_DLL ComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief the compute expression */
Array<PrimExpr> body;
/*! \brief constructor */
ComputeOpNode() {}
// override functions
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "ComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeOpNode, BaseComputeOpNode);
};
/*!
* \brief Managed reference to ComputeOpNode
* \sa ComputeOpNode
*/
class ComputeOp : public Operation {
public:
TVM_DLL ComputeOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<IterVar> axis, Array<PrimExpr> body);
TVM_DEFINE_OBJECT_REF_METHODS(ComputeOp, Operation, ComputeOpNode);
};
/*!
* \brief A TenorCompute op that compute a tensor with an tensor intrinsic.
*/
class TensorComputeOpNode : public BaseComputeOpNode {
public:
/*! \brief number of axes that can be scheduled */
int schedulable_ndim;
/*! \brief TensorIntrin used to compute */
TensorIntrin intrin;
/*! \brief input tensors of intrin */
Array<Tensor> inputs;
/*! \brief region of input tensors */
Array<Region> input_regions;
/*! \brief scalar expression inputs */
Array<PrimExpr> scalar_inputs;
/*! \brief constructor */
TensorComputeOpNode() {}
// override functions
int num_outputs() const final;
DataType output_dtype(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
size_t num_schedulable_dims() const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("axis", &axis);
v->Visit("reduce_axis", &reduce_axis);
v->Visit("schedulable_ndim", &schedulable_ndim);
v->Visit("intrin", &intrin);
v->Visit("inputs", &inputs);
v->Visit("input_regions", &input_regions);
v->Visit("scalar_inputs", &scalar_inputs);
}
static constexpr const char* _type_key = "TensorComputeOp";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorComputeOpNode, BaseComputeOpNode);
};
/*!
* \brief Managed reference to TensorComputeOpNode
* \sa TensorComputeOpNode
*/
class TensorComputeOp : public Operation {
public:
TVM_DLL TensorComputeOp(std::string name, std::string tag, Array<IterVar> axis,
Array<IterVar> reduce_axis, int schedulable_ndim, TensorIntrin intrin,
Array<Tensor> tensors, Array<Region> regions,
Array<PrimExpr> scalar_inputs);
TVM_DEFINE_OBJECT_REF_METHODS(TensorComputeOp, Operation, TensorComputeOpNode);
};
/*!
* \brief Symbolic scan.
*/
class ScanOpNode : public OperationNode {
public:
/*! \brief IterVar to scan over */
IterVar scan_axis;
/*! \brief the initialization tensors */
Array<Tensor> init;
/*! \brief the update function represented by tensor */
Array<Tensor> update;
/*! \brief The placeholder to refer as states in update. */
Array<Tensor> state_placeholder;
/*!
* \brief the inputs to the scan, these are optionally provided
* But they can be helpful to provide hints to speedup get of scan body.
*/
Array<Tensor> inputs;
/*!
* \brief Spatial axis to indicate spatial dimension of each output.
* They corresponds to flattened spatial axis of the outputs.
*
* [output[0].axis[1], output[0].axis[2]... output[k].axis[j]...]
* These are auxiliary data structure for storing result of bound inference.
* They do not corresponds to splittable iterations, thus the name comes
* with underscore.
*/
Array<IterVar> spatial_axis_;
/*! \brief constructor */
ScanOpNode() {}
// override behavior.
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("scan_axis", &scan_axis);
v->Visit("init", &init);
v->Visit("update", &update);
v->Visit("state_placeholder", &state_placeholder);
v->Visit("inputs", &inputs);
v->Visit("spatial_axis_", &spatial_axis_);
}
static constexpr const char* _type_key = "ScanOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ScanOpNode, OperationNode);
};
/*!
* \brief Managed reference to ScanOpNode
* \sa ScanOpNode
*/
class ScanOp : public Operation {
public:
TVM_DLL ScanOp(std::string name, std::string tag, Map<String, ObjectRef> attrs, IterVar axis,
Array<Tensor> init, Array<Tensor> update, Array<Tensor> state_placeholder,
Array<Tensor> input);
TVM_DEFINE_OBJECT_REF_METHODS(ScanOp, Operation, ScanOpNode);
};
/*!
* \brief External computation that cannot be splitted.
*/
class ExternOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of inputs */
Array<Buffer> input_placeholders;
/*! \brief Symbolic placeholder representation of outputs */
Array<Buffer> output_placeholders;
/*! \brief the statement that generates the computation. */
Stmt body;
/*! \brief constructor */
ExternOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("input_placeholders", &input_placeholders);
v->Visit("output_placeholders", &output_placeholders);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "ExternOp";
TVM_DECLARE_FINAL_OBJECT_INFO(ExternOpNode, OperationNode);
};
/*!
* \brief Managed reference to ExternOpNode
* \sa ExternOpNode
*/
class ExternOp : public Operation {
public:
TVM_DLL ExternOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Buffer> input_placeholders,
Array<Buffer> output_placeholders, Stmt body);
TVM_DEFINE_OBJECT_REF_METHODS(ExternOp, Operation, ExternOpNode);
};
/*!
* \brief A computation operator that generated by hybrid script.
*/
class HybridOpNode : public OperationNode {
public:
/*! \brief The input tensors */
Array<Tensor> inputs;
/*! \brief Symbolic placeholder representation of outputs */
Array<Tensor> outputs;
/*! \brief The axis of iterations */
Array<IterVar> axis;
/*! \brief the statement that generates the computation. This is
* slightly different from the body in ExternOpNode. All the output
* tensors keep its own name specified by users in the script.
* However, when compilation, these tensors will be placed by those
* actual output tensors. */
Stmt body;
/*! \brief constructor */
HybridOpNode() {}
// override functions
int num_outputs() const final;
Array<IterVar> root_iter_vars() const final;
DataType output_dtype(size_t i) const final;
Array<PrimExpr> output_shape(size_t i) const final;
Array<Tensor> InputTensors() const final;
Operation ReplaceInputs(const Operation& self,
const std::unordered_map<Tensor, Tensor>& rmap) const final;
void PropBoundToInputs(const Operation& self, arith::Analyzer* analyzer,
const std::unordered_map<const VarNode*, IntSet>& dom_map,
std::unordered_map<Tensor, TensorDom>* out_dom_map) const final;
void GatherBound(const Operation& self, const std::unordered_map<Tensor, TensorDom>& tensor_dom,
std::unordered_map<IterVar, Range>* out_dom_map) const final;
Stmt BuildRealize(const Stage& stage, const std::unordered_map<IterVar, Range>& realize_map,
const Stmt& body) const final;
Stmt BuildProvide(const Stage& stage, const std::unordered_map<IterVar, Range>& dom_map,
bool debug_keep_trivial_loop) const final;
void VisitAttrs(AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("tag", &tag);
v->Visit("attrs", &attrs);
v->Visit("inputs", &inputs);
v->Visit("outputs", &outputs);
v->Visit("axis", &axis);
v->Visit("body", &body);
}
static constexpr const char* _type_key = "HybridOp";
TVM_DECLARE_FINAL_OBJECT_INFO(HybridOpNode, OperationNode);
};
/*!
* \brief Managed reference to HybridOpNode
* \sa HybridOpNode
*/
class HybridOp : public Operation {
public:
TVM_DLL HybridOp(std::string name, std::string tag, Map<String, ObjectRef> attrs,
Array<Tensor> inputs, Array<Tensor> outputs, Stmt body);
TVM_DEFINE_OBJECT_REF_METHODS(HybridOp, Operation, HybridOpNode);
};
/*!
* \brief Construct a new Var expression
* \param name_hint The name hint for the expression
* \param t The type of the expression
*/
TVM_DLL Var var(std::string name_hint, DataType t = DataType::Int(32));
/*!
* \brief Create a new IterVar that represents an axis in thread.
*
* \param dom Optional, domain of the thread axis.
* \param tag The thread tag of the axis.
*/
TVM_DLL IterVar thread_axis(Range dom, std::string tag);
/*!
* \brief Create a new IterVar for reduction operations.
*
* \param dom The domain of the reduction axis.
* \param name The name of the reduction axis.
*/
TVM_DLL IterVar reduce_axis(Range dom, std::string name = "rv");
/*! \brief The compute function to specify the input source of a Tensor */
using FCompute = std::function<PrimExpr(const Array<Var>& i)>;
/*! \brief The compute function to specify the inputs source of Tensors */
using FBatchCompute = std::function<Array<PrimExpr>(const Array<Var>& i)>;
/*!
* \brief create a place holder tensor.
* \param shape The shape of the tensor.
* \param dtype the data type of the tensor.
* \param name The name of the Tensor.
*/
TVM_DLL Tensor placeholder(Array<PrimExpr> shape, DataType dtype = DataType::Float(32),
std::string name = "placeholder");
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensor.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Tensor compute(Array<PrimExpr> shape, FCompute fcompute, std::string name = "tensor",
std::string tag = "", Map<String, ObjectRef> attrs = {});
/*!
* \brief Construct a new tensor by computing over shape,
* using the computation rule: result_tensor[axis] = fcompute(axis)
* \param shape Shape of the tensor.
* \param fcompute The compute function to create the tensors.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Array<Tensor> compute(Array<PrimExpr> shape, FBatchCompute fcompute,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {});
/*!
* \brief Construct new tensors by scan.
*
* \param init The intialize tensor of first K steps.
* \param update The update tensor indicated the updated result after each timestamp.
* \param state_placeholder The placeholder for the states.
* \param inputs The inputs to the scan body, this is optional,
* but recommended to provide concrete information about scan body.
* \param name The optional name of the tensor.
* \param tag The optional tag of the tensor.
* \param attrs Optional additional attributes of the compute.
*/
TVM_DLL Array<Tensor> scan(Array<Tensor> init, Array<Tensor> update,
Array<Tensor> state_placeholder, Array<Tensor> inputs = Array<Tensor>(),
std::string name = "scan", std::string tag = "",
Map<String, ObjectRef> attrs = {});
// same as compute, specialized for different fcompute function
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0]); };
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1]); };
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2]); };
return compute(shape, fc, name, tag, attrs);
}
inline Tensor compute(Array<PrimExpr> shape, std::function<PrimExpr(Var, Var, Var, Var)> f,
std::string name = "tensor", std::string tag = "",
Map<String, ObjectRef> attrs = {}) {
FCompute fc = [f](const Array<Var>& i) { return f(i[0], i[1], i[2], i[3]); };
return compute(shape, fc, name, tag, attrs);
}
// inline function.
inline const OperationNode* Operation::operator->() const {
return static_cast<const OperationNode*>(get());
}
} // namespace te
} // namespace tvm
#endif // TVM_TE_OPERATION_H_