Skip to content

MKLDNN Segmentation Fault on backward pass on CPU with Conv1D layer #45746

@kkoehncke

Description

@kkoehncke

🐛 Bug

A segmentation fault and / or indefinite hanging occurs when using the Conv1D layer with varying kernel sizes. Based on my investigation, this occurs during the backward pass and is related to the MKLDNN back-end used for optimizing CPU operations.

To Reproduce

The below code snippet shows this behavior:

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.embedding_size = 16
        self.filter_num = 512
        self.padding_length = 25
        self.convolutions = nn.ModuleList([nn.Conv1d(1, self.filter_num // 8, kernel_size=(K, self.embedding_size), stride=1) for K in range(1, 9)])

    def forward(self):
        X = torch.randn([300, 1, self.padding_length, self.embedding_size])
        X = [torch.tanh(convolution(X).squeeze(3)) for convolution in self.convolutions]
        X = [F.max_pool1d(x, x.size(2)).squeeze(2) for x in X]
        X = torch.cat(X, dim=1)
        return X

if __name__ == "__main__":
    model = Model()
    output = model()
    output.mean().backward()

Expected behavior

I expect that this code would execute without error on CPU.

Environment

Using a Intel® Core™ i7-9800X CPU @ 3.80GHz

PyTorch version: 1.6.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: Could not collect

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration: GPU 0: GeForce RTX 2080 Ti
Nvidia driver version: 440.100
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] numpydoc==1.1.0
[pip3] pytorch-memlab==0.1.0
[pip3] torch==1.6.0
[pip3] torchvision==0.7.0
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.2.89              hfd86e86_1  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py37he904b0f_0  
[conda] mkl_fft                   1.2.0            py37h23d657b_0  
[conda] mkl_random                1.1.1            py37h0573a6f_0  
[conda] numpy                     1.18.4                   pypi_0    pypi
[conda] numpy-base                1.19.1           py37hfa32c7d_0  
[conda] numpydoc                  1.1.0                      py_0  
[conda] pytorch                   1.6.0           py3.7_cuda10.2.89_cudnn7.6.5_0    pytorch
[conda] pytorch-memlab            0.1.0                    pypi_0    pypi
[conda] torchvision               0.7.0                py37_cu102    pytorch

Additional context

This problem does not occur when done on GPU.

When turning off the MKLDNN backend via torch.backends.mkldnn.enabled = False, example executes without error.

When self.embedding_size is >= 14, the example produces a seg-fault. Any lower integer value executes successfully.

When running the above code snippet with MKLDNN_VERBOSE=1, the following stack trace is produced:

dnnl_verbose,info,oneDNN v1.5.0 (commit e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
dnnl_verbose,info,cpu,runtime:OpenMP
dnnl_verbose,info,cpu,isa:Intel AVX-512 with AVX512BW, AVX512VL, and AVX512DQ extensions
dnnl_verbose,info,gpu,runtime:none
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x1x16,0.00195312
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh25kh1sh1dh0ph0_iw16ow1kw16sw1dw0pw0,0.908203
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x25x1,0.529053
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x2x16,0.0151367
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh24kh2sh1dh0ph0_iw16ow1kw16sw1dw0pw0,0.844971
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x24x1,0.456055
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x3x16,0.00610352
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh23kh3sh1dh0ph0_iw16ow1kw16sw1dw0pw0,0.763916
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x23x1,0.393799
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x4x16,0.00585938
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh22kh4sh1dh0ph0_iw16ow1kw16sw1dw0pw0,1.00586
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x22x1,0.0761719
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x5x16,0.00610352
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh21kh5sh1dh0ph0_iw16ow1kw16sw1dw0pw0,1.39478
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x21x1,0.406982
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x6x16,0.00708008
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh20kh6sh1dh0ph0_iw16ow1kw16sw1dw0pw0,1.36694
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x20x1,0.0708008
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x7x16,0.0090332
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh19kh7sh1dh0ph0_iw16ow1kw16sw1dw0pw0,1.31787
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x19x1,0.0649414
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:Acdb16a:f0,,,64x1x8x16,0.00708008
dnnl_verbose,exec,cpu,convolution,jit:avx512_common,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:Acdb16a:f0 bia_f32::blocked:a:f0 dst_f32::blocked:aBcd16b:f0,scratchpad_mode:user;,alg:convolution_direct,mb300_ic1oc64_ih25oh18kh8sh1dh0ph0_iw16ow1kw16sw1dw0pw0,1.46704
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd16b:f0 dst_f32::blocked:abcd:f0,,,300x64x18x1,0.310059
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:abcd:f0 dst_f32::blocked:aBcd16b:f0,,,300x64x18x1,0.0441895
Segmentation fault (core dumped)

torch.config:

PyTorch built with:
  - GCC 7.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v1.5.0 (Git Hash e2ac1fac44c5078ca927cb9b90e1b3066a0b2ed0)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 10.2
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_37,code=compute_37
  - CuDNN 7.6.5
  - Magma 2.5.2
  - Build settings: BLAS=MKL, BUILD_TYPE=Release, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DUSE_VULKAN_WRAPPER -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wno-sign-compare -Wno-unused-parameter -Wno-unused-variable -Wno-unused-function -Wno-unused-result -Wno-unused-local-typedefs -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Wno-stringop-overflow, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, USE_CUDA=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_STATIC_DISPATCH=OFF, 

cc @ezyang @gchanan @zou3519 @bdhirsh @ejguan @gujinghui @PenghuiCheng @XiaobingSuper @jianyuh @VitalyFedyunin

Metadata

Metadata

Assignees

Labels

high prioritymodule: crashProblem manifests as a hard crash, as opposed to a RuntimeErrormodule: dependency bugProblem is not caused by us, but caused by an upstream library we usemodule: mkldnnRelated to Intel IDEEP or oneDNN (a.k.a. mkldnn) integrationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions