Skip to content

Commit

Permalink
[dynamo] add source.name() in node meta nn_module_stack
Browse files Browse the repository at this point in the history
  • Loading branch information
ydwu4 committed Feb 17, 2023
1 parent c16b291 commit 4608186
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 22 deletions.
3 changes: 2 additions & 1 deletion torch/_dynamo/variables/nn_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,9 @@ def call_function(

@contextmanager
def record_nn_module_stack():
fully_qualified_name = self.source.name()
try:
tx.nn_module_stack[self.module_key] = type(mod)
tx.nn_module_stack[self.module_key] = (fully_qualified_name, type(mod))
yield
finally:
del tx.nn_module_stack[self.module_key]
Expand Down
18 changes: 0 additions & 18 deletions torch/ao/quantization/_pt2e/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,8 @@
from torch.ao.quantization.fx.prepare import (
_is_activation_post_process_node,
)
from collections import OrderedDict
import operator

# TODO[qihan]: longer term, this should happen in the dynamo stack as well
def _get_renamed_nn_module_stack(nn_module_stack):
# initialize with top level parent scope
nn_module_stack_renamed = OrderedDict([("", None)])
if nn_module_stack:
# Rename module_key, e.g. "self_layer1_1__conv1" to "self.layer1.1._conv1", for easier downstream parsing
prev_key = ""
for key, value in nn_module_stack.items():
if not prev_key:
if key.startswith("self_"):
new_key = key[5:]
prev_key = new_key
else:
new_key = prev_key + "." + key[len(prev_key) + 6 :]
nn_module_stack_renamed[new_key] = value
prev_key = new_key
return nn_module_stack_renamed

def _get_tensor_constant_from_node(node, m):
if node is None:
Expand Down
8 changes: 5 additions & 3 deletions torch/ao/quantization/_quantize_pt2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from .fx import prepare
from .quantize_fx import _convert_to_reference_decomposed_fx
from ._pt2e.utils import (
_get_renamed_nn_module_stack,
_fuse_conv_bn_,
_rearrange_weight_observer_for_addmm,
)
Expand All @@ -21,8 +20,11 @@ def prepare_pt2e(
# TODO: move this information to fx node itself
node_name_to_scope: Dict[str, Tuple[str, type]] = {}
for n in model.graph.nodes:
renamed_stack = _get_renamed_nn_module_stack(n.meta.get("nn_module_stack", None))
current_scope = list(renamed_stack.items())[-1]
nn_module_stack = n.meta.get("nn_module_stack", None)
current_scope = ("", type(None))
if nn_module_stack:
bt = list(nn_module_stack.values())[-1]
current_scope = (bt[0].split(".")[-1], bt[1])
node_name_to_scope[n.name] = current_scope

# TODO: check qconfig_mapping to make sure conv and bn are both configured
Expand Down

0 comments on commit 4608186

Please sign in to comment.