diff --git a/include/nbla/cuda/function/transpose.hpp b/include/nbla/cuda/function/transpose.hpp index b3f373175..82597b8eb 100644 --- a/include/nbla/cuda/function/transpose.hpp +++ b/include/nbla/cuda/function/transpose.hpp @@ -25,10 +25,10 @@ namespace nbla { template class TransposeCuda : public Transpose { protected: public: - typedef typename CudaType::type Tc; + typedef typename CudaType::type Tcu; explicit TransposeCuda(const Context &ctx, const vector &axes) - : Transpose(ctx, axes) {} + : Transpose(ctx, axes), device_(std::stoi(ctx.device_id)) {} virtual ~TransposeCuda() {} virtual string name() { return "TransposeCuda"; } virtual vector allowed_array_classes() { @@ -36,6 +36,9 @@ template class TransposeCuda : public Transpose { } protected: + int device_; + std::shared_ptr var_strides_; + virtual void setup_impl(const Variables &inputs, const Variables &outputs); virtual void forward_impl(const Variables &inputs, const Variables &outputs); virtual void backward_impl(const Variables &inputs, const Variables &outputs, diff --git a/src/nbla/cuda/function/generic/transpose.cu b/src/nbla/cuda/function/generic/transpose.cu index 1bccac739..47c1bfcf7 100644 --- a/src/nbla/cuda/function/generic/transpose.cu +++ b/src/nbla/cuda/function/generic/transpose.cu @@ -12,75 +12,190 @@ // See the License for the specific language governing permissions and // limitations under the License. -// transpose.cpp - +#include #include -#include - #include #include +#include namespace nbla { -template -__global__ void -kernel_transpose_forward(const int num, const int ndim, const int64_t *axes, - const int64_t *x_strides, const int64_t *y_strides, - const int64_t *y_shape, const T *x, T *y) { - - NBLA_CUDA_KERNEL_LOOP(o, num) { - int i = 0; - for (int d = 0; d < ndim; ++d) { - const int k = int(o / y_strides[d]) % y_shape[d]; - i += k * x_strides[axes[d]]; +namespace { // local namespace + +template +__global__ void transpose_1d(const int size, const T *src, T *dst) { + NBLA_CUDA_KERNEL_LOOP(idx, size) { + dst[idx] = accum ? dst[idx] + src[idx] : src[idx]; + } +} + +template +__global__ void transpose_2d(const int2 shape, const T *src, T *dst) { + // One extra column to avoid memory bank conflicts. + __shared__ T tile[CUDA_WARP_SIZE][CUDA_WARP_SIZE + 1]; + + int x = blockIdx.x * CUDA_WARP_SIZE + threadIdx.x; + int y = blockIdx.y * CUDA_WARP_SIZE + threadIdx.y; + +#pragma unroll + for (int j = 0; j < CUDA_WARP_SIZE; j += 8) { + if ((x < shape.x) && (y + j < shape.y)) { + tile[threadIdx.y + j][threadIdx.x] = src[(y + j) * shape.x + x]; + } + } + + __syncthreads(); + + x = blockIdx.y * CUDA_WARP_SIZE + threadIdx.x; // transpose block offset + y = blockIdx.x * CUDA_WARP_SIZE + threadIdx.y; + +#pragma unroll + for (int j = 0; j < CUDA_WARP_SIZE; j += 8) { + if ((x < shape.y) && (y + j < shape.x)) { + auto val = tile[threadIdx.x][threadIdx.y + j]; + auto idx = (y + j) * shape.y + x; + dst[idx] = accum ? dst[idx] + val : val; } - y[o] = x[i]; } } -template -__global__ void -kernel_transpose_backward(const int num, const int ndim, const int64_t *axes, - const int64_t *x_strides, const int64_t *y_strides, - const int64_t *y_shape, const T *dy, T *dx) { - NBLA_CUDA_KERNEL_LOOP(o, num) { - int i = 0; - for (int d = 0; d < ndim; ++d) { - const int k = int(o / y_strides[d]) % y_shape[d]; - i += k * x_strides[axes[d]]; +template +__global__ void transpose_3d(const int size, const int3 ostride, + const int3 tstride, const T *src, T *dst) { + // ostride - strides of the transposed input shape + // tstride - transpose of the input shape strides + // Ex. transpose(ishape=(2, 3, 4), axes=(2, 1, 0)) => oshape (4, 3, 2) + // then ostride is (6, 2, 1) and tstride is (1, 4, 12) + NBLA_CUDA_KERNEL_LOOP(idx, size) { + auto z = (idx / ostride.z); + auto y = (idx - z * ostride.z) / ostride.y; + auto x = (idx - z * ostride.z - y * ostride.y); + T val = src[z * tstride.z + y * tstride.y + x * tstride.x]; + dst[idx] = accum ? dst[idx] + val : val; + } +} + +template +__global__ void transpose_4d(const int size, const int4 ostride, + const int4 tstride, const T *src, T *dst) { + // ostride - strides of the transposed input shape + // tstride - transpose of the input shape strides + // Ex. transpose(ishape=(2,3,4,5), axes=(3,2,1,0)) => oshape (5,4,3,2) + // then ostride is (24, 6, 2, 1) and tstride is (1, 5, 20, 60) + NBLA_CUDA_KERNEL_LOOP(idx, size) { + auto w = (idx / ostride.w); + auto z = (idx - w * ostride.w) / ostride.z; + auto y = (idx - w * ostride.w - z * ostride.z) / ostride.y; + auto x = (idx - w * ostride.w - z * ostride.z - y * ostride.y); + T val = src[w * tstride.w + z * tstride.z + y * tstride.y + x * tstride.x]; + dst[idx] = accum ? dst[idx] + val : val; + } +} + +template struct TransposeStrides { + T ostride; // strides of transposed input shape + T tstride; // transposed strides of input shape +}; + +template +__global__ void transpose_nd(const int size, const T *src, T *dst, + const TransposeStrides *strides, + const int ndim) { + NBLA_CUDA_KERNEL_LOOP(idx, size) { + int src_index = 0, dst_index = idx; + for (int axis = 0; axis < ndim; axis++) { + const auto k = idx / strides[axis].ostride; + src_index += k * strides[axis].tstride; + idx -= k * strides[axis].ostride; } - dx[i] = (accum ? dx[i] : (T)0) + dy[o]; + dst[dst_index] = accum ? dst[dst_index] + src[src_index] : src[src_index]; } } +} // end local namespace + template void TransposeCuda::setup_impl(const Variables &inputs, const Variables &outputs) { Transpose::setup_impl(inputs, outputs); + + const int ndim = this->x_shape_.size(); + const int size = outputs[0]->size(); + + NBLA_CHECK(size <= std::numeric_limits::max(), error_code::value, + "Maximum supported array size is %d elements", + std::numeric_limits::max()); + + if (ndim > 4) { + Shape_t shape = {2, ndim * (int)sizeof(TransposeStrides)}; + this->var_strides_ = std::make_shared(); + this->var_strides_->reshape(shape, true); + auto var = static_cast(this->var_strides_); + auto ptr = var->cast_data_and_get_pointer(Context()); + auto strides = reinterpret_cast *>(ptr); + for (int i = 0; i < ndim; i++) { + // strides for forward where we iterate the output shape + strides[i].ostride = this->y_strides_[i]; + strides[i].tstride = this->x_strides_transposed_[i]; + // strides for backward where we iterate the input shape + strides[i + ndim].ostride = this->x_strides_[i]; + strides[i + ndim].tstride = this->y_strides_transposed_[i]; + } + } } +template inline int2 make_int2_from(vector vec, int skip = 0) { + return make_int2(vec[1 + skip], vec[0 + skip]); +} + +template inline int3 make_int3_from(vector vec, int skip = 0) { + return make_int3(vec[2 + skip], vec[1 + skip], vec[0 + skip]); +} + +template inline int4 make_int4_from(vector vec, int skip = 0) { + return make_int4(vec[3 + skip], vec[2 + skip], vec[1 + skip], vec[0 + skip]); +} + +#define WARPS_FOR(N) NBLA_CEIL_INT_DIV(N, CUDA_WARP_SIZE) + template void TransposeCuda::forward_impl(const Variables &inputs, const Variables &outputs) { - cuda_set_device(std::stoi(this->ctx_.device_id)); - const Tc *x = inputs[0]->get_data_pointer(this->ctx_); - Tc *y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); - - // To avoid compiler error : type name is not allowed. - // The following statement causes a compiler error. - // this->v_axes_.get_data_pointer(this->ctx_) - auto get_ = [this](Variable &var) { - return var.get_data_pointer(this->ctx_); - }; - const int64_t *axes = get_(this->v_axes_); - const int64_t *x_strides = get_(this->v_x_strides_); - const int64_t *y_strides = get_(this->v_y_strides_); - const int64_t *y_shape = get_(this->v_y_shape_); - const int ndim = inputs[0]->ndim(); + cuda_set_device(this->device_); + auto x = inputs[0]->get_data_pointer(this->ctx_); + auto y = outputs[0]->cast_data_and_get_pointer(this->ctx_, true); + const int ndim = this->x_shape_.size(); const int size = outputs[0]->size(); - NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel_transpose_forward, size, ndim, axes, - x_strides, y_strides, y_shape, x, y); + if (ndim == 1) { + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(transpose_1d, size, x, y); + } else if (ndim == 2) { + const auto shape = make_int2_from(this->x_shape_); + const dim3 grid_dim(WARPS_FOR(shape.x), WARPS_FOR(shape.y)); + const dim3 block_dim(CUDA_WARP_SIZE, 8); + transpose_2d<<>>(shape, x, y); + NBLA_CUDA_KERNEL_CHECK(); + } else if (ndim == 3 && this->axes_[0] == 0) { + const auto shape = make_int2_from(this->x_shape_, 1); + const dim3 grid_dim(WARPS_FOR(shape.x), WARPS_FOR(shape.y)); + const dim3 block_dim(CUDA_WARP_SIZE, 8); + for (int i = 0, w = shape.x * shape.y; i < this->x_shape_[0]; i++) + transpose_2d<<>>(shape, x + i * w, y + i * w); + NBLA_CUDA_KERNEL_CHECK(); + } else if (ndim == 3) { + const auto ostride = make_int3_from(this->y_strides_); + const auto tstride = make_int3_from(this->x_strides_transposed_); + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(transpose_3d, size, ostride, tstride, x, y); + } else if (ndim == 4) { + const auto ostride = make_int4_from(this->y_strides_); + const auto tstride = make_int4_from(this->x_strides_transposed_); + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(transpose_4d, size, ostride, tstride, x, y); + } else { + auto var = static_cast(this->var_strides_); + auto ptr = var->get_data_pointer(this->ctx_); + auto strides = reinterpret_cast *>(ptr); + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(transpose_nd, size, x, y, strides, ndim); + } } template @@ -91,30 +206,56 @@ void TransposeCuda::backward_impl(const Variables &inputs, if (!propagate_down[0]) return; - cuda_set_device(std::stoi(this->ctx_.device_id)); - const Tc *dy = outputs[0]->get_grad_pointer(this->ctx_); - Tc *dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); - - // To avoid compiler error : type name is not allowed. - // The following statement causes a compiler error. - // this->v_axes_.get_data_pointer(this->ctx_) - auto get_ = [this](Variable &var) { - return var.get_data_pointer(this->ctx_); - }; - const int64_t *axes = get_(this->v_axes_); - const int64_t *x_strides = get_(this->v_x_strides_); - const int64_t *y_strides = get_(this->v_y_strides_); - const int64_t *y_shape = get_(this->v_y_shape_); - const int ndim = inputs[0]->ndim(); + cuda_set_device(this->device_); + auto dy = outputs[0]->get_grad_pointer(this->ctx_); + auto dx = inputs[0]->cast_grad_and_get_pointer(this->ctx_, !accum[0]); + const int ndim = this->x_shape_.size(); const int size = outputs[0]->size(); - if (accum[0]) { - NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_transpose_backward), size, - ndim, axes, x_strides, y_strides, y_shape, - dy, dx); + + if (ndim == 1) { + auto kernel = transpose_1d; + if (accum[0]) + kernel = transpose_1d; + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel, size, dy, dx); + } else if (ndim == 2) { + const auto shape = make_int2_from(this->y_shape_); + const dim3 grid(WARPS_FOR(shape.x), WARPS_FOR(shape.y)); + auto kernel = transpose_2d; + if (accum[0]) + kernel = transpose_2d; + kernel<<>>(shape, dy, dx); + NBLA_CUDA_KERNEL_CHECK(); + } else if (ndim == 3 && this->axes_[0] == 0) { + const auto shape = make_int2_from(this->y_shape_, 1); + const dim3 grid(WARPS_FOR(shape.x), WARPS_FOR(shape.y)); + auto kernel = transpose_2d; + if (accum[0]) + kernel = transpose_2d; + for (int i = 0, w = shape.x * shape.y; i < this->x_shape_[0]; i++) + kernel<<>>(shape, dy + i * w, dx + i * w); + NBLA_CUDA_KERNEL_CHECK(); + } else if (ndim == 3) { + const auto ostride = make_int3_from(this->x_strides_); + const auto tstride = make_int3_from(this->y_strides_transposed_); + auto kernel = transpose_3d; + if (accum[0]) + kernel = transpose_3d; + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel, size, ostride, tstride, dy, dx); + } else if (ndim == 4) { + const auto ostride = make_int4_from(this->x_strides_); + const auto tstride = make_int4_from(this->y_strides_transposed_); + auto kernel = transpose_4d; + if (accum[0]) + kernel = transpose_4d; + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel, size, ostride, tstride, dy, dx); } else { - NBLA_CUDA_LAUNCH_KERNEL_SIMPLE((kernel_transpose_backward), size, - ndim, axes, x_strides, y_strides, y_shape, - dy, dx); + auto var = static_cast(this->var_strides_); + auto ptr = var->get_data_pointer(this->ctx_); + auto strides = reinterpret_cast *>(ptr); + auto kernel = transpose_nd; + if (accum[0]) + kernel = transpose_nd; + NBLA_CUDA_LAUNCH_KERNEL_SIMPLE(kernel, size, dy, dx, strides + ndim, ndim); } } }