diff --git a/torch/csrc/cudnn/BatchNorm.cpp b/torch/csrc/cudnn/BatchNorm.cpp index 6d0350e735388ff..99be075e1382c9b 100644 --- a/torch/csrc/cudnn/BatchNorm.cpp +++ b/torch/csrc/cudnn/BatchNorm.cpp @@ -69,6 +69,10 @@ void cudnn_batch_norm_forward( mode = CUDNN_BATCHNORM_PER_ACTIVATION; } else { mode = CUDNN_BATCHNORM_SPATIAL; +#if CUDNN_VERSION >= 7000 + if(training) + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#endif } TensorDescriptor idesc; // input descriptor @@ -131,6 +135,11 @@ void cudnn_batch_norm_backward( mode = CUDNN_BATCHNORM_PER_ACTIVATION; } else { mode = CUDNN_BATCHNORM_SPATIAL; +#if CUDNN_VERSION >= 7000 + if(training) + mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; +#endif + } THVoidTensor_assertContiguous(input); diff --git a/torch/csrc/cudnn/Conv.cpp b/torch/csrc/cudnn/Conv.cpp index a813442380188f6..1943e0697969ddf 100644 --- a/torch/csrc/cudnn/Conv.cpp +++ b/torch/csrc/cudnn/Conv.cpp @@ -238,11 +238,11 @@ struct algorithm_search { CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1, CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, + CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, #if CUDNN_VERSION >= 6000 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING, #endif - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3, - CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED, }; size_t max_ws_size = getMaxWorkspaceSize(handle,conv,algo,sizeof(algo)/sizeof(algo[0]),state); Workspace ws(state, max_ws_size); diff --git a/torch/csrc/cudnn/Descriptors.h b/torch/csrc/cudnn/Descriptors.h index 339536cda563d2a..37ffc0be3b7887f 100644 --- a/torch/csrc/cudnn/Descriptors.h +++ b/torch/csrc/cudnn/Descriptors.h @@ -67,6 +67,11 @@ struct ConvolutionDescriptor if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT; CHECK(cudnnSetConvolutionNdDescriptor(desc, dim, pad, stride, upscale, CUDNN_CROSS_CORRELATION, mathType)); +#if CUDNN_VERSION >= 7000 + CHECK(cudnnSetConvolutionMathType(desc, CUDNN_DEFAULT_MATH)); + if(dataType == CUDNN_DATA_HALF) + CHECK(cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH)); +#endif } }; diff --git a/torch/cuda/nccl.py b/torch/cuda/nccl.py index b7838e29a0b93dc..09465c586321313 100644 --- a/torch/cuda/nccl.py +++ b/torch/cuda/nccl.py @@ -5,18 +5,80 @@ from torch.backends.cudnn import int_array lib = None - +nccl_2_0 = None __all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter'] +_communicators = {} + +# ncclDataType_t +nccl_types = { + 'torch.cuda.ByteTensor': 0, + 'torch.cuda.CharTensor': 0, + 'torch.cuda.IntTensor': 1, + 'torch.cuda.HalfTensor': 2, + 'torch.cuda.FloatTensor': 3, + 'torch.cuda.DoubleTensor': 4, + 'torch.cuda.LongTensor': 5, +} +nccl_types_2_0 = { + 'torch.cuda.ByteTensor': 0, + 'torch.cuda.CharTensor': 0, + 'torch.cuda.IntTensor': 2, + 'torch.cuda.HalfTensor': 6, + 'torch.cuda.FloatTensor': 7, + 'torch.cuda.DoubleTensor': 8, + 'torch.cuda.LongTensor': 4, +} + +# ncclRedOp_t +SUM = 0 +PROD = 1 +MAX = 2 +MIN = 3 + +status_codes_2_0 = { + 0: "Success", + 1: "Unhandled Cuda Error", + 2: "System Error", + 3: "Internal Error", + 4: "Invalid Argument Error", + 5: "Invalid Usage Error", +} + +status_codes = { + 0: "Success", + 1: "Unhandled Cuda Error", + 2: "System Error", + 3: "Internal Error", + 4: "Invalid Device Pointer", + 5: "Invalid Rank", + 6: "Unsupported Device Count", + 7: "Device Not Found", + 8: "Invalid Device Index", + 9: "Lib Wrapper Not Set", + 10: "Cuda Malloc Failed", + 11: "Rank Mismatch", + 12: "Invalid Argument", + 13: "Invalid Type", + 14: "Invalid Operation", +} + def _libnccl(): + global nccl_2_0 global lib + global status_codes + global nccl_types if lib is None: lib = ctypes.pydll.LoadLibrary(None) if hasattr(lib, 'ncclCommDestroy'): lib.ncclCommDestroy.restype = None else: lib = None + if hasattr(lib, 'ncclGroupStart'): + nccl_2_0 = True + status_codes = status_codes_2_0 + nccl_types = nccl_types_2_0 return lib @@ -39,52 +101,6 @@ def is_available(tensors): return True -_communicators = {} - -# ncclDataType_t -ncclChar = 0 -ncclInt = 1 -ncclHalf = 2 -ncclFloat = 3 -ncclDouble = 4 -ncclInt64 = 5 -ncclUint64 = 6 - -# ncclRedOp_t -SUM = 0 -PROD = 1 -MAX = 2 -MIN = 3 - -status_codes = { - 0: "Success", - 1: "Unhandled Cuda Error", - 2: "System Error", - 3: "Internal Error", - 4: "Invalid Device Pointer", - 5: "Invalid Rank", - 6: "Unsupported Device Count", - 7: "Device Not Found", - 8: "Invalid Device Index", - 9: "Lib Wrapper Not Set", - 10: "Cuda Malloc Failed", - 11: "Rank Mismatch", - 12: "Invalid Argument", - 13: "Invalid Type", - 14: "Invalid Operation", -} - -nccl_types = { - 'torch.cuda.ByteTensor': ncclChar, - 'torch.cuda.CharTensor': ncclChar, - 'torch.cuda.IntTensor': ncclInt, - 'torch.cuda.HalfTensor': ncclHalf, - 'torch.cuda.FloatTensor': ncclFloat, - 'torch.cuda.DoubleTensor': ncclDouble, - 'torch.cuda.LongTensor': ncclInt64, -} - - class NcclError(RuntimeError): def __init__(self, status): @@ -131,7 +147,6 @@ def communicator(inputs, outputs=None): key = ','.join(str(d) for d in devices) if key not in _communicators: _communicators[key] = NcclCommList(devices) - return _communicators[key] @@ -149,12 +164,16 @@ def all_reduce(inputs, outputs=None, op=SUM): count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): + if nccl_2_0 is not None: + lib.ncclGroupStart() for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclAllReduce( ctypes.c_void_p(inputs[i].data_ptr()), ctypes.c_void_p(outputs[i].data_ptr()), count, data_type, op, comm[i], cudaStream())) + if nccl_2_0 is not None: + lib.ncclGroupEnd() def reduce(inputs, outputs=None, root=0, op=SUM, streams=None): @@ -168,12 +187,16 @@ def reduce(inputs, outputs=None, root=0, op=SUM, streams=None): count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): + if nccl_2_0 is not None: + lib.ncclGroupStart() for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclReduce( ctypes.c_void_p(inputs[i].data_ptr()), ctypes.c_void_p(outputs[i].data_ptr()), count, data_type, op, root, comm[i], streams[i])) + if nccl_2_0 is not None: + lib.ncclGroupEnd() def broadcast(inputs, root=0): @@ -183,11 +206,15 @@ def broadcast(inputs, root=0): count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): + if nccl_2_0 is not None: + lib.ncclGroupStart() for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclBcast( ctypes.c_void_p(inputs[i].data_ptr()), count, data_type, root, comm[i], cudaStream())) + if nccl_2_0 is not None: + lib.ncclGroupEnd() def all_gather(inputs, outputs): @@ -196,12 +223,23 @@ def all_gather(inputs, outputs): count = inputs[0].numel() data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): + if nccl_2_0 is not None: + lib.ncclGroupStart() for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): - check_error(lib.ncclAllGather( - ctypes.c_void_p(inputs[i].data_ptr()), count, data_type, - ctypes.c_void_p(outputs[i].data_ptr()), comm[i], - cudaStream())) + if nccl_2_0 is None: + check_error(lib.ncclAllGather( + ctypes.c_void_p(inputs[i].data_ptr()), count, data_type, + ctypes.c_void_p(outputs[i].data_ptr()), comm[i], + cudaStream())) + else: + check_error(lib.ncclAllGather( + ctypes.c_void_p(inputs[i].data_ptr()), + ctypes.c_void_p(outputs[i].data_ptr()), count, + data_type, comm[i], cudaStream())) + + if nccl_2_0 is not None: + lib.ncclGroupEnd() def reduce_scatter(inputs, outputs, op=SUM): @@ -210,12 +248,16 @@ def reduce_scatter(inputs, outputs, op=SUM): count = inputs[0].numel() // len(inputs) data_type = nccl_types[inputs[0].type()] with torch.cuda._free_mutex(): + if nccl_2_0 is not None: + lib.ncclGroupStart() for i in range(len(inputs)): with torch.cuda.device(comm.devices[i]): check_error(lib.ncclReduceScatter( ctypes.c_void_p(inputs[i].data_ptr()), ctypes.c_void_p(outputs[i].data_ptr()), count, data_type, op, comm[i], cudaStream())) + if nccl_2_0 is not None: + lib.ncclGroupEnd() def _check_inputs(inputs, outputs=None, size_multiplier=1): diff --git a/torch/lib/THCS/CMakeLists.txt b/torch/lib/THCS/CMakeLists.txt index 10537710fdfa6b4..993a640a630635c 100644 --- a/torch/lib/THCS/CMakeLists.txt +++ b/torch/lib/THCS/CMakeLists.txt @@ -35,6 +35,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3") endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") +if(CUDA_VERSION VERSION_GREATER "8.0") + LIST(APPEND CUDA_NVCC_FLAGS "-D__CUDA_NO_HALF_OPERATORS__") +endif(CUDA_VERSION VERSION_GREATER "8.0") + if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU") SET(CMAKE_CXX_STANDARD 11) endif()