Skip to content
Closed
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
87f6eff
[quant][graphmode][refactor] swapDeQuant takes block as arugment
jerryzh168 Mar 20, 2020
b4a22de
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 21, 2020
a8f4b6c
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 22, 2020
74e6dcc
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 23, 2020
8a26f47
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 23, 2020
47108b2
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 24, 2020
a5daf24
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 24, 2020
9f66599
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 24, 2020
2312f52
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 24, 2020
01b9e07
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 24, 2020
9d1a54a
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 25, 2020
c198284
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 25, 2020
079cc05
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 25, 2020
b580f09
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 25, 2020
fe723bd
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 26, 2020
849913e
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 27, 2020
852a531
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 27, 2020
f4d4154
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 27, 2020
b0708ac
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 27, 2020
7a15a20
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 27, 2020
561627f
Update on "[quant][graphmode][refactor] swapDeQuant takes block as ar…
jerryzh168 Mar 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 38 additions & 41 deletions torch/csrc/jit/passes/quantization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2221,6 +2221,43 @@ void addBiasForConv2dIfNone(Module& module) {
}
}

void swapDeQuant(Block* block) {
auto graph = block->owningGraph();
for (Node* n : block->nodes()) {
auto input_indexes = getGeneralOpTensorInputIndexes(n);
if (input_indexes.size() > 0) {
bool is_dequantized = true;
for (auto i : input_indexes) {
is_dequantized &= n->inputs()[i]->node()->kind() == Symbol::aten("dequantize");
}
if (!is_dequantized) {
continue;
}
// Delete dequantize node, we have one dequantize
// for each use of the value
for (auto i : input_indexes) {
auto* dequantized_val = n->inputs()[i];
auto* dequantize_node = dequantized_val->node();
TORCH_INTERNAL_ASSERT(dequantized_val->uses().size() == 1,
"Expect to have one dequantize node for each use");
// Replace useses of dequantized_val with the input of
// dequantize node
dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
dequantize_node->removeAllInputs();
dequantize_node->destroy();
}
for (auto* output: n->outputs()) {
std::vector<Use> uses = output->uses();
// Insert new dequantize node for each use of the output
insertDeQuantCall(graph, output, output, uses);
}
}
for (Block* subblock : n->blocks()) {
swapDeQuant(subblock);
}
}
}

} // namespace

TORCH_API Module InsertObservers(
Expand Down Expand Up @@ -2383,47 +2420,7 @@ void ReplicateDeQuant(std::shared_ptr<Graph>& graph) {
// for example: flatten, average_pool, upsample
// This is called after inline and before graph execution
void SwapDeQuant(std::shared_ptr<Graph>& graph) {
std::stack<Block*> blocks_to_visit;
blocks_to_visit.push(graph->block());
while (!blocks_to_visit.empty()) {
Block* b = blocks_to_visit.top();
blocks_to_visit.pop();
for (Node* n : b->nodes()) {
auto input_indexes = getGeneralOpTensorInputIndexes(n);
if (input_indexes.size() > 0) {
bool is_dequantized = true;
for (auto i : input_indexes) {
is_dequantized &=
n->inputs()[i]->node()->kind() == Symbol::aten("dequantize");
}
if (!is_dequantized) {
continue;
}
// Delete dequantize node, we have one dequantize
// for each use of the value
for (auto i : input_indexes) {
auto* dequantized_val = n->inputs()[i];
auto* dequantize_node = dequantized_val->node();
TORCH_INTERNAL_ASSERT(
dequantized_val->uses().size() == 1,
"Expect to have one dequantize node for each use");
// Replace useses of dequantized_val with the input of
// dequantize node
dequantized_val->replaceAllUsesWith(dequantize_node->inputs()[0]);
dequantize_node->removeAllInputs();
dequantize_node->destroy();
}
for (auto* output : n->outputs()) {
std::vector<Use> uses = output->uses();
// Insert new dequantize node for each use of the output
insertDeQuantCall(graph.get(), output, output, uses);
}
}
for (Block* subblock : n->blocks()) {
blocks_to_visit.push(subblock);
}
}
}
swapDeQuant(graph->block());
}

void QuantFusion(std::shared_ptr<Graph>& graph) {
Expand Down