Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fusion in fx graph mode did not take care of direct attribute access #68892

Open
jerryzh168 opened this issue Nov 24, 2021 · 2 comments
Open
Assignees
Labels
low priority We're unlikely to get around to doing this in the near future oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jerryzh168
Copy link
Contributor

jerryzh168 commented Nov 24, 2021

🐛 Bug

att

To Reproduce

from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx
import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(5, 5)
    
    def forward(self, x):
        b = self.linear.bias
        return torch.nn.functional.relu(self.linear(x)) + b

m = M().eval()
m = prepare_fx(m, {"": torch.ao.quantization.default_qconfig})
print(m)
m = convert_fx(m)

Output

'LinearReLU' object has no attribute 'bias'
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-2-e6547926d9ec> in <module>
     12 
     13 m = M().eval()
---> 14 m = prepare_fx(m, {"": torch.ao.quantization.default_qconfig})
     15 print(m)
     16 m = convert_fx(m)
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/ao/quantization/quantize_fx.py in prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict)
    498     torch._C._log_api_usage_once("quantization_api.quantize_fx.prepare_fx")
    499     assert not model.training, "prepare_fx only works for models in " + "eval mode"
--> 500     return _prepare_fx(
    501         model,
    502         qconfig_dict,
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/ao/quantization/quantize_fx.py in _prepare_fx(model, qconfig_dict, prepare_custom_config_dict, equalization_qconfig_dict, backend_config_dict, is_standalone_module)
    232     for attr_name in preserved_attributes:
    233         setattr(graph_module, attr_name, getattr(model, attr_name))
--> 234     graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
    235     prepared = prepare(
    236         graph_module,
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/ao/quantization/quantize_fx.py in _fuse_fx(graph_module, fuse_custom_config_dict)
     56     _check_is_graph_module(graph_module)
     57     fuser = Fuser()
---> 58     return fuser.fuse(graph_module, fuse_custom_config_dict)
     59 
     60 
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/ao/quantization/fx/fuse.py in fuse(self, model, fuse_custom_config_dict)
     60 
     61         preserved_attributes = set(fuse_custom_config_dict.get("preserved_attributes", []))
---> 62         model = FusedGraphModule(input_root, self.fused_graph, preserved_attributes)
     63         return model
     64 
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/ao/quantization/fx/graph_module.py in __init__(self, root, graph, preserved_attr_names)
      9         self.preserved_attr_names = preserved_attr_names
     10         preserved_attrs = {attr: getattr(root, attr) for attr in self.preserved_attr_names if hasattr(root, attr)}
---> 11         super().__init__(root, graph)
     12         for attr in preserved_attrs:
     13             setattr(self, attr, preserved_attrs[attr])
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/fx/graph_module.py in __init__(self, root, graph, class_name)
    288                 if node.op in ['get_attr', 'call_module']:
    289                     assert isinstance(node.target, str)
--> 290                     _copy_attr(root, self, node.target)
    291         elif isinstance(root, dict):
    292             targets_to_copy = []
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/fx/graph_module.py in _copy_attr(from_module, to_module, target)
    196         from_module, to_module = f, t
    197 
--> 198     orig = getattr(from_module, field)
    199     # If it is a tensor and not a parameter attribute of a module, it should be a named buffer.
    200     # So, we register it as a named buffer in the target module.
/mnt/xarfuse/uid-127034/b9cbaed7-seed-nspid4026531836_cgpid354053-ns-4026531840/torch/nn/modules/module.py in __getattr__(self, name)
   1183             if name in modules:
   1184                 return modules[name]
-> 1185         raise AttributeError("'{}' object has no attribute '{}'".format(
   1186             type(self).__name__, name))
   1187 
AttributeError: 'LinearReLU' object has no attribute 'bias'

Expected behavior

Should run successfully.

Environment

pytorch master

cc @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @jerryzh168

@jerryzh168 jerryzh168 added oncall: quantization Quantization support in PyTorch and removed module: fx labels Nov 24, 2021
@github-actions github-actions bot added this to Need Triage in Quantization Triage Nov 24, 2021
@jerryzh168 jerryzh168 changed the title fusion in fx graph mode did not take of direct attribute access fusion in fx graph mode did not take care of direct attribute access Dec 13, 2021
@vkuzo vkuzo added the low priority We're unlikely to get around to doing this in the near future label Feb 23, 2022
@andrewor14 andrewor14 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 18, 2023
@parakh08
Copy link

parakh08 commented Apr 3, 2024

I am also facing similar issue. Did you get a solution?

@AndroidDevelopersTools
Copy link

I am also facing similar issue. Did you get a solution?

Do you find any solutions?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low priority We're unlikely to get around to doing this in the near future oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Quantization Triage
  
Need Triage
Development

No branches or pull requests

6 participants