diff --git a/aten/src/THC/THCBlas.cu b/aten/src/THC/THCBlas.cu index f304b9cf06b52..de1051fd41a08 100644 --- a/aten/src/THC/THCBlas.cu +++ b/aten/src/THC/THCBlas.cu @@ -49,7 +49,7 @@ double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, doub } #ifdef CUDA_HALF_TENSOR -float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy) +half THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy) { #if CUDA_VERSION >= 8000 if (n == 1) { @@ -58,22 +58,23 @@ float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y } if ((n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX)) { - int i_n = (int)n; - int i_incx = (int)incx; - int i_incy = (int)incy; - float result; + half result; cublasHandle_t handle = THCState_getCurrentBlasHandle(state); cublasSetStream(handle, THCState_getCurrentStream(state)); - THCublasCheck(cublasDotEx(handle, i_n, x, CUDA_R_16F, i_incx, y, CUDA_R_16F, i_incy, &result, CUDA_R_32F, CUDA_R_32F)); + THCublasCheck(cublasDotEx(handle, n, + x, CUDA_R_16F, incx, + y, CUDA_R_16F, incy, + &result, CUDA_R_16F, + CUDA_R_32F)); return result; -} + } THError("Cublas_Hdot only supports n, incx and incy " "up to signed integer limits: %d", INT_MAX); - return 0; + return THC_float2half(0); #else THError("Cublas_Hdot requires CUDA 8.0+"); - return 0; + return THC_half2float(0); #endif } #endif @@ -360,7 +361,7 @@ void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char transb, i float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount) { if( (m >= INT_MAX) || (n >= INT_MAX) || (k >= INT_MAX) || (lda >= INT_MAX) || (ldb >= INT_MAX) || (ldc >= INT_MAX) || (batchCount >= INT_MAX) ) - + { THError("Cublas_SgemmStridedBatched only supports m, n, k, lda, ldb, ldc, batchCount" "with the bound [val] <= %d", INT_MAX); @@ -420,7 +421,7 @@ void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, i cublasSetStream(handle, THCState_getCurrentStream(state)); THCublasCheck(cublasDgemmStridedBatched(handle, opa, opb, (int)m, (int)n, (int)k, - &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, + &alpha, a, (int)lda, strideA, b, (int)ldb, strideB, &beta, c, (int)ldc, strideC, (int)batchCount)); } #endif diff --git a/aten/src/THC/THCBlas.h b/aten/src/THC/THCBlas.h index 8fe9b63c24486..9aac7ab228f3a 100644 --- a/aten/src/THC/THCBlas.h +++ b/aten/src/THC/THCBlas.h @@ -8,7 +8,7 @@ THC_API float THCudaBlas_Sdot(THCState *state, int64_t n, float *x, int64_t incx, float *y, int64_t incy); THC_API double THCudaBlas_Ddot(THCState *state, int64_t n, double *x, int64_t incx, double *y, int64_t incy); #ifdef CUDA_HALF_TENSOR -THC_API float THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy); +THC_API half THCudaBlas_Hdot(THCState *state, int64_t n, half *x, int64_t incx, half *y, int64_t incy); #endif /* Level 2 */ @@ -36,7 +36,7 @@ THC_API void THCudaBlas_SgemmStridedBatched(THCState *state, char transa, char t float alpha, const float *a, int64_t lda, int64_t strideA, const float *b, int64_t ldb, int64_t strideB, float beta, float *c, int64_t ldc, int64_t strideC, int64_t batchCount); THC_API void THCudaBlas_DgemmStridedBatched(THCState *state, char transa, char transb, int64_t m, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, + double alpha, const double *a, int64_t lda, int64_t strideA, const double *b, int64_t ldb, int64_t strideB, double beta, double *c, int64_t ldc, int64_t strideC, int64_t batchCount); #endif /* Inverse */ diff --git a/aten/src/THC/generic/THCTensorMathBlas.cu b/aten/src/THC/generic/THCTensorMathBlas.cu index 4f7381073e36d..a765f7818e3e2 100644 --- a/aten/src/THC/generic/THCTensorMathBlas.cu +++ b/aten/src/THC/generic/THCTensorMathBlas.cu @@ -24,10 +24,11 @@ THCTensor_(dot)(THCState *state, THCTensor *self, THCTensor *src) THCTensor_(data)(state, self), 1, THCTensor_(data)(state, src), 1); #elif defined(THC_REAL_IS_HALF) - accreal result = THCudaBlas_Hdot(state, + accreal result = ScalarConvert::to( + THCudaBlas_Hdot(state, THCTensor_(nElement)(state, self), THCTensor_(data)(state, self), 1, - THCTensor_(data)(state, src), 1); + THCTensor_(data)(state, src), 1)); #endif THCTensor_(free)(state, src); @@ -492,14 +493,14 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, ldc = result_->stride[2]; } - if (batch1->stride[transpose_result ? 2 : 1] == 1 && + if (batch1->stride[transpose_result ? 2 : 1] == 1 && batch1->stride[transpose_result ? 1 : 2] != 0) { transpose_batch1 = 'n'; batch1_ = batch1; lda = batch1_->stride[transpose_result ? 1 : 2]; } - else if (batch1->stride[transpose_result ? 1 : 2] == 1 && + else if (batch1->stride[transpose_result ? 1 : 2] == 1 && batch1->stride[transpose_result ? 2 : 1] != 0) { transpose_batch1 = 't'; @@ -513,7 +514,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, lda = batch1_->stride[1]; } - if (batch2->stride[transpose_result ? 2 : 1] == 1 && + if (batch2->stride[transpose_result ? 2 : 1] == 1 && batch2->stride[transpose_result ? 1 : 2] != 0) { transpose_batch2 = 'n'; @@ -537,7 +538,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, #if defined(THC_REAL_IS_FLOAT) || defined(THC_REAL_IS_DOUBLE) // Compute pointers to matrices in each batch. -#if CUDA_VERSION < 8000 +#if CUDA_VERSION < 8000 size_t matrices_size = num_batches * sizeof(real*); // Copy pointers to device. @@ -592,7 +593,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, THCudaFree(state, d_matrices1); THCudaFree(state, d_matrices2); THCudaFree(state, d_result_matrices); - + #else #ifdef THC_REAL_IS_FLOAT THCudaBlas_SgemmStridedBatched( @@ -606,7 +607,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, THCTensor_(data)(state, batch1_), lda, batch1_->stride[0], THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0], beta, - THCTensor_(data)(state, result_), ldc, result_->stride[0], + THCTensor_(data)(state, result_), ldc, result_->stride[0], num_batches); #elif defined(THC_REAL_IS_DOUBLE) THCudaBlas_DgemmStridedBatched( @@ -620,7 +621,7 @@ THCTensor_(baddbmm)(THCState *state, THCTensor *result, real beta, THCTensor *t, THCTensor_(data)(state, batch1_), lda, batch1_->stride[0], THCTensor_(data)(state, batch2_), ldb, batch2_->stride[0], beta, - THCTensor_(data)(state, result_), ldc, result_->stride[0], + THCTensor_(data)(state, result_), ldc, result_->stride[0], num_batches); #endif #endif diff --git a/test/test_cuda.py b/test/test_cuda.py index 5271898760fcd..67f8f9025550e 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -384,12 +384,18 @@ def get_cycles_per_ms(): return _cycles_per_ms -def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5): +def compare_cpu_gpu(tensor_constructor, arg_constructor, fn, t, precision=1e-5, force_gpu_half=False): def tmp(self): cpu_tensor = tensor_constructor(t) - gpu_tensor = to_gpu(cpu_tensor) + type_map = {} + if force_gpu_half: + type_map = { + 'torch.FloatTensor': 'torch.cuda.HalfTensor', + 'torch.DoubleTensor': 'torch.cuda.HalfTensor', + } + gpu_tensor = to_gpu(cpu_tensor, type_map) cpu_args = arg_constructor(t) - gpu_args = [to_gpu(arg) for arg in cpu_args] + gpu_args = [to_gpu(arg, type_map) for arg in cpu_args] cpu_result = getattr(cpu_tensor, fn)(*cpu_args) try: gpu_result = getattr(gpu_tensor, fn)(*gpu_args) @@ -1099,7 +1105,15 @@ def test_nvtx(self): test_name += '_' + desc assert not hasattr(TestCuda, test_name), "Duplicated test name: " + test_name - setattr(TestCuda, test_name, compare_cpu_gpu(constr, arg_constr, name_inner, t, precision)) + setattr(TestCuda, + test_name, + compare_cpu_gpu(constr, arg_constr, name_inner, t, precision)) + if t == torch.FloatTensor: + assert not hasattr(TestCuda, test_name + '_gpu_half'), "Duplicated test name: " + test_name + setattr(TestCuda, + test_name + '_gpu_half', + compare_cpu_gpu(constr, arg_constr, name_inner, t, + precision, force_gpu_half=True)) if __name__ == '__main__':