-
Notifications
You must be signed in to change notification settings - Fork 21.4k
/
normalize_ops.cpp
90 lines (79 loc) · 3.15 KB
/
normalize_ops.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include <torch/csrc/jit/passes/normalize_ops.h>
#include <c10/util/Exception.h>
namespace torch {
namespace jit {
namespace {
void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {
WithInsertPoint insert_guard{node};
auto graph = node->owningGraph();
auto replace_node = graph->insertNode(graph->create(new_symbol, 0));
for (Value* v : node->inputs()) {
replace_node->addInput(v);
}
for (Value* v : node->outputs()) {
auto new_out = replace_node->addOutput()->copyMetadata(v);
v->replaceAllUsesWith(new_out);
}
replace_node->copyMetadata(node);
TORCH_INTERNAL_ASSERT(
replace_node->maybeOperator(),
"invalid symbol replacement:",
new_symbol,
node->kind());
}
// having multiple ops in our IR that do the same thing makes the IR more
// difficult to consumer for downstream user of the IR, such as our own
// optimization passes here, we convert op aliases into a standard form
bool normalizeOpAliases(graph_node_list_iterator& iter) {
auto alias = getOperatorAliasMap().find(iter->kind());
if (alias != getOperatorAliasMap().end()) {
replaceNodeWithNewSymbol(*iter, alias->second);
iter.destroyCurrent();
return true;
}
return false;
}
void NormalizeOps(Block* block) {
for (auto it = block->nodes().begin(), end = block->nodes().end();
it != end;) {
for (auto sub : it->blocks()) {
NormalizeOps(sub);
}
if (normalizeOpAliases(it)) {
continue;
}
it++;
}
}
} // namespace
const std::unordered_map<Symbol, Symbol>& getOperatorAliasMap() {
// map from op alias -> normalized op
static const std::unordered_map<Symbol, Symbol> alias_map = {
{aten::absolute, aten::abs}, {aten::absolute_, aten::abs_},
{aten::clip, aten::clamp}, {aten::clip_, aten::clamp_},
{aten::linalg_det, aten::det}, {aten::ger, aten::outer},
{aten::arccos, aten::acos}, {aten::arccos_, aten::acos_},
{aten::arcsin, aten::asin}, {aten::arcsin_, aten::asin_},
{aten::arctan, aten::atan}, {aten::arctan_, aten::atan_},
{aten::arccosh, aten::acosh}, {aten::arccosh_, aten::acosh_},
{aten::arcsinh, aten::asinh}, {aten::arcsinh_, aten::asinh_},
{aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_},
{aten::fix, aten::trunc}, {aten::fix_, aten::trunc_},
{aten::negative, aten::neg}, {aten::negative_, aten::neg_},
{aten::subtract, aten::sub}, {aten::subtract_, aten::sub_},
{aten::greater_equal, aten::ge}, {aten::greater_equal_, aten::ge_},
{aten::greater, aten::gt}, {aten::greater_, aten::gt_},
{aten::less_equal, aten::le}, {aten::less_equal_, aten::le_},
{aten::less, aten::lt}, {aten::less_, aten::lt_},
{aten::not_equal, aten::ne}, {aten::not_equal_, aten::ne_},
{aten::divide, aten::div}, {aten::divide_, aten::div_},
{aten::multiply, aten::mul}, {aten::multiply_, aten::mul_},
{aten::true_divide, aten::div}, {aten::true_divide_, aten::div_},
};
return alias_map;
}
void NormalizeOps(const std::shared_ptr<Graph>& graph) {
NormalizeOps(graph->block());
}
} // namespace jit
} // namespace torch