Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inconsistent results when trying to enable Tensor Cores on NVIDIA T4 #42311

Open
gerwin3 opened this issue Jul 30, 2020 · 8 comments
Open

Inconsistent results when trying to enable Tensor Cores on NVIDIA T4 #42311

gerwin3 opened this issue Jul 30, 2020 · 8 comments
Labels
module: amp (automated mixed precision) autocast module: cudnn Related to torch.backends.cudnn, and CuDNN support module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@gerwin3
Copy link

gerwin3 commented Jul 30, 2020

🐛 Bug

I am trying to figure out if my model is correctly using Tensor Cores on NVIDIA T4 but it seems that PyTorch is not enabling them correctly.

Context: I'm trying to get Tensor Cores to work with I3D. The model is converted to TorchScript and I'm executing it through the C++ API. I've converted the model to FP16 with .half(). And I'm seeing the following cudnn logs:

I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! Time: 2020-07-29T13:12:22.659680 (0d+0h+0m+10s since start)
i! Process=47432; Thread=47516; GPU=NULL; Handle=NULL; StreamId=NULL.

I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i! Time: 2020-07-29T13:12:22.659684 (0d+0h+0m+10s since start)
i! Process=47432; Thread=47516; GPU=NULL; Handle=NULL; StreamId=NULL.

I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! Time: 2020-07-29T13:12:22.659689 (0d+0h+0m+10s since start)
i! Process=47432; Thread=47516; GPU=NULL; Handle=NULL; StreamId=NULL.

I! CuDNN (v7605) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,512,4,14,14];
i!         strideA: type=int; val=[401408,784,196,14,1];
i!     xData: location=dev; addr=0x7fa487a00000;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[64,512,1,1,1];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x7fa4ac115e00;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=3;
i!         padA: type=int; val=[0,0,0];
i!         strideA: type=int; val=[1,1,1];
i!         dilationA: type=int; val=[1,1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM (0);
i!     workSpace: location=dev; addr=NULL_PTR;
i!     workSpaceSizeInBytes: type=size_t; val=0;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,64,4,14,14];
i!         strideA: type=int; val=[50176,784,196,14,1];
i!     yData: location=dev; addr=0x7fa489ddce00;
i! Time: 2020-07-29T13:12:22.659706 (0d+0h+0m+10s since start)
i! Process=47432; Thread=47516; GPU=0; Handle=0x7fa9ac05c4f0; StreamId=(nil) (defaultStream).

As can be seen from the logs, PyTorch calls cudnnSetConvolutionMathType three times. First, it sets the math type to DEFAULT, then to TENSOR_OP_MATH, then back to DEFAULT. This only happens about 1 out of 4 times, the other times is works correctly and PyTorch calls cudnnSetConvolutionMathType with DEFAULT, en then TENSOR_OP_MATH twice.

Even if PyTorch sets the math type correctly, I'm not seeing Tensor Core usage (I think), the output from nvprof shows:

======== Profiling result:
            Type  Time(%)      Time     Calls       Avg       Min       Max  Name
 GPU activities:   39.91%  216.292s    383537  563.94us  47.103us  6.8770ms  volta_fp16_scudnn_fp16_128x128_stridedB_splitK_small_nn_v1
                   18.95%  102.713s    369770  277.78us  4.5440us  8.7415ms  void cudnn::detail::implicit_convolveNd_sgemm<__half, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>(int, int, int, __half const *, int, cudnn::detail::implicit_convolveNd_sgemm<__half, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>*, __half const *, kernel_convNd_params, int, float, float, int, __half const *, __half const *)
                   13.10%  71.0103s    468766  151.48us  75.006us  403.87us  volta_fp16_scudnn_fp16_128x128_stridedB_splitK_interior_nn_v1
                    9.37%  50.7830s    853079  59.529us     960ns  46.2806s  [CUDA memcpy DtoD]
[...]

I would expect to see "h884" in volta_fp16_scudnn_fp16_128x128_stridedB_splitK_small_nn_v1 (i.e. volta_fp16_h884 cudnn_fp16_128x128_stridedB_splitK_small_nn_v1) if Tensor Cores were used.

When using the performance profiling tools, it does show some Tensor Core usage (less than 1%) but I think that's incorrect.

How can I correctly trigger Tensor Core usage in this case?

To Reproduce

Steps to reproduce the behavior:

  1. Take the code from: https://github.com/hassony2/kinetics_i3d_pytorch
  2. Call model.half()
  3. Export the model with torch.jit.script
  4. Load in C++ and do inference

I'm pretty sure the same would happen without step 3 and 4 but haven't gotten around to testing that yet.

Expected behavior

I would expect Tensor Core usage in this case.

Environment

  • PyTorch Version: 1.5, 1.5.1, 1.6 (Tested on all of those)
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): pip, libtorch via website
  • Python version: 3.7
  • CUDA/cuDNN version: 10.1, 10.2 (Tested on both)
  • GPU models and configuration: 1x NVIDIA Tesla T4
  • Any other relevant information:

Additional context

I'm still trying to see if I can get Tensor Core usage by converting the model through torch.cuda.amp but I'm running in to this bug: #36428

Will report back if I have results from that. I'm not sure if it matters though, as far as I understand, if Tensor Cores cannot even be activated if the whole model is converted to FP16, then automatic mixed precision won't help but I might be wrong about that.

EDIT: I got automatic mixed precision to work and I'm seeing the same results in terms of cudnn logs. Sometimes PyTorch correctly sets CUDNN_TENSOR_OP_MATH, sometimes it doesn't. Example:

✅ Correct:

I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i! Time: 2020-07-30T16:26:38.640234 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102128; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnCreateTensorDescriptor() called:
i! Time: 2020-07-30T16:26:38.640240 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102129; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnSetFilterNdDescriptor() called:
i!     dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!     format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     nbDims: type=int; val=5;
i!     filterDimA: type=int; val=[64,24,3,3,3];
i! Time: 2020-07-30T16:26:38.640246 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102130; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnSetTensorNdDescriptor() called:
i!     dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!     nbDims: type=int; val=5;
i!     dimA: type=int; val=[1,160,4,14,14];
i!     strideA: type=int; val=[125440,784,196,14,1];
i! Time: 2020-07-30T16:26:38.640258 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102129; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,128,4,14,14];
i!         strideA: type=int; val=[100352,784,196,14,1];
i!     xData: location=dev; addr=0x7f0f555b9000;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[256,128,3,3,3];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x7f0f2b404000;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=3;
i!         padA: type=int; val=[1,1,1];
i!         strideA: type=int; val=[1,1,1];
i!         dilationA: type=int; val=[1,1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM (1);
i!     workSpace: location=dev; addr=0x7f0f574e5200;
i!     workSpaceSizeInBytes: type=size_t; val=16364;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,256,4,14,14];
i!         strideA: type=int; val=[200704,784,196,14,1];
i!     yData: location=dev; addr=0x7f0f55400000;
i! Time: 2020-07-30T16:26:38.640262 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102128; GPU=0; Handle=0x7f146c0a0770; StreamId=(nil) (defaultStream).

🚫 Incorrect:

I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! Time: 2020-07-30T16:26:38.639380 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102129; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnCreateTensorDescriptor() called:
i! Time: 2020-07-30T16:26:38.639383 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102130; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnSetTensorNdDescriptor() called:
i!     dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!     nbDims: type=int; val=5;
i!     dimA: type=int; val=[1,128,4,14,14];
i!     strideA: type=int; val=[100352,784,196,14,1];
i! Time: 2020-07-30T16:26:38.639390 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102130; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnCreateFilterDescriptor() called:
i! Time: 2020-07-30T16:26:38.639395 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102130; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,512,4,14,14];
i!         strideA: type=int; val=[401408,784,196,14,1];
i!     xData: location=dev; addr=0x7f0f55400000;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[64,512,1,1,1];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x7f0f435e8e00;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=3;
i!         padA: type=int; val=[0,0,0];
i!         strideA: type=int; val=[1,1,1];
i!         dilationA: type=int; val=[1,1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM (0);
i!     workSpace: location=dev; addr=NULL_PTR;
i!     workSpaceSizeInBytes: type=size_t; val=0;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,64,4,14,14];
i!         strideA: type=int; val=[50176,784,196,14,1];
i!     yData: location=dev; addr=0x7f0f574d2c00;
i! Time: 2020-07-30T16:26:38.639396 (0d+0h+1m+34s since start)
i! Process=102036; Thread=102129; GPU=0; Handle=0x7f0f911512d0; StreamId=(nil) (defaultStream).

EDIT 2: Output from nvprof still suggests no usage of Tensor Core kernels when using automatic mixed precision, just like before 😞

Thanks in advance!

cc @mcarilli @csarofeen @ptrblck @VitalyFedyunin @jamesr66a

@VitalyFedyunin VitalyFedyunin added module: amp (automated mixed precision) autocast module: cudnn Related to torch.backends.cudnn, and CuDNN support module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 30, 2020
@VitalyFedyunin
Copy link
Contributor

Can you please try to convert your input tensors and model into channels last memory format

@mcarilli
Copy link
Collaborator

If inputs are FP16 and CUDNN_TENSOR_OP_MATH is set on the convolution descriptor, Tensor Core use is allowed but not guaranteed. Cudnn may still choose a non-tensor-core algo if its heuristics expect the non-tensor-core algo to be faster for your convolution sizes.

The fact that CUDNN_TENSOR_OP_MATH appears not to be set sometimes for FP16 inputs is unexpected. Afaik Pytorch's convolution backend should always set CUDNN_TENSOR_OP_MATH for FP16 inputs, at least for cudnn 7605. Do you have a minimal example of a convolution with FP16 inputs that doesn't get CUDNN_TENSOR_OP_MATH set for its descriptor? If so we should track down why.

To sandbox potential issues with jit/autocast, try running the ordinary Python model with .half() under nvprof or nsight systems (no autocast, no jit scripting, no c++) and see if you observe the same weird behavior. Also try setting torch.backends.cudnn.benchmark=True, run a warmup iteration so benchmark tries and caches all the fastest algos, then see if Tensor Cores are used for an iteration after the warmup. If you're not sure how to use nvprof/nsight systems that way I can elaborate.

@gerwin3
Copy link
Author

gerwin3 commented Jul 31, 2020

Cudnn may still choose a non-tensor-core algo if its heuristics expect the non-tensor-core algo to be faster for your convolution sizes.

Looking at some of the convolutions I posted above, do you think that might be the case? I can't really see anything that indicates that Tensor Cores wouldn't help.

To sandbox potential issues with jit/autocast, try running the ordinary Python model with .half() under nvprof or nsight systems (no autocast, no jit scripting, no c++) and see if you observe the same weird behavior. Also try setting torch.backends.cudnn.benchmark=True, run a warmup iteration so benchmark tries and caches all the fastest algos, then see if Tensor Cores are used for an iteration after the warmup. If you're not sure how to use nvprof/nsight systems that way I can elaborate.

I did an experiment from the Python code base as you said. I set torch.backends.cudnn.benchmark = True, took the vanilla model, applied model.half(), converted input to FP16 and then did some inference (I also didn't use amp or JIT). Results from nvprof seem identical, no tensor core usage:

 GPU activities:   40.24%  22.8876s     34084  671.50us  69.855us  6.8746ms  volta_fp16_scudnn_fp16_128x128_stridedB_splitK_small_nn_v1
                   22.45%  12.7698s     55978  228.12us  5.9840us  8.7513ms  void cudnn::detail::implicit_convolveNd_sgemm<__half, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>(int, int, int, __half const *, int, cudnn::detail::implicit_convolveNd_sgemm<__half, int=3, int=1024, int=5, int=5, int=3, int=3, int=3, int=1, bool=1, bool=0, bool=1>*, __half const *, kernel_convNd_params, int, float, float, int, __half const *, __half const *)
                    7.53%  4.28137s     31616  135.42us  24.447us  316.03us  void at::native::_GLOBAL__N__63_tmpxft_00001915_00000000_10_DilatedMaxPool3d_compute_75_cpp1_ii_2c55c1ee::max_pool3d_with_indices_single_out_frame<c10::Half>(c10::Half*, at::GenericPackedTensorAccessor<at::native::_GLOBAL__N__63_tmpxft_00001915_00000000_10_DilatedMaxPool3d_compute_75_cpp1_ii_2c55c1ee::max_pool3d_with_indices_single_out_frame<c10::Half>, unsigned long=4, at::DefaultPtrTraits, long>, c10::Half*, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int, int)
                    6.67%  3.79440s      3785  1.0025ms     928ns  12.057ms  [CUDA memcpy HtoD]
                    6.21%  3.53297s     24354  145.07us  87.007us  2.0281ms  void cudnn::detail::implicit_convolveNd_sgemm<__half, int=3, int=1024, int=6, int=7, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>(int, int, int, __half const *, int, cudnn::detail::implicit_convolveNd_sgemm<__half, int=3, int=1024, int=6, int=7, int=3, int=3, int=5, int=1, bool=1, bool=0, bool=1>*, __half const *, kernel_convNd_params, int, float, float, int, __half const *, __half const *)
                    5.56%  3.15995s     22637  139.59us  15.232us  223.39us  volta_fp16_scudnn_fp16_128x128_stridedB_splitK_interior_nn_v1
[... etc ...]

Also, still getting the weird behaviour where PyTorch sets DEFAULT_MATH when it should be TENSOR_OP_MATH, sample:

I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! Time: 2020-07-31T09:27:56.608726 (0d+0h+0m+49s since start)
i! Process=111274; Thread=111362; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_TENSOR_OP_MATH (1);
i! Time: 2020-07-31T09:27:56.608731 (0d+0h+0m+49s since start)
i! Process=111274; Thread=111362; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnSetConvolutionMathType() called:
i!     mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i! Time: 2020-07-31T09:27:56.608737 (0d+0h+0m+49s since start)
i! Process=111274; Thread=111362; GPU=NULL; Handle=NULL; StreamId=NULL.


I! CuDNN (v7605) function cudnnConvolutionForward() called:
i!     handle: type=cudnnHandle_t; streamId=(nil) (defaultStream);
i!     alpha: type=CUDNN_DATA_FLOAT; val=1.000000;
i!     xDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,512,4,14,14];
i!         strideA: type=int; val=[401408,784,196,14,1];
i!     xData: location=dev; addr=0x7fca5368f400;
i!     wDesc: type=cudnnFilterDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[112,512,1,1,1];
i!         format: type=cudnnTensorFormat_t; val=CUDNN_TENSOR_NCHW (0);
i!     wData: location=dev; addr=0x7fca535c4a00;
i!     convDesc: type=cudnnConvolutionDescriptor_t:
i!         mode: type=cudnnConvolutionMode_t; val=CUDNN_CROSS_CORRELATION (1);
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_FLOAT (0);
i!         mathType: type=cudnnMathType_t; val=CUDNN_DEFAULT_MATH (0);
i!         reorderType: type=int; val=0;
i!         arrayLength: type=int; val=3;
i!         padA: type=int; val=[0,0,0];
i!         strideA: type=int; val=[1,1,1];
i!         dilationA: type=int; val=[1,1,1];
i!         groupCount: type=int; val=1;
i!     algo: type=cudnnConvolutionFwdAlgo_t; val=CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM (0);
i!     workSpace: location=dev; addr=NULL_PTR;
i!     workSpaceSizeInBytes: type=size_t; val=0;
i!     beta: type=CUDNN_DATA_FLOAT; val=0.000000;
i!     yDesc: type=cudnnTensorDescriptor_t:
i!         dataType: type=cudnnDataType_t; val=CUDNN_DATA_HALF (2);
i!         nbDims: type=int; val=5;
i!         dimA: type=int; val=[1,112,4,14,14];
i!         strideA: type=int; val=[87808,784,196,14,1];
i!     yData: location=dev; addr=0x7fca55577000;
i! Time: 2020-07-31T09:27:56.608763 (0d+0h+0m+49s since start)
i! Process=111274; Thread=111362; GPU=0; Handle=0x7fcb2c6f9ac0; StreamId=(nil) (defaultStream).

I think we can rule out interference from either AMP, JIT or C++.

As I understand there are basically two issues:

  1. PyTorch sometimes sets CUDNN_DEFAULT_MATH even when input and output are FP16, so it should have set CUDNN_TENSOR_OP_MATH. Note that it doesn't always do this, about 1/4th of the time.

  2. Even if CUDNN_TENSOR_OP_MATH is set correctly, cudnn doesn't select Tensor Cores for operations where they should (maybe I'm missing something here).

I'll also try Vitaly's suggestion and report back!

@gerwin3
Copy link
Author

gerwin3 commented Jul 31, 2020

Can you please try to convert your input tensors and model into channels last memory format

@VitalyFedyunin I think that there's no support for NDWHC in Conv3D-layers? At least I can't seem to find it.

@xsacha
Copy link
Contributor

xsacha commented Aug 4, 2020

Shouldn't it be automatically transposed to whichever memory format (channels last) will allow half-precision? I thought CUDNN was doing this already, otherwise torchscript.

@VitalyFedyunin
Copy link
Contributor

@gerwin3 There is no NDWHC support so far
@xsacha CUDNN does support both memory formats, the problem is to find best suitable for the model, and as PyTorch (yet) can't do it automatically, we rely on the user to convert module/inputs to the desired format.

@xsacha
Copy link
Contributor

xsacha commented Aug 29, 2020

How do we convert to the memory format?
If we do it on the input, it will stay as that format for the duration of the model?

@gerwin3
Copy link
Author

gerwin3 commented Aug 29, 2020

Does this mean that there's no Tensor Core support for models with 3D convolutions? (because that would require them to be NDWHC format)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: amp (automated mixed precision) autocast module: cudnn Related to torch.backends.cudnn, and CuDNN support module: memory format Memory format/layout related issues/changes (channels_last, nhwc) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants