From d319e4698dc3ef7be98424b971310ac6990b5098 Mon Sep 17 00:00:00 2001 From: "HE, Tao" Date: Sat, 13 Jan 2018 00:11:56 +0800 Subject: [PATCH 1/3] More strict shape check on Conv operators. Signed-off-by: HE, Tao --- .../THCUNN/generic/SpatialConvolutionMM.cu | 14 +++++++++++-- .../THCUNN/generic/VolumetricConvolution.cu | 17 ++++++++++++--- aten/src/THNN/generic/SpatialConvolutionMM.c | 17 ++++++++++++--- .../THNN/generic/VolumetricConvolutionMM.c | 21 ++++++++++++++++--- 4 files changed, 58 insertions(+), 11 deletions(-) diff --git a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu index fc53d1997651e..ba6860fcfec50 100644 --- a/aten/src/THCUNN/generic/SpatialConvolutionMM.cu +++ b/aten/src/THCUNN/generic/SpatialConvolutionMM.cu @@ -38,8 +38,18 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)( int64_t inputHeight = input->size[dimh]; int64_t inputWidth = input->size[dimw]; int64_t nOutputPlane = weight->size[0]; - int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; - int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; + + int64_t exactInputHeight = inputHeight + 2 * padH; + int64_t exactInputWidth = inputWidth + 2 * padW; + + if (exactInputHeight < kH || exactInputWidth < kW) { + THError("Calculated input size: (%d x %d). " + "Kernel size: (%d x %d). Kernel size can't greater than actual input size", + exactInputHeight,exactInputWidth,kH,kW); + } + + int64_t outputHeight = (exactInputHeight - kH) / dH + 1; + int64_t outputWidth = (exactInputWidth - kW) / dW + 1; if (outputWidth < 1 || outputHeight < 1) THError("Given input size: (%d x %d x %d). " diff --git a/aten/src/THCUNN/generic/VolumetricConvolution.cu b/aten/src/THCUNN/generic/VolumetricConvolution.cu index 7c5148af44a62..3d5a93e563159 100644 --- a/aten/src/THCUNN/generic/VolumetricConvolution.cu +++ b/aten/src/THCUNN/generic/VolumetricConvolution.cu @@ -72,9 +72,20 @@ static inline void THNN_(VolumetricConvolution_shapeCheck) int64_t inputWidth = input->size[dimw]; int64_t inputHeight = input->size[dimh]; int64_t inputDepth = input->size[dimd]; - int64_t outputWidth = (inputWidth + 2*padH - kH) / dH + 1; - int64_t outputHeight = (inputHeight + 2*padT - kT) / dT + 1; - int64_t outputDepth = (inputDepth + 2*padW - kW) / dW + 1; + + int64_t exactInputDepth = inputDepth + 2*padT; + int64_t exactInputHeight = inputHeight + 2*padH; + int64_t exactInputWidth = inputWidth + 2*padW; + + if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) { + THError("Calculated input size: (%d x %d x %d). " + "Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size", + exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW); + } + + int64_t outputWidth = (exactInputDepth - kH) / dH + 1; + int64_t outputHeight = (exactInputHeight - kT) / dT + 1; + int64_t outputDepth = (exactInputWidth - kW) / dW + 1; if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1) { diff --git a/aten/src/THNN/generic/SpatialConvolutionMM.c b/aten/src/THNN/generic/SpatialConvolutionMM.c index 87ce76b01830b..bce477bb32426 100644 --- a/aten/src/THNN/generic/SpatialConvolutionMM.c +++ b/aten/src/THNN/generic/SpatialConvolutionMM.c @@ -36,13 +36,24 @@ static inline void THNN_(SpatialConvolutionMM_shapeCheck)( int64_t inputHeight = input->size[dimh]; int64_t inputWidth = input->size[dimw]; int64_t nOutputPlane = weight->size[0]; - int64_t outputHeight = (inputHeight + 2*padH - kH) / dH + 1; - int64_t outputWidth = (inputWidth + 2*padW - kW) / dW + 1; - if (outputWidth < 1 || outputHeight < 1) + int64_t exactInputHeight = inputHeight + 2 * padH; + int64_t exactInputWidth = inputWidth + 2 * padW; + + if (exactInputHeight < kH || exactInputWidth < kW) { + THError("Calculated input size: (%d x %d). " + "Kernel size: (%d x %d). Kernel size can't greater than actual input size", + exactInputHeight,exactInputWidth,kH,kW); + } + + int64_t outputHeight = (exactInputHeight - kH) / dH + 1; + int64_t outputWidth = (exactInputWidth - kW) / dW + 1; + + if (outputWidth < 1 || outputHeight < 1) { THError("Given input size: (%d x %d x %d). " "Calculated output size: (%d x %d x %d). Output size is too small", nInputPlane,inputHeight,inputWidth,nOutputPlane,outputHeight,outputWidth); + } THNN_CHECK_DIM_SIZE(input, ndim, dimf, nInputPlane); diff --git a/aten/src/THNN/generic/VolumetricConvolutionMM.c b/aten/src/THNN/generic/VolumetricConvolutionMM.c index 5f83f13cfe3a5..607f4738c46cc 100644 --- a/aten/src/THNN/generic/VolumetricConvolutionMM.c +++ b/aten/src/THNN/generic/VolumetricConvolutionMM.c @@ -43,6 +43,10 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)( int64_t inputHeight; int64_t inputWidth; int64_t nOutputPlane; + + int64_t exactInputDepth; + int64_t exactInputHeight; + int64_t exactInputWidth; int64_t outputDepth; int64_t outputHeight; int64_t outputWidth; @@ -52,9 +56,20 @@ static void inline THNN_(VolumetricConvolutionMM_shapeCheck)( inputHeight = input->size[dimh]; inputWidth = input->size[dimw]; nOutputPlane = weight->size[0]; - outputDepth = (inputDepth + 2*pT - kT) / dT + 1; - outputHeight = (inputHeight + 2*pH - kH) / dH + 1; - outputWidth = (inputWidth + 2*pW - kW) / dW + 1; + + exactInputDepth = inputDepth + 2*pT; + exactInputHeight = inputHeight + 2*pH; + exactInputWidth = inputWidth + 2*pW; + + if (exactInputDepth < kT || exactInputHeight < kH || exactInputWidth < kW) { + THError("Calculated input size: (%d x %d x %d). " + "Kernel size: (%d x %d x %d). Kernel size can't greater than actual input size", + exactInputDepth,exactInputHeight,exactInputWidth,kT,kH,kW); + } + + outputDepth = (exactInputDepth - kT) / dT + 1; + outputHeight = (exactInputHeight - kH) / dH + 1; + outputWidth = (exactInputWidth - kW) / dW + 1; if (outputWidth < 1 || outputHeight < 1 || outputDepth < 1) { From 49cc92a091367169e4c668c7def1a6ab123eb676 Mon Sep 17 00:00:00 2001 From: "HE, Tao" Date: Sat, 13 Jan 2018 00:14:07 +0800 Subject: [PATCH 2/3] Test case for conv's shape check. Signed-off-by: HE, Tao --- test/test_nn.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/test_nn.py b/test/test_nn.py index a60d6da9a5bb5..6d998ae169370 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2202,6 +2202,31 @@ def test_conv_modules_raise_error_on_incorrect_input_size(self): input = Variable(torch.Tensor(torch.Size((3, ) * dims))) self.assertRaises(RuntimeError, lambda: module(input)) + def test_conv_shapecheck(self): + def test(should_raise, module, input_size): + input = Variable(torch.Tensor(3, *input_size)) + if should_raise: + self.assertRaises(RuntimeError, lambda: module(input)) + else: + module(input) ## just run it to ensure no exception raised. + + # Conv1d + test(True, nn.Conv1d(1, 1, 3), (1, 2)) + test(True, nn.Conv1d(1, 1, 3, stride=2), (1, 2)) + test(False, nn.Conv1d(1, 1, 2), (1, 2)) + test(False, nn.Conv1d(1, 1, 2, stride=2), (1, 2)) + test(False, nn.Conv1d(1, 1, 3, stride=2, padding=1), (1, 2)) + + # Conv2d + test(True, nn.Conv2d(1, 1, (3, 3)), (1, 2, 2)) + test(False, nn.Conv2d(1, 1, (3, 3)), (1, 3, 3)) + test(False, nn.Conv2d(1, 1, (3, 3), padding=1), (1, 2, 2)) + + # Conv3D + test(True, nn.Conv3d(1, 1, (3, 3, 3)), (1, 2, 2, 2)) + test(False, nn.Conv3d(1, 1, (3, 3, 3)), (1, 3, 3, 3)) + test(False, nn.Conv3d(1, 1, (3, 3, 3), padding=1), (1, 2, 2, 2)) + def test_ConvTranspose2d_output_size(self): m = nn.ConvTranspose2d(3, 4, 3, 3, 0, 2) i = Variable(torch.randn(2, 3, 6, 6)) From 0288e51481968acdfde79711e411b204b725eaf0 Mon Sep 17 00:00:00 2001 From: "HE, Tao" Date: Sat, 13 Jan 2018 00:17:24 +0800 Subject: [PATCH 3/3] Fix lint. Signed-off-by: HE, Tao --- test/test_nn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_nn.py b/test/test_nn.py index 6d998ae169370..786f7d36515d2 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2208,7 +2208,8 @@ def test(should_raise, module, input_size): if should_raise: self.assertRaises(RuntimeError, lambda: module(input)) else: - module(input) ## just run it to ensure no exception raised. + # just run it to ensure no exception raised. + module(input) # Conv1d test(True, nn.Conv1d(1, 1, 3), (1, 2))