Skip to content

Commit

Permalink
Clean Up MobileOptimizerType Rewrite Flags Public API and Documentati…
Browse files Browse the repository at this point in the history
…on (#91600)

Summary:
X-link: facebookresearch/d2go#452

Remove MobileOptimizerType and all rewrite flags from torch.X and torch._C.X to clean up torch.X and torch._C.X namespaces

The affected rewrite flags are
- CONV_BN_FUSION
- FUSE_ADD_RELU
- HOIST_CONV_PACKED_PARAMS
- INSERT_FOLD_PREPACK_OPS
- REMOVE_DROPOUT
- VULKAN_AUTOMATIC_GPU_TRANSFER

Bc-Breaking Change:

Before this change, the rewrite flags were accessible through all of
1. torch.utils.mobile_optimizer.MobileOptimizerType.X
2. torch._C.MobileOptimizerType.X
3. torch.X
4. torch.MobileOptimizerType.X
5. torch._C.X

But after this change, only torch.utils.mobile_optimizer.MobileOptimizerType.X  (option 1 above) and the newly added torch._C._MobileOptimizerType.X remain

Corresponding updates to PyTorch Tutorial Docs are in pytorch/tutorials#2163

Test Plan:
```buck test caffe2/test:test_mobile_optimizer```
```
Summary
  Pass: 6
  Skip: 1
    ↻ caffe2/test:test_mobile_optimizer - test_mobilenet_optimize_for_mobile (test_mobile_optimizer.TestOptimizer)
  ListingSuccess: 1
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/4222124793514412
```
___

With temporary testing changes in D41690204:

```buck run caffe2:test_rewrite_flags_api```
Before:
```
torch.utils.mobile_optimizer.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ✅ | Result: ✅
torch._C._MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ✅ | Result: ❌ (module 'torch._C' has no attribute '_MobileOptimizerType')
torch._C.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ✅
torch.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ✅
torch.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ✅
torch._C.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ✅
```
After:
```
torch.utils.mobile_optimizer.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ✅ | Result: ✅
torch._C._MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ✅ | Result: ✅
torch._C.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ❌ (module 'torch._C' has no attribute 'MobileOptimizerType')
torch.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ❌ (module 'torch' has no attribute 'VULKAN_AUTOMATIC_GPU_TRANSFER')
torch.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ❌ (module 'torch' has no attribute 'MobileOptimizerType')
torch._C.VULKAN_AUTOMATIC_GPU_TRANSFER
        Expected: ❌ | Result: ❌ (module 'torch._C' has no attribute 'VULKAN_AUTOMATIC_GPU_TRANSFER')
```

```buck test caffe2/test:public_bindings -- test_no_new_bindings```
```
Summary
  Pass: 1
  ListingSuccess: 1
Finished test run: https://www.internalfb.com/intern/testinfra/testrun/7881299473114294
```

Differential Revision: D41690203

Pull Request resolved: #91600
Approved by: https://github.com/albanD, https://github.com/malfet
  • Loading branch information
salilsdesai authored and pytorchmergebot committed Jan 10, 2023
1 parent e9cd7e0 commit 370df96
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 31 deletions.
19 changes: 10 additions & 9 deletions docs/source/mobile_optimizer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,21 @@ torch.utils.mobile_optimizer
.. warning::
This API is in beta and may change in the near future.

Torch mobile supports ``torch.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode.
The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set and a preserved method list
Torch mobile supports ``torch.utils.mobile_optimizer.optimize_for_mobile`` utility to run a list of optimization pass with modules in eval mode.
The method takes the following parameters: a torch.jit.ScriptModule object, a blocklisting optimization set, a preserved method list, and a backend.

For CPU Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimizations:
- **Conv2D + BatchNorm fusion** (blocklisting option `MobileOptimizerType::CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated.
- **Insert and Fold prepacked ops** (blocklisting option `MobileOptimizerType::INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops.
- **Conv2D + BatchNorm fusion** (blocklisting option `mobile_optimizer.MobileOptimizerType.CONV_BN_FUSION`): This optimization pass folds ``Conv2d-BatchNorm2d`` into ``Conv2d`` in ``forward`` method of this module and all its submodules. The weight and bias of the ``Conv2d`` are correspondingly updated.
- **Insert and Fold prepacked ops** (blocklisting option `mobile_optimizer.MobileOptimizerType.INSERT_FOLD_PREPACK_OPS`): This optimization pass rewrites the graph to replace 2D convolutions and linear ops with their prepacked counterparts. Prepacked ops are stateful ops in that, they require some state to be created, such as weight prepacking and use this state, i.e. prepacked weights, during op execution. XNNPACK is one such backend that provides prepacked ops, with kernels optimized for mobile platforms (such as ARM CPUs). Prepacking of weight enables efficient memory access and thus faster kernel execution. At the moment ``optimize_for_mobile`` pass rewrites the graph to replace ``Conv2D/Linear`` with 1) op that pre-packs weight for XNNPACK conv2d/linear ops and 2) op that takes pre-packed weight and activation as input and generates output activations. Since 1 needs to be done only once, we fold the weight pre-packing such that it is done only once at model load time. This pass of the ``optimize_for_mobile`` does 1 and 2 and then folds, i.e. removes, weight pre-packing ops.
- **ReLU/Hardtanh fusion**: XNNPACK ops support fusion of clamping. That is clamping of output activation is done as part of the kernel, including for 2D convolution and linear op kernels. Thus clamping effectively comes for free. Thus any op that can be expressed as clamping op, such as ``ReLU`` or ``hardtanh``, can be fused with previous ``Conv2D`` or ``linear`` op in XNNPACK. This pass rewrites graph by finding ``ReLU/hardtanh`` ops that follow XNNPACK ``Conv2D/linear`` ops, written by the previous pass, and fuses them together.
- **Dropout removal** (blocklisting option `MobileOptimizerType::REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false.
- **Conv packed params hoisting** (blocklisting option `MobileOptimizerType::HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics.
- **Dropout removal** (blocklisting option `mobile_optimizer.MobileOptimizerType.REMOVE_DROPOUT`): This optimization pass removes ``dropout`` and ``dropout_`` nodes from this module when training is false.
- **Conv packed params hoisting** (blocklisting option `mobile_optimizer.MobileOptimizerType.HOIST_CONV_PACKED_PARAMS`): This optimization pass moves convolution packed params to the root module, so that the convolution structs can be deleted. This decreases model size without impacting numerics.
- **Add/ReLU fusion** (blocklisting option `mobile_optimizer.MobileOptimizerType.FUSE_ADD_RELU`): This pass finds instances of ``relu`` ops that follow ``add`` ops and fuses them into a single ``add_relu``.

for Vulkan Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the folllwing optimization:
- **Automatic GPU Transfer** (blocklisting option `MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER`): This optimization pass rewrites the graph such that inputs are transferred to Vulkan backend, and outputs are transferred to CPU backend
for Vulkan Backend, by default, if optimization blocklist is None or empty, ``optimize_for_mobile`` will run the following optimization:
- **Automatic GPU Transfer** (blocklisting option `mobile_optimizer.MobileOptimizerType.VULKAN_AUTOMATIC_GPU_TRANSFER`): This optimization pass rewrites the graph so that moving input and output data to and from the GPU becomes part of the model.

``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method. If you have other method to that needed to be preserved, add them into the preserved method list and pass into the method.
``optimize_for_mobile`` will also invoke freeze_module pass which only preserves ``forward`` method. If you have other method to that needed to be preserved, add them into the preserved method list and pass into the method.


.. currentmodule:: torch.utils.mobile_optimizer
Expand Down
4 changes: 2 additions & 2 deletions test/test_mobile_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from torch.testing._internal.jit_utils import get_forward, get_forward_graph
from torch.utils.mobile_optimizer import (LintCode,
generate_mobile_module_lints,
optimize_for_mobile)
optimize_for_mobile,
MobileOptimizerType)
from torch.nn import functional as F
from torch._C import MobileOptimizerType
from torch.testing._internal.common_quantized import override_quantized_engine

try:
Expand Down
7 changes: 0 additions & 7 deletions test/test_public_bindings.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def test_no_new_bindings(self):
"ComplexType",
"ConcreteModuleType",
"ConcreteModuleTypeBuilder",
"CONV_BN_FUSION",
"cpp",
"CudaBFloat16TensorBase",
"CudaBFloat16TensorBase",
Expand Down Expand Up @@ -113,7 +112,6 @@ def test_no_new_bindings(self):
"FloatType",
"fork",
"FunctionSchema",
"FUSE_ADD_RELU",
"Future",
"FutureType",
"Generator",
Expand All @@ -132,13 +130,11 @@ def test_no_new_bindings(self):
"has_mps",
"has_openmp",
"has_spectral",
"HOIST_CONV_PACKED_PARAMS",
"iinfo",
"import_ir_module_from_buffer",
"import_ir_module",
"InferredType",
"init_num_threads",
"INSERT_FOLD_PREPACK_OPS",
"InterfaceType",
"IntType",
"SymFloatType",
Expand All @@ -159,7 +155,6 @@ def test_no_new_bindings(self):
"LoggerBase",
"memory_format",
"merge_type_from_type_comment",
"MobileOptimizerType",
"ModuleDict",
"Node",
"NoneType",
Expand All @@ -176,7 +171,6 @@ def test_no_new_bindings(self):
"PyTorchFileWriter",
"qscheme",
"read_vitals",
"REMOVE_DROPOUT",
"RRefType",
"ScriptClass",
"ScriptClassFunction",
Expand Down Expand Up @@ -261,7 +255,6 @@ def test_no_new_bindings(self):
"set_num_threads",
"unify_type_list",
"vitals_enabled",
"VULKAN_AUTOMATIC_GPU_TRANSFER",
"wait",
"Tag",
}
Expand Down
18 changes: 9 additions & 9 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ class Future(object):
def _jit_set_num_profiled_runs(num: _size) -> _size: ...

# Defined in torch/csrc/jit/passes/mobile_optimizer_type.h
class MobileOptimizerType:
class _MobileOptimizerType:
...

CONV_BN_FUSION: MobileOptimizerType
INSERT_FOLD_PREPACK_OPS: MobileOptimizerType
REMOVE_DROPOUT: MobileOptimizerType
FUSE_ADD_RELU: MobileOptimizerType
HOIST_CONV_PACKED_PARAMS: MobileOptimizerType
VULKAN_AUTOMATIC_GPU_TRANSFER: MobileOptimizerType
CONV_BN_FUSION: _MobileOptimizerType
INSERT_FOLD_PREPACK_OPS: _MobileOptimizerType
REMOVE_DROPOUT: _MobileOptimizerType
FUSE_ADD_RELU: _MobileOptimizerType
HOIST_CONV_PACKED_PARAMS: _MobileOptimizerType
VULKAN_AUTOMATIC_GPU_TRANSFER: _MobileOptimizerType

def fork(*args: Any, **kwargs: Any) -> Future: ...
def wait(fut: Future) -> Any: ...
Expand Down Expand Up @@ -217,13 +217,13 @@ def _jit_get_operation(op_name: str) -> Tuple[Callable, List[str]]: ...
def _get_operation_overload(op_name: str, op_overload_name: str) -> Tuple[Callable, Callable, List[Any]]: ...
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],
optimization_blocklist: Set[_MobileOptimizerType],
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _clone_module_with_class(module: 'torch.jit.ScriptModule',
ignored_methods: List[AnyStr],
ignored_attributes: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _jit_pass_vulkan_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],
optimization_blocklist: Set[_MobileOptimizerType],
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
def _jit_pass_metal_optimize_for_mobile(module: 'torch.jit.ScriptModule',
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
Expand Down
5 changes: 2 additions & 3 deletions torch/csrc/jit/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ void initJITBindings(PyObject* module) {
"get_all_written_records",
&PyTorchStreamWriter::getAllWrittenRecords);

py::enum_<MobileOptimizerType>(m, "MobileOptimizerType")
py::enum_<MobileOptimizerType>(m, "_MobileOptimizerType")
.value("CONV_BN_FUSION", MobileOptimizerType::CONV_BN_FUSION)
.value(
"INSERT_FOLD_PREPACK_OPS",
Expand All @@ -1328,8 +1328,7 @@ void initJITBindings(PyObject* module) {
MobileOptimizerType::HOIST_CONV_PACKED_PARAMS)
.value(
"VULKAN_AUTOMATIC_GPU_TRANSFER",
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER)
.export_values();
MobileOptimizerType::VULKAN_AUTOMATIC_GPU_TRANSFER);

// This allows PyTorchStreamReader to read from a Python buffer. It requires
// that the buffer implement `seek()`, `tell()`, and `read()`.
Expand Down
2 changes: 1 addition & 1 deletion torch/utils/mobile_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch
from enum import Enum
from torch._C import MobileOptimizerType
from torch._C import _MobileOptimizerType as MobileOptimizerType
from typing import Optional, Set, List, AnyStr

class LintCode(Enum):
Expand Down

1 comment on commit 370df96

@pytorchmergebot
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted #91600 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally

Please sign in to comment.