Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fx quant: add types to observed_module.py #49607

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 7 additions & 5 deletions torch/quantization/fx/observed_module.py
@@ -1,6 +1,8 @@
import torch
import copy
from torch.fx import GraphModule # type: ignore
from torch.fx.graph import Graph
from typing import Union, Dict, Any

class ObservedGraphModule(GraphModule):

Expand All @@ -10,7 +12,7 @@ def get_preserved_attr_names(self):
'_qconfig_map',
'_prepare_custom_config_dict']

def __init__(self, root, graph):
def __init__(self, root: Union[torch.nn.Module, Dict[str, Any]], graph: Graph):
preserved_attrs = dict()
for attr in self.get_preserved_attr_names():
preserved_attrs[attr] = getattr(root, attr)
Expand All @@ -26,10 +28,10 @@ def __deepcopy__(self, memo):
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedGraphModule(fake_mod, self.graph)

def mark_observed_module(module):
def mark_observed_module(module: GraphModule) -> GraphModule:
return ObservedGraphModule(module, module.graph)

def is_observed_module(module):
def is_observed_module(module: Any) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this be GraphModule as well?

return isinstance(module, ObservedGraphModule)

class ObservedStandaloneGraphModule(ObservedGraphModule):
Expand All @@ -38,8 +40,8 @@ def __deepcopy__(self, memo):
fake_mod.__dict__ = copy.deepcopy(self.__dict__)
return ObservedStandaloneGraphModule(fake_mod, self.graph)

def mark_observed_standalone_module(module):
def mark_observed_standalone_module(module: GraphModule) -> GraphModule:
return ObservedStandaloneGraphModule(module, module.graph)

def is_observed_standalone_module(module):
def is_observed_standalone_module(module: Any) -> bool:
return isinstance(module, ObservedStandaloneGraphModule)