From 89d2e3fcb3b83b15042edc461e0d579159db200c Mon Sep 17 00:00:00 2001 From: albanD Date: Wed, 31 Jan 2018 16:48:37 +0000 Subject: [PATCH] fix triu and tril for zero-strided inputs on gpu --- aten/src/THC/THCTensorMathPairwise.cu | 6 +-- aten/src/THC/generic/THCTensorMathPairwise.cu | 39 +++++++------------ test/test_cuda.py | 28 +++++++++++-- 3 files changed, 43 insertions(+), 30 deletions(-) diff --git a/aten/src/THC/THCTensorMathPairwise.cu b/aten/src/THC/THCTensorMathPairwise.cu index 4f2a745e67fb1..a4e0711dc2da1 100644 --- a/aten/src/THC/THCTensorMathPairwise.cu +++ b/aten/src/THC/THCTensorMathPairwise.cu @@ -375,8 +375,8 @@ struct TensorTriOp { TensorTriOp(T *start_, int64_t stride0_, int64_t stride1_, int64_t k_) : start(start_), stride0(stride0_), stride1(stride1_), k(k_) {} - __device__ __forceinline__ int mask(T *in) { - ptrdiff_t n = in - start; + __device__ __forceinline__ int mask(T *out) { + ptrdiff_t n = out - start; int64_t row, col; if (stride0 > stride1) { @@ -393,7 +393,7 @@ struct TensorTriOp { } __device__ __forceinline__ void operator()(T* out, T* in) { - *out = mask(in) ? *in : ScalarConvert::to(0); + *out = mask(out) ? *in : ScalarConvert::to(0); } __device__ __forceinline__ void operator()(T* v) { diff --git a/aten/src/THC/generic/THCTensorMathPairwise.cu b/aten/src/THC/generic/THCTensorMathPairwise.cu index e14df076b266e..1ec9a23560171 100644 --- a/aten/src/THC/generic/THCTensorMathPairwise.cu +++ b/aten/src/THC/generic/THCTensorMathPairwise.cu @@ -193,31 +193,27 @@ void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, int64_ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); THArgCheck(src_->nDimension == 2, 1, "expected a matrix"); - THCTensor *src = src_; - if (self_ == src_) - src = THCTensor_(newContiguous)(state, src_); + if (self_ != src_) + THCTensor_(resizeAs)(state, self_, src_); - int64_t stride0 = src->stride[0]; - int64_t stride1 = src->stride[1]; - real *start = THCTensor_(data)(state, src); + int64_t stride0 = self_->stride[0]; + int64_t stride1 = self_->stride[1]; + real *start = THCTensor_(data)(state, self_); TensorTriOp op(start, stride0, stride1, k); if (self_ == src_) { - if (!THC_pointwiseApply1(state, src, op)) { + if (!THC_pointwiseApply1(state, src_, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } else { - THCTensor_(resizeAs)(state, self_, src); + THCTensor_(resizeAs)(state, self_, src_); - if (!THC_pointwiseApply2(state, self_, src, op)) { + if (!THC_pointwiseApply2(state, self_, src_, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } - if (self_ == src_) - THCTensor_(freeCopyTo)(state, src, src_); - THCudaCheck(cudaGetLastError()); } @@ -226,31 +222,26 @@ void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, int64_ THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); THArgCheck(src_->nDimension == 2, 1, "expected a matrix"); - THCTensor *src = src_; - if (self_ == src_) - src = THCTensor_(newContiguous)(state, src_); + if (self_ != src_) + THCTensor_(resizeAs)(state, self_, src_); - int64_t stride0 = src->stride[0]; - int64_t stride1 = src->stride[1]; - real *start = THCTensor_(data)(state, src); + int64_t stride0 = self_->stride[0]; + int64_t stride1 = self_->stride[1]; + real *start = THCTensor_(data)(state, self_); TensorTriOp op(start, stride0, stride1, k); if (self_ == src_) { - if (!THC_pointwiseApply1(state, src, op)) { + if (!THC_pointwiseApply1(state, src_, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } else { - THCTensor_(resizeAs)(state, self_, src); - if (!THC_pointwiseApply2(state, self_, src, op)) { + if (!THC_pointwiseApply2(state, self_, src_, op)) { THArgCheck(false, 2, CUTORCH_DIM_WARNING); } } - if (self_ == src_) - THCTensor_(freeCopyTo)(state, src, src_); - THCudaCheck(cudaGetLastError()); } diff --git a/test/test_cuda.py b/test/test_cuda.py index 89b385d3bdbd1..010808adb2929 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -97,6 +97,10 @@ def medium_2d(t): return make_tensor(t, M, M) +def medium_2d_expanded(t): + return t(1).expand(M, M) + + def medium_2d_scaled(t, scale=10): return make_tensor(t, M, M).mul(scale) @@ -143,6 +147,13 @@ def tmp(t): return t(*sizes).copy_(torch.randn(*sizes)) return tmp +# Content of each tuple: +# - function name +# - constructor for the tensor, signature: fn(tensor_type) -> tensor +# - constructor for the arguments, signature: fn(tensor_type) -> list +# - postfix name for the test (must be unique for a given function) (default='') +# - tensor types to use (default=types) +# - disable inplace test, if set to True, no inplace test will be done (default=False) tests = [ ('add', small_3d, lambda t: [number(3.14, 3, t)]), ('add', small_3d, lambda t: [small_3d_positive(t)], 'tensor'), @@ -296,9 +307,11 @@ def tmp(t): ('topk', small_3d_unique, lambda t: [2, 1, True, True], 'dim_desc_sort'), ('trace', medium_2d, lambda t: [],), ('tril', medium_2d, lambda t: [],), + ('tril', medium_2d_expanded, lambda t: [], 'zero_stride', types, True), ('tril', medium_2d, lambda t: [2], 'positive'), ('tril', medium_2d, lambda t: [-2], 'negative'), ('triu', medium_2d, lambda t: [],), + ('triu', medium_2d_expanded, lambda t: [], 'zero_stride', types, True), ('triu', medium_2d, lambda t: [2], 'positive'), ('triu', medium_2d, lambda t: [-2], 'negative'), ('unsqueeze', new_t(2, 3, 4), lambda t: [2],), @@ -1351,18 +1364,27 @@ def test_nvtx(self): for t in types: tensor = t() gpu_tensor = get_gpu_type(t)() + + # Default values + desc = '' + type_subset = types + no_inplace = False if len(decl) == 3: name, constr, arg_constr = decl - desc = '' elif len(decl) == 4: name, constr, arg_constr, desc = decl elif len(decl) == 5: name, constr, arg_constr, desc, type_subset = decl - if t not in type_subset: - continue + elif len(decl) == 6: + name, constr, arg_constr, desc, type_subset, no_inplace = decl + + if t not in type_subset: + continue precision = custom_precision.get(name, TestCuda.precision) for inplace in (True, False): + if inplace and no_inplace: + continue if inplace: name_inner = name + '_' else: