Skip to content

Commit

Permalink
Updates for CUDA 9
Browse files Browse the repository at this point in the history
  • Loading branch information
csarofeen authored and soumith committed Aug 25, 2017
1 parent b079469 commit ec86d0b
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 54 deletions.
9 changes: 9 additions & 0 deletions torch/csrc/cudnn/BatchNorm.cpp
Expand Up @@ -69,6 +69,10 @@ void cudnn_batch_norm_forward(
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7000
if(training)
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif
}

TensorDescriptor idesc; // input descriptor
Expand Down Expand Up @@ -131,6 +135,11 @@ void cudnn_batch_norm_backward(
mode = CUDNN_BATCHNORM_PER_ACTIVATION;
} else {
mode = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7000
if(training)
mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
#endif

}

THVoidTensor_assertContiguous(input);
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/cudnn/Conv.cpp
Expand Up @@ -238,11 +238,11 @@ struct algorithm_search<cudnnConvolutionBwdFilterAlgo_t> {
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
#if CUDNN_VERSION >= 6000
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
#endif
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
};
size_t max_ws_size = getMaxWorkspaceSize<cudnnConvolutionBwdFilterAlgo_t>(handle,conv,algo,sizeof(algo)/sizeof(algo[0]),state);
Workspace ws(state, max_ws_size);
Expand Down
5 changes: 5 additions & 0 deletions torch/csrc/cudnn/Descriptors.h
Expand Up @@ -67,6 +67,11 @@ struct ConvolutionDescriptor
if (dataType == CUDNN_DATA_HALF) mathType = CUDNN_DATA_FLOAT;
CHECK(cudnnSetConvolutionNdDescriptor(desc, dim, pad, stride, upscale,
CUDNN_CROSS_CORRELATION, mathType));
#if CUDNN_VERSION >= 7000
CHECK(cudnnSetConvolutionMathType(desc, CUDNN_DEFAULT_MATH));
if(dataType == CUDNN_DATA_HALF)
CHECK(cudnnSetConvolutionMathType(desc, CUDNN_TENSOR_OP_MATH));
#endif
}
};

Expand Down
146 changes: 94 additions & 52 deletions torch/cuda/nccl.py
Expand Up @@ -5,18 +5,80 @@
from torch.backends.cudnn import int_array

lib = None

nccl_2_0 = None
__all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']

_communicators = {}

# ncclDataType_t
nccl_types = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 1,
'torch.cuda.HalfTensor': 2,
'torch.cuda.FloatTensor': 3,
'torch.cuda.DoubleTensor': 4,
'torch.cuda.LongTensor': 5,
}
nccl_types_2_0 = {
'torch.cuda.ByteTensor': 0,
'torch.cuda.CharTensor': 0,
'torch.cuda.IntTensor': 2,
'torch.cuda.HalfTensor': 6,
'torch.cuda.FloatTensor': 7,
'torch.cuda.DoubleTensor': 8,
'torch.cuda.LongTensor': 4,
}

# ncclRedOp_t
SUM = 0
PROD = 1
MAX = 2
MIN = 3

status_codes_2_0 = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Argument Error",
5: "Invalid Usage Error",
}

status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
}


def _libnccl():
global nccl_2_0
global lib
global status_codes
global nccl_types
if lib is None:
lib = ctypes.pydll.LoadLibrary(None)
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
else:
lib = None
if hasattr(lib, 'ncclGroupStart'):
nccl_2_0 = True
status_codes = status_codes_2_0
nccl_types = nccl_types_2_0
return lib


Expand All @@ -39,52 +101,6 @@ def is_available(tensors):
return True


_communicators = {}

# ncclDataType_t
ncclChar = 0
ncclInt = 1
ncclHalf = 2
ncclFloat = 3
ncclDouble = 4
ncclInt64 = 5
ncclUint64 = 6

# ncclRedOp_t
SUM = 0
PROD = 1
MAX = 2
MIN = 3

status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
}

nccl_types = {
'torch.cuda.ByteTensor': ncclChar,
'torch.cuda.CharTensor': ncclChar,
'torch.cuda.IntTensor': ncclInt,
'torch.cuda.HalfTensor': ncclHalf,
'torch.cuda.FloatTensor': ncclFloat,
'torch.cuda.DoubleTensor': ncclDouble,
'torch.cuda.LongTensor': ncclInt64,
}


class NcclError(RuntimeError):

def __init__(self, status):
Expand Down Expand Up @@ -131,7 +147,6 @@ def communicator(inputs, outputs=None):
key = ','.join(str(d) for d in devices)
if key not in _communicators:
_communicators[key] = NcclCommList(devices)

return _communicators[key]


Expand All @@ -149,12 +164,16 @@ def all_reduce(inputs, outputs=None, op=SUM):
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
if nccl_2_0 is not None:
lib.ncclGroupStart()
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclAllReduce(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()),
count, data_type, op, comm[i], cudaStream()))
if nccl_2_0 is not None:
lib.ncclGroupEnd()


def reduce(inputs, outputs=None, root=0, op=SUM, streams=None):
Expand All @@ -168,12 +187,16 @@ def reduce(inputs, outputs=None, root=0, op=SUM, streams=None):
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
if nccl_2_0 is not None:
lib.ncclGroupStart()
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclReduce(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()), count,
data_type, op, root, comm[i], streams[i]))
if nccl_2_0 is not None:
lib.ncclGroupEnd()


def broadcast(inputs, root=0):
Expand All @@ -183,11 +206,15 @@ def broadcast(inputs, root=0):
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
if nccl_2_0 is not None:
lib.ncclGroupStart()
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclBcast(
ctypes.c_void_p(inputs[i].data_ptr()), count,
data_type, root, comm[i], cudaStream()))
if nccl_2_0 is not None:
lib.ncclGroupEnd()


def all_gather(inputs, outputs):
Expand All @@ -196,12 +223,23 @@ def all_gather(inputs, outputs):
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
if nccl_2_0 is not None:
lib.ncclGroupStart()
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclAllGather(
ctypes.c_void_p(inputs[i].data_ptr()), count, data_type,
ctypes.c_void_p(outputs[i].data_ptr()), comm[i],
cudaStream()))
if nccl_2_0 is None:
check_error(lib.ncclAllGather(
ctypes.c_void_p(inputs[i].data_ptr()), count, data_type,
ctypes.c_void_p(outputs[i].data_ptr()), comm[i],
cudaStream()))
else:
check_error(lib.ncclAllGather(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()), count,
data_type, comm[i], cudaStream()))

if nccl_2_0 is not None:
lib.ncclGroupEnd()


def reduce_scatter(inputs, outputs, op=SUM):
Expand All @@ -210,12 +248,16 @@ def reduce_scatter(inputs, outputs, op=SUM):
count = inputs[0].numel() // len(inputs)
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
if nccl_2_0 is not None:
lib.ncclGroupStart()
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclReduceScatter(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()), count, data_type,
op, comm[i], cudaStream()))
if nccl_2_0 is not None:
lib.ncclGroupEnd()


def _check_inputs(inputs, outputs=None, size_multiplier=1):
Expand Down
4 changes: 4 additions & 0 deletions torch/lib/THCS/CMakeLists.txt
Expand Up @@ -35,6 +35,10 @@ if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
endif(CMAKE_CXX_COMPILER_VERSION VERSION_GREATER "4.9.3")
endif(CMAKE_CXX_COMPILER_ID STREQUAL "GNU")

if(CUDA_VERSION VERSION_GREATER "8.0")
LIST(APPEND CUDA_NVCC_FLAGS "-D__CUDA_NO_HALF_OPERATORS__")
endif(CUDA_VERSION VERSION_GREATER "8.0")

if(NOT CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
SET(CMAKE_CXX_STANDARD 11)
endif()
Expand Down

0 comments on commit ec86d0b

Please sign in to comment.