Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

FX graph mode quantization does not support F.linear and F.conv{n}d with kwargs only #87686

Closed
wong00 opened this issue Oct 25, 2022 · 8 comments
Labels
oncall: quantization Quantization support in PyTorch

Comments

@wong00
Copy link

wong00 commented Oct 25, 2022

馃悰 Describe the bug

I used the from quantize_fx module to quantify my own model, the code is:

if name == "main":
    checkpoint1 = torch.load('/mnt/vox-cpk.pth.tar')
    with open('config/vox-256.yaml') as f:
    config = yaml.safe_load(f)
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
    **config['model_params']['common_params'])
    if torch.cuda.is_available():
    kp_detector.to(cuda_device)
    kp_detector.load_state_dict(checkpoint1['kp_detector'])
    generator = OcclusionAwareGenerator(**config['model_params']['generator_params'],
    **config['model_params']['common_params'])
    if torch.cuda.is_available():
    generator.to(cuda_device)
    generator.load_state_dict(checkpoint1['generator'])
    train_params = config['train_params']

    model_fp32 = GeneratorFullModel(kp_detector, generator, None, train_params)
    
    model_to_quantize = copy.deepcopy(model_fp32)
    model_to_quantize.eval()
    qconfig = get_default_qconfig("fbgemm")
    qconfig_dict = {"": qconfig}
    prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
    # calibrate(prepared_model, data_loader)
    quantized_model = convert_fx(prepared_model)
    print(quantized_model)

_the GeneratorFullModel is consist of two separate models, and the error is: _

Traceback (most recent call last):
File "/home/lab239-5/users/wangxin/first-order-model-master/ccc.py", line 115, in
quantized_model = convert_fx(prepared_model)
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 658, in convert_fx
return _convert_fx(
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/quantize_fx.py", line 563, in _convert_fx
quantized = convert(
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/convert.py", line 754, in convert
model = lower_to_fbgemm(model, qconfig_map, node_name_to_scope)
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/lower_to_fbgemm.py", line 14, in lower_to_fbgemm
return _lower_to_native_backend(model, qconfig_map, node_name_to_scope)
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py", line 958, in _lower_to_native_backend
_lower_static_weighted_ref_functional(model, qconfig_map)
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py", line 606, in _lower_static_weighted_ref_functional
(q_node, relu_node, func_node) = _match_static_pattern(
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/ao/quantization/fx/_lower_to_native_backend.py", line 450, in _match_static_pattern
assert i < len(ref_node.args),
AssertionError: Dequantize index 1 exceeded reference node's arg length 1

Versions

pytorch==1.12.1

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo

@malfet malfet added the oncall: quantization Quantization support in PyTorch label Oct 25, 2022
@vkuzo
Copy link
Contributor

vkuzo commented Oct 27, 2022

This sounds like a bug in FX graph mode quantization. Can you provide a reproducible example to help us look into it?

@wong00
Copy link
Author

wong00 commented Oct 27, 2022

This sounds like a bug in FX graph mode quantization. Can you provide a reproducible example to help us look into it?

I upload my code on github.The link is
https://github.com/wong00/Quantization-of-FOM-model
Thanks

@vkuzo
Copy link
Contributor

vkuzo commented Oct 28, 2022

thank you!

I narrowed it down to a small reproducible example:

import copy                                                          
import torch                                                         
import torch.nn as nn                                                
import torch.fx                                                      
import torch.nn.functional as F                                      
import torch.ao.quantization.quantize_fx as quantize_fx              
                                                                     
import torch                                                         
from torch.ao.quantization import get_default_qconfig_mapping        
from torch.quantization.quantize_fx import prepare_fx, convert_fx    
import copy                                                          
                                                                     
class M(torch.nn.Module):                                            
    def __init__(self):                                              
        super().__init__()                                           
        self.w = torch.nn.Parameter(torch.randn(1, 1))               
                                                                     
    def forward(self, x):                                            
        x = F.linear(input=x, weight=self.w)                         
        return x                                                     
                                                                     
m = M()                                                              
mp = quantize_fx.prepare_fx(                                         
    m, get_default_qconfig_mapping('fbgemm'), (torch.randn(1, 1),))  
mq = quantize_fx.convert_fx(mp)                                      

This fails because the lowering code is assuming that ops such as F.linear, F.conv{n}d are using args and not kwargs for input and weight. We will fix this, thanks for the report!

@vkuzo
Copy link
Contributor

vkuzo commented Oct 28, 2022

@jerryzh168 , @andrewor14 , what's the latest on arg/kwarg normalization? I remember we talked about running this in the very beginning after capturing the FX graph, but not sure if it was rolled out.

@jerryzh168
Copy link
Contributor

@jerryzh168 , @andrewor14 , what's the latest on arg/kwarg normalization? I remember we talked about running this in the very beginning after capturing the FX graph, but not sure if it was rolled out.

no plans to do it in the IR using torch ops, since current normalization (https://github.com/pytorch/pytorch/blob/master/torch/fx/experimental/normalize.py) have some corner cases like override and no one is working on that.

We'll get normalization automatically in the new PT2 Export stack.

Maybe we can just support normalization for a few selected ops like F.linear F.conv2d for now to unblock?

@vkuzo
Copy link
Contributor

vkuzo commented Oct 31, 2022

Hi @wong00 , I created a manual normalization pass you could try on your model: https://gist.github.com/vkuzo/21c1ae37a262696faa5914843187426a . Could you see if it fixes your use case?

We aren't sure if we will be able to check this into PyTorch in the near future because we are working on a new program capture frontend which should resolve these issues in a cleaner way.

@vkuzo vkuzo changed the title AssertionError: Dequantize index 1 exceeded reference node's arg length 1 FX graph mode quantization does not support F.linear and F.conv{n}d with kwargs only Oct 31, 2022
@wong00
Copy link
Author

wong00 commented Nov 2, 2022

Hi @wong00 , I created a manual normalization pass you could try on your model: https://gist.github.com/vkuzo/21c1ae37a262696faa5914843187426a . Could you see if it fixes your use case?

We aren't sure if we will be able to check this into PyTorch in the near future because we are working on a new program capture frontend which should resolve these issues in a cleaner way.

Thanks for your solution, I tried the normalize_conv_linear method, and meet this error:

Use predefined train-test split.
%conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%pad,), kwargs = {weight: %kp_extractor_down_weight, groups: 3})
Traceback (most recent call last):
File "/home/lab239-5/users/wangxin/first-order-model-fx/ccc.py", line 140, in
model_to_quantize = normalize_conv_linear(model_to_quantize)
File "/home/lab239-5/users/wangxin/first-order-model-fx/ccc.py", line 74, in normalize_conv_linear
norm_args, norm_kwargs = n.normalized_arguments(mt)
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/fx/node.py", line 564, in normalized_arguments
return normalize_function(self.target, self.args, self.kwargs, arg_types, kwarg_types) # type: ignore[arg-type]
File "/home/lab239-5/users/wangxin/anaconda3/envs/pytorch1.12.1/lib/python3.10/site-packages/torch/fx/operator_schemas.py", line 324, in normalize_function
raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but '
RuntimeError: Tried to normalize arguments to torch._VariableFunctionsClass.conv2d but the schema match was ambiguous! Please provide argument types to the normalize_arguments() call. Available schemas:
(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: List[int] = [1, 1], padding: List[int] = [0, 0], dilation: List[int] = [1, 1], groups: int = 1) -> torch.Tensor
(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: List[int] = [1, 1], padding: str = 'valid', dilation: List[int] = [1, 1], groups: int = 1) -> torch.Tensor

@jerryzh168
Copy link
Contributor

this won't be fixed since we are moving to the new pytorch 2.0 export quantization flow and this should be supported in the new flow

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: quantization Quantization support in PyTorch
Projects
None yet
Development

No branches or pull requests

4 participants