Skip to content

Commit

Permalink
[WIP][dynamo] simplify module_key creation logic (#94945)
Browse files Browse the repository at this point in the history
After some thoughts, I find it difficult to come up with a robust naming convention that satisfies the following constraints at the same time: 1. the new name should be a valid nn.Moule attribute (as required by minifier and it's a good thing to have in general) 2. it can cover various cases such as GetItemSource, GetAttrSource 3. it's easy to recover the original path 4. robust to users' naming scheme.

Thanks to @yanboliang for pointing out the original access path is preserved in Source, now we just need to add an additonal value source.name() to node.meta["nn_module_stack"]  to get the access path in original module.

We also address some TODO in quantization, which relies on the original naming convention in nn_module_stack.

Pull Request resolved: #94945
Approved by: https://github.com/jansel, https://github.com/yanboliang
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Feb 20, 2023
1 parent 954c767 commit 4d753b5
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 25 deletions.
8 changes: 5 additions & 3 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import typing
import weakref
from collections.abc import Sized
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Tuple, Type
from unittest.mock import patch

import torch
Expand Down Expand Up @@ -1617,8 +1617,10 @@ def __init__(

# Execution record for replaying errors
self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options)
# Stack of module being parsed, current nn.module is at the end of ordered dict
self.nn_module_stack: Dict[str, str] = {}
# Stack of module being parsed, current nn.module is at the end of ordered dict.
# The first field of tuple is the fully qualified name of current module
# in original hierarchy. The second field is the type of current nn.module
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {}
# Flag to indicate whether tracing is used for export.
self.export = export

Expand Down
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ def call_function(

@contextmanager
def record_nn_module_stack():
fully_qualified_name = self.source.name()
try:
tx.nn_module_stack[self.module_key] = type(mod)
tx.nn_module_stack[self.module_key] = (fully_qualified_name, type(mod))
yield
finally:
del tx.nn_module_stack[self.module_key]
Expand Down
18 changes: 0 additions & 18 deletions torch/ao/quantization/_pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,8 @@
from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
from collections import OrderedDict
import operator

# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
# initialize with top level parent scope
nn_module_stack_renamed = OrderedDict([("", None)])
if nn_module_stack:
# Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
prev_key = ""
for key, value in nn_module_stack.items():
if not prev_key:
if key.startswith("self_"):
new_key = key[5:]
prev_key = new_key
else:
new_key = prev_key + "." + key[len(prev_key) + 6 :]
nn_module_stack_renamed[new_key] = value
prev_key = new_key
return nn_module_stack_renamed

def _get_tensor_constant_from_node(node, m):
if node is None:
Expand Down
8 changes: 5 additions & 3 deletions torch/ao/quantization/_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
_get_renamed_nn_module_stack,
_fuse_conv_bn_,
_rearrange_weight_observer_for_addmm,
)
Expand All @@ -21,8 +20,11 @@ def prepare_pt2e(
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
current_scope = list(renamed_stack.items())[-1]
nn_module_stack = n.meta.get("nn_module_stack", None)
current_scope = ("", type(None))
if nn_module_stack:
bt = list(nn_module_stack.values())[-1]
current_scope = (bt[0].split(".")[-1], bt[1])
node_name_to_scope[n.name] = current_scope

# TODO: check qconfig_mapping to make sure conv and bn are both configured
Expand Down

0 comments on commit 4d753b5

Please sign in to comment.