Skip to content

QNNPACK convolution race condition #58055

@dbalchev

Description

@dbalchev

🐛 Bug

We're hitting a QNNPACK race condition, when executing quantized convolutions concurrently in more than 1 thread.

To Reproduce

Steps to reproduce the behavior:

  1. Run the following script
    import random
    import threading
    
    import torch
    
    torch.set_grad_enabled(False)
    torch.backends.quantized.engine = "qnnpack"
    
    
    def _make_quant_layer():
        layer = torch.nn.intrinsic.ConvReLU2d(
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU()
        )
        layer.qconfig = torch.quantization.default_per_channel_qconfig
        fake_model = torch.nn.Sequential(layer)
        torch.quantization.prepare(fake_model, inplace=True)
        fake_model(torch.zeros(1, 64, 1, 1))
        layer = torch.nn.intrinsic.quantized.ConvReLU2d.from_float(layer)
        return layer
    
    
    class _RaceExperiment:
        def __init__(self, num_threads):
            self.module = torch.jit.script(
                torch.nn.Sequential(*[_make_quant_layer() for _ in range(50)])
            )
            self.module.eval()
    
            self.enter_barrier = threading.Barrier(num_threads)
            self.exit_barrier = threading.Barrier(num_threads)
            self.num_threads = num_threads
    
        def thread_run(self):
            data = torch.zeros((1, 64, 50, random.choice([50, 100, 200, 400])))
            data = torch.quantize_per_tensor(data, 1.0, 0, dtype=torch.quint8)
            self.enter_barrier.wait()
            self.module(data)
            self.exit_barrier.wait()
            self.module(data)
    
        def run_experiment(self):
            threads = [
                threading.Thread(target=self.thread_run) for _ in range(self.num_threads)
            ]
            for t in threads:
                t.start()
            for t in threads:
                t.join()
            print("done")
    
    
    _RaceExperiment(1).run_experiment()
    
    for _ in range(10):
        _RaceExperiment(3).run_experiment()
  2. Pretty consistently I get either a segmentation fault, another memory error (e.g. double free) or
    the following message:
    RuntimeError: The following operation failed in the TorchScript interpreter.
     Traceback of TorchScript (most recent call last):
       File "/home/ubuntu/repro-venv/lib/python3.7/site-packages/torch/nn/modules/container.py", line 119, in forward
         def forward(self, input):
             for module in self:
                 input = module(input)
                         ~~~~~~ <--- HERE
             return input
       File "/home/ubuntu/repro-venv/lib/python3.7/site-packages/torch/nn/intrinsic/quantized/modules/conv_relu.py", line 85, in forward
                 input = F.pad(input, _reversed_padding_repeated_twice,
                               mode=self.padding_mode)
             return torch.ops.quantized.conv2d_relu(
                    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
                 input, self._packed_params, self.scale, self.zero_point)
     RuntimeError: pack_w != nullptr INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp":678, please report a bug to PyTorch. Packed Weights are NULL
    

Expected behavior

The script should finish without any issues. This can be achieved either by using the FBGEMM
backend or setting num_threads = 1 in _RaceExperiment.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

Collecting environment information...
PyTorch version: 1.8.1+cpu
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.7 (64-bit runtime)
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.20.2
[pip3] torch==1.8.1+cpu
[pip3] torchaudio==0.8.1
[pip3] torchvision==0.9.1+cpu
[conda] Could not collect

Additional context

We've ran into this issue on one of our client's machines, so the bug is not theoretical. The actual
issue we observed there was the error message:

RuntimeError: pack_w != nullptr INTERNAL ASSERT FAILED at "/pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp":678, please report a bug to PyTorch. Packed Weights are NULL

(this is adapted to 1.8.1 since we run on 1.7.1, but I'm able to reproduce with the above script on
1.8.1)

Unfortunately we can't access AVX2 on that machine, so we cannot use FBGEMM.

My hypothesis of the cause of the race condition is the following:

  1. Before the first execution of the convolution, PackedConvWeightsQnnp::w is nullptr and
    PackedConvWeightsQnnp::input_scale doesn't contain a value (as outlined in https://github.com/pytorch/pytorch/blob/v1.8.1/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp#L254-L257)
  2. One of the threads starts executing enters https://github.com/pytorch/pytorch/blob/v1.8.1/aten/src/ATen/native/quantized/cpu/qconv.cpp#L604
    and stores a value in PackedConvWeightsQnnp::input_scale, but hasn't yet assigned
    PackedConvWeightsQnnp::w
  3. Another thread sees that value of PackedConvWeightsQnnp::input_scale, skips the branch on https://github.com/pytorch/pytorch/blob/v1.8.1/aten/src/ATen/native/quantized/cpu/qconv.cpp#L604,
    reaches https://github.com/pytorch/pytorch/blob/v1.8.1/aten/src/ATen/native/quantized/cpu/qconv.cpp#L678 and exits
    with the error we see above.

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo

Metadata

Metadata

Assignees

Labels

oncall: quantizationQuantization support in PyTorchtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions