Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantized model using boolean_dispatch not picklable #60210

Open
kev-zheng opened this issue Jun 17, 2021 · 6 comments
Open

Quantized model using boolean_dispatch not picklable #60210

kev-zheng opened this issue Jun 17, 2021 · 6 comments
Assignees
Labels
low priority We're unlikely to get around to doing this in the near future oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@kev-zheng
Copy link
Contributor

kev-zheng commented Jun 17, 2021

馃悰 Bug

Follow-up bug from #57352. Model quantized using prepare_fx is not picklable due to this boolean_dispatch function

Traceback (most recent call last):
  File "torch_pickle.py", line 45, in <module>
    pickle.dumps(model)
AttributeError: Can't pickle local object 'boolean_dispatch.<locals>.fn'

To Reproduce

Same steps. Use a recent pytorch nightly.

Version:

>> python -c "import torch; print(torch.__version__)"
1.10.0.dev20210617
import torch
import pickle

from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_qat_fx

class TwoLayerNet(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super(TwoLayerNet, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        h_relu = self.linear1(x).clamp(min=0)
        y_pred = self.linear2(h_relu)
        return y_pred


# N is batch size; D_in is input dimension;
# H is hidden dimension; D_out is output dimension.
N, D_in, H, D_out = 64, 1000, 100, 10

# Create random Tensors to hold inputs and outputs
x = torch.randn(N, D_in)
y = torch.randn(N, D_out)

# Construct our model by instantiating the class defined above
model = TwoLayerNet(D_in, H, D_out)

# Try to pickle quantized model
qconfig = get_default_qconfig("qnnpack")
qconfig_dict = {"": qconfig}

model = prepare_qat_fx(model, qconfig_dict)

pickle.dumps(model)

Expected behavior

Empty output; pickle.dumps should succeed

Environment

PyTorch Version (e.g., 1.0): pytorch=1.10.0.dev20210617   (py3.6_0  pytorch-nightly)
OS (e.g., Linux): OSX
How you installed PyTorch (conda, pip, source): conda
Build command you used (if compiling from source): N/A
Python version: 3.6.10
CUDA/cuDNN version: cpu only
GPU models and configuration:
Any other relevant information:

Additional context

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel

@kev-zheng
Copy link
Contributor Author

cc @jerryzh168 @vkuzo

@vkuzo
Copy link
Contributor

vkuzo commented Jun 17, 2021

This is because functions which use boolean_dispatch are currently non-pickleable, and quantization is defined for some of these functions. Here is a simpler repro:

import pickle
pickle.dumps({torch.nn.functional.max_pool3d})

...

Traceback (most recent call last):
  File "/home/vasiliy/local/tmp/test.py", line 23, in <module>
    print(pickle.dumps(foo))
AttributeError: Can't pickle local object 'boolean_dispatch.<locals>.fn'

@vkuzo
Copy link
Contributor

vkuzo commented Jun 17, 2021

cc @driazati as the original author of boolean dispatch, would you know who can help take a look?

@zou3519 zou3519 added the oncall: quantization Quantization support in PyTorch label Jun 17, 2021
@github-actions github-actions bot added this to Need Triage in Quantization Triage Jun 17, 2021
@vkuzo
Copy link
Contributor

vkuzo commented Jun 22, 2021

if you are not using any ops which use boolean_dispatch in your model, you can use this workaround:

keys_to_delete = []             
for p in model._patterns:       
    try:                        
        pickle.dumps(p)         
    except Exception:           
        keys_to_delete.append(p)
                                
for k in keys_to_delete:        
    del model._patterns[k]      

# runs without errors
pickle.dumps(model)

@driazati
Copy link
Contributor

The problem here is that boolean_dispatch returns a local function which pickle can't import on loading so it just gives up. We could fix this in a couple ways, I don't know enough about quantization or the state of torchscript today to say which would be best though. I see #53180 is coming along, if that were to land we could probably delete boolean_dispatch altogether since it was originally just a hack around the fact that we have functions with different return types based on their arguments. In that case the functions would just be normal nn.functional functions again that pickle is okay with, so that would fix this (and probably make everyone's Python type checkers a little happier)

cc @ansley

@vkuzo vkuzo added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 20, 2021
@vkuzo
Copy link
Contributor

vkuzo commented Jul 20, 2021

we discussed offline, someone on @gmagogsfm 's team will take a look

@gmagogsfm gmagogsfm assigned ansley and unassigned gmagogsfm Jul 21, 2021
@andrewor14 andrewor14 added the low priority We're unlikely to get around to doing this in the near future label Nov 17, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low priority We're unlikely to get around to doing this in the near future oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Quantization Triage
  
Need Triage
Development

No branches or pull requests

7 participants