-
Notifications
You must be signed in to change notification settings - Fork 22k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Post Freezing Optimizations, turn on by default in torch.jit.free…
…ze (#50222) Summary: Pull Request resolved: #50222 This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal. I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything. I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations. Test Plan: Imported from OSS Reviewed By: tugsbayasgalan Differential Revision: D25856264 Pulled By: eellison fbshipit-source-id: 56be1f12cfc459b4c4421d4dfdedff8b9ac77112
- Loading branch information
1 parent
30aeed7
commit a389b30
Showing
8 changed files
with
110 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h> | ||
#include <torch/csrc/jit/ir/alias_analysis.h> | ||
#include <torch/csrc/jit/ir/ir_views.h> | ||
#include <torch/csrc/jit/passes/frozen_conv_folding.h> | ||
#include <torch/csrc/jit/runtime/graph_executor.h> | ||
#include <torch/csrc/utils/memory.h> | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) { | ||
// run a couple times to capture Conv -> Mul -> Add etc | ||
for (size_t i = 0; i < 2; i++) { | ||
FoldFrozenConvBatchnorm(graph); | ||
FoldFrozenConvAddOrSub(graph); | ||
FoldFrozenConvMulOrDiv(graph); | ||
} | ||
} | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#pragma once | ||
|
||
#include <torch/csrc/jit/ir/ir.h> | ||
|
||
/** \brief Runs a set of Optimizations that Optimize Frozen Graphs | ||
* | ||
* Currently this set of optimizations is: | ||
* - FoldFrozenConvBatchnorm | ||
* - FoldFrozenConvAddOrSub | ||
* - FoldFrozenConvMulOrDiv | ||
*/ | ||
|
||
namespace torch { | ||
namespace jit { | ||
|
||
TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph); | ||
|
||
} // namespace jit | ||
} // namespace torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters