-
Notifications
You must be signed in to change notification settings - Fork 339
/
remove_dropout.cpp
62 lines (51 loc) · 1.91 KB
/
remove_dropout.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "core/util/prelude.h"
namespace torch_tensorrt {
namespace core {
namespace lowering {
namespace passes {
// Schemas for dropout variants
const std::unordered_set<c10::Symbol> DropoutNodeKinds = {
c10::Symbol::fromQualString("aten::dropout"),
c10::Symbol::fromQualString("aten::dropout_"),
c10::Symbol::fromQualString("aten::feature_dropout"),
c10::Symbol::fromQualString("aten::feature_dropout_"),
c10::Symbol::fromQualString("aten::feature_alpha_dropout"),
c10::Symbol::fromQualString("aten::feature_alpha_dropout_"),
};
void removeDropoutInBlock(torch::jit::Block* block) {
/*
Function adapted from:
torch/csrc/jit/passes/remove_dropout.cpp
Modified for conciseness, documentation, and allowing new variants of dropout operators to be quickly added
*/
std::vector<torch::jit::Node*> dropout_nodes_to_remove;
for (auto node : block->nodes()) {
// Remove dropout for each member block within a node
for (auto block : node->blocks()) {
removeDropoutInBlock(block);
}
// For each node having a dropout-variant Schema, remove the node
if (DropoutNodeKinds.find(node->kind()) != DropoutNodeKinds.end()) {
// Extract input and output tensors of dropout operator
auto input_value = node->inputs()[0];
auto output_value = node->outputs()[0];
output_value->replaceAllUsesWith(input_value);
dropout_nodes_to_remove.push_back(node);
}
}
// Delete dropout nodes
for (auto del_node : dropout_nodes_to_remove) {
del_node->destroy();
}
}
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
// Remove all instances of dropout variants from graph
removeDropoutInBlock(graph->block());
torch::jit::EliminateDeadCode(graph);
LOG_GRAPH("Post remove dropout: " << *graph);
}
} // namespace passes
} // namespace lowering
} // namespace core
} // namespace torch_tensorrt