/
exception_elimination.cpp
180 lines (156 loc) · 5.61 KB
/
exception_elimination.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
#include "torch/csrc/jit/ir/alias_analysis.h"
#include "torch/csrc/jit/jit_log.h"
#include "torch/csrc/jit/passes/constant_pooling.h"
#include "torch/csrc/jit/passes/constant_propagation.h"
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/jit/passes/guard_elimination.h"
#include "torch/csrc/jit/passes/peephole.h"
#include "torch/csrc/jit/runtime/graph_executor.h"
#include "core/util/prelude.h"
#include <vector>
namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
namespace {
using namespace torch::jit;
struct ExceptionOrPassPatternElimination {
ExceptionOrPassPatternElimination(std::shared_ptr<Graph> graph) : graph_(std::move(graph)) {}
void run() {
findExceptionOrPassNodes(graph_->block());
torch::jit::EliminateDeadCode(graph_);
LOG_GRAPH("Post exeception or pass elimination: " << *graph_);
}
private:
bool isExceptionOrPassNode(Node* n) {
if (n->blocks().size() != 2) {
return false;
}
auto arm1 = n->blocks()[0];
auto arm2 = n->blocks()[1];
if (arm1->outputs().size() != 0 || arm2->outputs().size() != 0) {
// Make sure that the node doesn't actually produce any Value that are
// used by other nodes
return false;
}
auto arm1_start = arm1->nodes().begin();
auto arm2_start = arm2->nodes().begin();
bool arm1_starts_with_exception = (*arm1_start)->kind() == prim::RaiseException;
bool arm2_starts_with_exception = (*arm2_start)->kind() == prim::RaiseException;
// if (!arm1_starts_with_exception && !arm2_starts_with_exception) {
// Neither arm matches the pattern
// return false;
//}
/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// = prim::RaiseException(%45)
/// -> ()
/// block1():
/// -> ()
if (arm1_starts_with_exception) {
if ((*(++arm1_start))->kind() == prim::Return) {
// Make sure that block0 is solely just the exception and the return
if ((*(arm2_start))->kind() == prim::Return) {
// Make sure that block1 is solely the return
return true;
}
}
}
/// Check if this Node hosts a pattern like so:
/// = prim::If(%5958)
/// block0():
/// -> ()
/// block1():
/// = prim::RaiseException(%45)
/// -> ()
if (arm2_starts_with_exception) {
if ((*(++arm2_start))->kind() == prim::Return) {
// Make sure that block1 is solely just the exception and the return
if ((*(arm1_start))->kind() == prim::Return) {
// Make sure that block0 is solely the return
return true;
}
}
}
return false;
}
void findExceptionOrPassNodes(Block* b) {
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
auto n = *it;
if (n->kind() == prim::If && isExceptionOrPassNode(n)) {
LOG_GRAPH("Found that node " << *n << " is an exception or pass node (EliminateChecks)" << std::endl);
it.destroyCurrent();
}
}
}
std::shared_ptr<Graph> graph_;
};
} // namespace
void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
ExceptionOrPassPatternElimination eppe(std::move(graph));
eppe.run();
if (graph) {
LOG_GRAPH("Post Eliminate Exception or Pass Patterns: " << *graph);
}
}
/*
Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
so as to not invalidate the IR in challenging cases, such as nested Ifs
Original Source from which it was adapted:
https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
*/
bool certainlyThrows(Block* block) {
// A block certainly throws an exception if it contains
// the prim::RaiseException operation
for (Node* n : block->nodes()) {
if (n->kind() == prim::RaiseException) {
return true;
}
}
return false;
}
void EliminateExceptionsSafe(Block* block) {
auto graph = block->owningGraph();
// Generate false and true constant placeholders
Value* false_const = graph->insertConstant(IValue(false));
Value* true_const = graph->insertConstant(IValue(true));
// For each prim::If node, if either block certainly throws an exception,
// replace input conditional of the node input with the logical opposite
for (Node* n : block->nodes()) {
if (n->kind() == prim::If) {
Block* true_block = n->blocks()[0];
Block* false_block = n->blocks()[1];
bool removed_exception = false;
Value* input_value_replacement;
// If the block throws an exception, replace input with logical opposite
if (certainlyThrows(true_block)) {
removed_exception = true;
input_value_replacement = false_const;
} else if (certainlyThrows(false_block)) {
removed_exception = true;
input_value_replacement = true_const;
}
// Log node and perform input replacement
if (removed_exception) {
LOG_WARNING("Detected and removing exception in TorchScript IR for node: " << util::node_info(n));
n->insertInput(0, input_value_replacement);
n->removeInput(1);
}
}
// Inspect and replace all instances within subblocks of the current node
for (Block* subblock : n->blocks()) {
EliminateExceptionsSafe(subblock);
}
}
}
void EliminateExceptionsSafe(std::shared_ptr<Graph>& graph) {
EliminateExceptionsSafe(graph->block());
ConstantPropagation(graph);
ConstantPooling(graph);
}
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt