Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

support LSTM for quantization aware training #42594

Open
Robin-Simmons opened this issue Aug 5, 2020 · 4 comments
Open

support LSTM for quantization aware training #42594

Robin-Simmons opened this issue Aug 5, 2020 · 4 comments
Labels
low priority We're unlikely to get around to doing this in the near future oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Robin-Simmons
Copy link

Robin-Simmons commented Aug 5, 2020

馃悰 Bug

LSTM network can not be evaluated after preparing for quantisation aware training. The same warning does not appear if evaluated before preparing.

To Reproduce

Steps to reproduce the behavior:

import numpy as np
import torch
import os
# Might not be necessary
os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True")

#import pytorch_GRU
#import quantGRUcell
device = torch.device("cpu")

class Net(torch.nn.Module):
    def __init__(self, seq_length):
        super(Net, self).__init__()

        self.hidden_size = 16
        self.input_size = 18

        self.seq_length = seq_length

        self.relu1 = torch.nn.ReLU()
        # Need to specify input sizes up front

        # batch_first specifies an input shape of (nBatches, nSeq, nFeatures),
        # otherwise this is (nSeq, nBatch, nFeatures)
        self.lstm = torch.nn.LSTM(input_size = self.input_size, hidden_size = self.hidden_size, batch_first = True)
        self.linear1 = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.dropout = torch.nn.Dropout(0.5)        #self.squeeze = torch.squeeze
        self.linearOut = torch.nn.Linear(self.hidden_size, 1)
        self.sigmoidOut = torch.nn.Sigmoid()
        self.sqeeze1 = torch.Tensor.squeeze
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
    def forward(self, x):
        print(type(x))
        x, (h,c) = self.lstm(x)#, self.h0)


        # Get last output, x[:,l - 1,:], equivalent to (last) hidden state
        # Squeeze to remove length 1 dim
        x = self.sqeeze1(h)

        x = self.dropout(x)

        x = self.linear1(x)
        x = self.relu1(x)
        x = self.linearOut(x)

        # Apply sigmoid either in the loss function, or in eval(...)
        return x
    def evaluate(self,x):
        return self.sigmoidOut(self.forward(x))

model = Net(100)
model = model.to(device)

lossF = torch.nn.BCEWithLogitsLoss() # Expects logits (i.e., without sigmoid)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
#error appears with or without this line
#model.eval()
model = torch.quantization.prepare_qat(model)
out = model(torch.rand(1,100,18))

Error:

Traceback (most recent call last):
  File "rnn_min_error.py", line 71, in <module>
    out = model(torch.rand(1,100,18))
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "rnn_min_error.py", line 42, in forward
    x, (h,c) = self.lstm(x)#, self.h0)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
    hook_result = hook(self, input, result)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/quantization/quantize.py", line 74, in _observer_forward_hook
    return self.activation_post_process(output)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/mnt/storage/home/rs17751/.local/lib/python3.7/site-packages/torch/quantization/fake_quantize.py", line 93, in forward
    self.activation_post_process(X.detach())
AttributeError: 'tuple' object has no attribute 'detach'

Expected behavior

Model to output a tensor for tensor feed to it.

Environment

PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: CentOS Linux release 7.3.1611 (Core)
GCC version: (GCC) 5.4.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: No
CUDA runtime version: 10.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Could not collect

Versions of relevant libraries:

[pip3] numpy==1.18.1
[pip3] numpydoc==0.9.2
[pip3] torch==1.6.0
[pip3] torchfile==0.1.0
[pip3] torchnet==0.0.4
[pip3] torchvision==0.4.2
[conda] blas                      1.0                         mkl
[conda] cudatoolkit               10.1.243             h6bb024c_0
[conda] mkl                       2020.0                      166
[conda] mkl-service               2.3.0            py37he904b0f_0
[conda] mkl_fft                   1.0.15           py37ha843d7b_0
[conda] mkl_random                1.1.0            py37hd6b4f25_0
[conda] numpy                     1.18.1           py37h4f9e942_0
[conda] numpy-base                1.18.1           py37hde5b4d6_1
[conda] numpydoc                  0.9.2                      py_0
[conda] pytorch                   1.5.1           py3.7_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchvision               0.6.1                py37_cu101    pytorch

Additional context

Can also produce error if these lines are used instead
qat_model.qconfig = torch.quantization.default_qconfig
qat_model = torch.quantization.prepare(qat_model)

cc @jerryzh168 @jianyuh @dzhulgakov @raghuramank100 @jamesr66a @vkuzo

@malfet malfet added the oncall: quantization Quantization support in PyTorch label Aug 6, 2020
@vkuzo
Copy link
Contributor

vkuzo commented Aug 6, 2020

hi @Robin-Simmons, thanks for the report! LSTM support for QAT is not ready at the moment. We have some preliminary work to enable it, but no timeline yet.

@vkuzo vkuzo added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Aug 6, 2020
@Robin-Simmons
Copy link
Author

Hi @vkuzo , that's a shame, but thanks for replying!

@vkuzo vkuzo changed the title LSTM network can not be evaluated after preparing for quantisation aware training. support LSTM for quantization aware training Mar 31, 2021
@vkuzo vkuzo added the low priority We're unlikely to get around to doing this in the near future label Mar 31, 2021
@github-actions github-actions bot added this to Need Triage in Quantization Triage Mar 31, 2021
@jerryzh168 jerryzh168 moved this from Need Triage to Low Priority in Quantization Triage Apr 27, 2021
@JorgeCMurillo
Copy link

Hi, I was wondering why there is no LSTM/GRU support for QAT, and what difficulties or challenges (if any ) are faced when attempting to quantize these models.

@Haibit
Copy link

Haibit commented Jan 16, 2024

hi @Robin-Simmons, thanks for the report! LSTM support for QAT is not ready at the moment. We have some preliminary work to enable it, but no timeline yet.

@vkuzo Hi, Does QAT support LSTM quant using pytorch 1.13.1?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
low priority We're unlikely to get around to doing this in the near future oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Quantization Triage
  
Low Priority
Development

No branches or pull requests

5 participants