forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
/
struct_info.h
451 lines (393 loc) · 14.1 KB
/
struct_info.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_RELAX_STRUCT_INFO_H_
#define TVM_RELAX_STRUCT_INFO_H_
#include <tvm/ir/env_func.h>
#include <tvm/ir/source_map.h>
#include <tvm/node/node.h>
#include <tvm/relax/block_builder.h>
#include <tvm/relax/expr.h>
#include <tvm/relax/type.h>
namespace tvm {
namespace relax {
/*!
* \brief Opaque object.
*/
class ObjectStructInfoNode : public StructInfoNode {
public:
void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
bool SEqualReduce(const ObjectStructInfoNode* other, SEqualReducer equal) const { return true; }
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
static constexpr const char* _type_key = "relax.ObjectStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectStructInfoNode, StructInfoNode);
};
/*!
* \brief Managed reference to ObjectStructInfoNode.
* \sa ObjectStructInfoNode
*/
class ObjectStructInfo : public StructInfo {
public:
TVM_DLL ObjectStructInfo(Span span = Span());
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectStructInfo, StructInfo, ObjectStructInfoNode);
};
/*!
* \brief Primitive value.
*/
class PrimStructInfoNode : public StructInfoNode {
public:
/*! \brief Underlying data type of the primitive value */
DataType dtype;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("span", &span);
}
bool SEqualReduce(const PrimStructInfoNode* other, SEqualReducer equal) const {
return equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(dtype); }
static constexpr const char* _type_key = "relax.PrimStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(PrimStructInfoNode, StructInfoNode);
};
/*!
* \brief Managed reference to PrimStructInfoNode.
* \sa PrimStructInfoNode
*/
class PrimStructInfo : public StructInfo {
public:
TVM_DLL PrimStructInfo(DataType dtype, Span span = Span());
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PrimStructInfo, StructInfo, PrimStructInfoNode);
};
/*!
* \brief StructInfo of shape value.
*/
class ShapeStructInfoNode : public StructInfoNode {
public:
/*! \brief optionally stores the symbolic value patterns of the shape */
Optional<Array<PrimExpr>> values;
/*!
* \brief The number of dimension of the shape, can be unknown.
* \sa kUnknownNDim
*/
int ndim;
/*! \return Whether the struct info contains unknown ndim. */
bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
void VisitAttrs(AttrVisitor* v) {
v->Visit("values", &values);
v->Visit("ndim", &ndim);
v->Visit("span", &span);
}
bool SEqualReduce(const ShapeStructInfoNode* other, SEqualReducer equal) const {
return equal(values, other->values) && equal(ndim, other->ndim);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(values);
hash_reduce(ndim);
}
static constexpr const char* _type_key = "relax.ShapeStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeStructInfoNode, StructInfoNode);
};
/*!
* \brief Managed reference to ShapeStructInfoNode.
* \sa ShapeStructInfoNode
*/
class ShapeStructInfo : public StructInfo {
public:
/*!
* \brief Construction with known symbolic shape patterns
* \param values The symbolic shape values
* \param span The span of the AST.
*/
TVM_DLL ShapeStructInfo(Array<PrimExpr> values, Span span = Span());
/*!
* \brief Construction with known unknown symbolic shape patterns.
* \param ndim Number of dimensions -- can be kUnknownNDim
* \param span The span of the AST.
*/
TVM_DLL ShapeStructInfo(int ndim, Span span = Span());
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeStructInfo, StructInfo, ShapeStructInfoNode);
};
/*!
* \brief StructInfo of Tensor.
*/
class TensorStructInfoNode : public StructInfoNode {
public:
/*!
* \brief optionally store the shape expression of the tensor.
* \note shape must be normalized: it can only be NullOpt or ShapeExpr or Var.
*/
Optional<Expr> shape;
/*! \brief The content data type, use void to denote the dtype is unknown. */
DataType dtype;
/*!
* \brief The number of dimension of the tensor, can be unknown.
* \sa kUnknownNDim
*/
int ndim;
/*! \return Whether the struct info contains unknown ndim. */
bool IsUnknownNdim() const { return ndim == kUnknownNDim; }
/*! \return Whether the struct info contains unknown dtype. */
bool IsUnknownDtype() const { return dtype.is_void(); }
/*! \return Shape if it is known. */
Optional<Array<PrimExpr>> GetShape() const {
if (!shape.defined()) return {};
ShapeStructInfo shape_sinfo = Downcast<ShapeStructInfo>(this->shape.value()->struct_info_);
return shape_sinfo->values;
}
void VisitAttrs(AttrVisitor* v) {
v->Visit("shape", &shape);
v->Visit("dtype", &dtype);
v->Visit("ndim", &ndim);
v->Visit("span", &span);
}
bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) const {
return equal(shape, other->shape) && equal(ndim, other->ndim) && equal(dtype, other->dtype);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(shape);
hash_reduce(dtype);
hash_reduce(ndim);
}
static constexpr const char* _type_key = "relax.TensorStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorStructInfoNode, StructInfoNode);
};
/*!
* \brief Managed reference to TensorStructInfoNode.
* \sa TensorStructInfoNode
*/
class TensorStructInfo : public StructInfo {
public:
/*!
* \brief Construction with a known shape expression.
* \param shape The shape of the tensor.
* \param dtype The data type of tensor's elements.
* \param span The span of the AST.
*
* \note shape must already be normalized.
*/
TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span());
/*!
* \brief Construction with an unknown shape expression.
* \param dtype The data type of tensor's elements.
* \param ndim The number of dimensions
* \param span The span of the AST.
*/
TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span());
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, TensorStructInfoNode);
};
/*!
* \brief StructInfo of Tuple.
*/
class TupleStructInfoNode : public StructInfoNode {
public:
/*! \brief The struct info of tuple fields. */
Array<StructInfo> fields;
void VisitAttrs(AttrVisitor* v) {
v->Visit("fields", &fields);
v->Visit("span", &span);
}
bool SEqualReduce(const TupleStructInfoNode* other, SEqualReducer equal) const {
return equal(fields, other->fields);
}
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(fields); }
static constexpr const char* _type_key = "relax.TupleStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(TupleStructInfoNode, StructInfoNode);
};
/*!
* \brief Managed reference to TupleStructInfoNode.
* \sa TupleStructInfoNode
*/
class TupleStructInfo : public StructInfo {
public:
/*!
* \brief Constructor
* \param fields Struct info of tuple fields.
* \param span The span of the AST.
*/
TVM_DLL TupleStructInfo(Array<StructInfo> fields, Span span = Span());
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TupleStructInfo, StructInfo, TupleStructInfoNode);
};
/*!
* \brief custom-defined StructInfo derivation function.
* \param call The call expression to be derived.
* \param ctx The builder context.
* \return The derived struct info of the call.
*/
using StructInfoDeriveFunc = TypedEnvFunc<StructInfo(const Call& call, const BlockBuilder& ctx)>;
/*!
* \brief Structure information about function.
*
* This data structure contains enough information for us to
* do best-effort structure information deduction.
*/
class FuncStructInfoNode : public StructInfoNode {
public:
/*!
* \brief The parameter struct info of the function.
* \note When params is NullOpt means the function can take arbitrary number of arguments.
* We define such functions as Opaque function.
*/
Optional<Array<StructInfo>> params;
/*!
* \brief The struct info of the function's return value.
*/
StructInfo ret;
/*!
* \brief Derivation function of opaque functions that may take any number of parameters.
* \note When derive_func is not empty, then params should be NullOpt,
* ret should be ObjectStructInfo()
*/
Optional<StructInfoDeriveFunc> derive_func;
/*!
* \brief Whether the function is pure.
* \note This parameter should be set to true only if the function is pure on all inputs.
* If the function _may_ have visible side effects, set it to false.
*/
bool purity;
/*!
* \return Whether the func struct info is opaque.
* \note We define a function as opaque we have no constraints on params.
*/
bool IsOpaque() const { return !params.defined(); }
void VisitAttrs(AttrVisitor* v) {
v->Visit("params", ¶ms);
v->Visit("ret", &ret);
v->Visit("derive_func", &derive_func);
v->Visit("span", &span);
v->Visit("purity", &purity);
}
bool SEqualReduce(const FuncStructInfoNode* other, SEqualReducer equal) const {
return equal.DefEqual(params, other->params) && equal(ret, other->ret) &&
equal(purity, other->purity) && equal(derive_func, other->derive_func);
}
void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(params);
hash_reduce(ret);
hash_reduce(purity);
hash_reduce(derive_func);
}
static constexpr const char* _type_key = "relax.FuncStructInfo";
TVM_DECLARE_FINAL_OBJECT_INFO(FuncStructInfoNode, StructInfoNode);
};
/*!
* \brief Managed reference to FuncStructInfoNode.
* \sa FuncStructInfoNode
*/
class FuncStructInfo : public StructInfo {
public:
/*!
* \brief Constructor from parameter struct info and return value struct info.
* \param params The struct info of function parameters.
* \param ret The return value struct info.
* \param purity The purity of the function (true by default).
* \param span The span of the AST.
*
* \note If the ret contains variables(tir::Var and relax::Var), they must be deducible from
* params. If you are unsure, you can always erase ret to static.
*/
TVM_DLL FuncStructInfo(Array<StructInfo> params, StructInfo ret, bool purity = true,
Span span = Span());
/*!
* \brief Constructing an opaque function struct info using derive_func.
*
* \param derive_func Derivation function.
* \param purity The purity of the function
* (false by default: most external functions are not pure).
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfoDeriveFunc derive_func, bool purity = false,
Span span = Span());
/*!
* \brief Construct an opaque function using from return struct info.
*
* \param ret The struct info of the return value.
* \param purity The purity of the function
* (false by default: most external functions are not pure).
* \param span The span of the AST.
*
* \return The FuncStructInfo for opaque packedfunc.
* \note Defaults to an derive func that always return ObjectStructInfo if not specified.
*/
TVM_DLL static FuncStructInfo OpaqueFunc(StructInfo ret = ObjectStructInfo(), bool purity = false,
Span span = Span());
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(FuncStructInfo, StructInfo, FuncStructInfoNode);
};
/*!
* \brief Match and check if expr have StructInfo T and return it.
*
* \param expr The input expression.
* \return The result of match.
* \tparam T the underlying structure info type
*/
template <typename T>
inline Optional<T> MatchStructInfo(const Expr& expr) {
using TNode = typename T::ContainerType;
if (const TNode* ptr = expr->struct_info_.as<TNode>()) {
return GetRef<T>(ptr);
} else {
return NullOpt;
}
}
/*!
* \brief Get the structure info of a given expr and try to cast it as const T*.
*
* \param expr The input expression.
* \return The pointer. Returns nullptr if the type does not match
* \tparam T the underlying structure info type
*/
template <typename T>
inline const T* GetStructInfoAs(const Expr& expr) {
ICHECK(expr->struct_info_.defined())
<< "The struct_info is not populated, check if you have normalized the expr";
return expr->struct_info_.as<T>();
}
/*!
* \brief Get the underlying structure info of expr.
*
* \param expr The input expression.
* \return underlying struct info.
*/
inline StructInfo GetStructInfo(const Expr& expr) {
auto* ptr = expr->struct_info_.as<StructInfoNode>();
ICHECK(ptr) << "The struct_info is not populated, check if you have normalized the expr";
return GetRef<StructInfo>(ptr);
}
/*!
* \brief Whether the expr has void struct info.
*
* \param expr The input expression.
* \return Whether the expr has void struct info.
*/
inline bool HasVoidStructInfo(const Expr& expr) {
auto* ptr = expr->struct_info_.as<TupleStructInfoNode>();
return ptr != nullptr && ptr->fields.size() == 0;
}
/*!
* \brief Update the struct info of an Expr.
* \param expr The Expr whose struct info to be updated.
* \param struct_info The struct_info assigned.
* \note We ensure idempotence, that is we can only update the struct_info of an Expr only
* if the original one is nullptr.
*/
TVM_DLL void UpdateStructInfo(Expr expr, StructInfo struct_info);
} // namespace relax
} // namespace tvm
#endif // TVM_RELAX_STRUCT_INFO_H_