This repository has been archived by the owner on Apr 23, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 258
/
LinalgTraits.h
193 lines (182 loc) · 7.08 KB
/
LinalgTraits.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
//===- LinalgTraits.h - Linalg Traits ---------------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_
#define MLIR_DIALECT_LINALG_LINALGTRAITS_H_
#include "mlir/IR/OpDefinition.h"
#include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
#include "mlir/Support/LLVM.h"
namespace mlir {
namespace OpTrait {
namespace linalg {
/// This class provides the API for ops that are known to have a specified
/// number of inputs and outputs, all passed as operands. This is used as a
/// trait like this:
///
/// class DotOp : public Op<DotOp, OpTrait::NInputsAndOutputs<2, 1>::Impl> {
///
template <unsigned NInputs, unsigned NOutputs> class NInputsAndOutputs {
public:
template <typename ConcreteType>
class Impl
: public OpTrait::TraitBase<ConcreteType,
NInputsAndOutputs<NInputs, NOutputs>::Impl> {
public:
static unsigned getNumInputs() { return NInputs; }
static unsigned getNumOutputs() { return NOutputs; }
static LogicalResult verifyTrait(Operation *op) {
return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
}
};
};
/// This class provides the API for ops that are known to operate on views. This
/// trait must be used in conjunction with an op definition or a trait that
/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a
/// trait like this:
///
/// class DotOp : public Op<DotOp, OpTrait::ViewTrait> {
///
template <typename ConcreteType>
class ViewTraits : public OpTrait::TraitBase<ConcreteType, ViewTraits> {
private:
/// Return the number of input views. For internal use only.
unsigned nInputs() {
return cast<ConcreteType>(this->getOperation()).getNumInputs();
}
/// Return the number of input views. For internal use only.
unsigned nOutputs() {
return cast<ConcreteType>(this->getOperation()).getNumOutputs();
}
public:
/// Return the `i`-th input view.
Value *getInput(unsigned i) {
assert(i < nInputs());
return this->getOperation()->getOperand(i);
}
/// Return the index of `view` in the list of input views if found, llvm::None
/// otherwise.
llvm::Optional<unsigned> getIndexOfInput(Value *view) {
auto it = llvm::find(getInputs(), view);
if (it != getInputs().end())
return it - getInputs().begin();
return llvm::None;
}
/// Return the `i`-th input view type.
mlir::linalg::ViewType getInputViewType(unsigned i) {
return getInput(i)->getType().template cast<mlir::linalg::ViewType>();
}
/// Return the range over input views.
Operation::operand_range getInputs() {
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + nInputs()};
}
/// Return the `i`-th output view.
Value *getOutput(unsigned i) {
return this->getOperation()->getOperand(nInputs() + i);
}
/// Return the index of `view` in the list of output views if found,
/// llvm::None otherwise.
llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
auto it = llvm::find(getOutputs(), view);
if (it != getOutputs().end())
return it - getOutputs().begin();
return llvm::None;
}
/// Return the `i`-th output view type.
mlir::linalg::ViewType getOutputViewType(unsigned i) {
return getOutput(i)->getType().template cast<mlir::linalg::ViewType>();
}
/// Return the range over output views.
Operation::operand_range getOutputs() {
auto range = this->getOperation()->getOperands();
return {range.begin() + nInputs(),
range.begin() + getNumInputsAndOutputs()};
}
/// Return the number of input and output views.
unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
/// Return the `i`-th view type.
mlir::linalg::ViewType getViewType(unsigned i) {
return (i < nInputs()) ? getInputViewType(i)
: getOutputViewType(i - nInputs());
}
/// Return the range over input and output views.
Operation::operand_range getInputsAndOutputs() {
auto range = this->getOperation()->getOperands();
return {range.begin(), range.begin() + getNumInputsAndOutputs()};
}
static LogicalResult verifyTrait(Operation *op) {
auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
return failure();
for (unsigned i = 0, e = nViews; i < e; ++i) {
if (!op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>())
return op->emitOpError("operand ") << i << " must have view type ";
}
return success();
}
};
/// This class provides the API for ops that are known to have a specified
/// number of parallel, reduction and window loops. This is used as a trait like
/// this:
///
/// class MatmulOp : public Op<MatmulOp, OpTrait::NLoopTypes<2, 1, 0>::Impl> {
///
template <unsigned NParallel, unsigned NReduction, unsigned NWindow = 0>
class NLoopTypes {
public:
template <typename ConcreteType>
class Impl
: public OpTrait::TraitBase<
ConcreteType, NLoopTypes<NParallel, NReduction, NWindow>::Impl> {
public:
static unsigned getNumParallelLoops() { return NParallel; }
static unsigned getNumReductionLoops() { return NReduction; }
static unsigned getNumWindowLoops() { return NWindow; }
static unsigned getNumLoops() { return NParallel + NReduction + NWindow; }
};
};
/// This class provides the API for ops that are known to have a specified
/// list of view ranks. This is used as a trait like this:
///
/// class MatvecOp : public Op<MatvecOp, OpTrait::ViewRanks<2, 1, 1>::Impl> {
///
template <unsigned... Ranks> class ViewRanks {
public:
template <typename ConcreteType>
class Impl
: public OpTrait::TraitBase<ConcreteType, ViewRanks<Ranks...>::Impl> {
public:
static LogicalResult verifyTrait(Operation *op) {
if (op->getNumOperands() != sizeof...(Ranks))
return op->emitError("expected ") << sizeof...(Ranks) << " operands";
unsigned ranks[]{Ranks...};
for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) {
auto viewType =
op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>();
if (!viewType)
return op->emitOpError("operand ") << i << " must have view type ";
if (ranks[i] != viewType.getRank())
return op->emitOpError("operand ")
<< i << " must have rank " << ranks[i];
}
return success();
}
};
};
} // namespace linalg
} // namespace OpTrait
} // namespace mlir
#endif // MLIR_DIALECT_LINALG_LINALGTRAITS_H_