diff --git a/torchvision/csrc/ops/cuda/cuda_helpers.h b/torchvision/csrc/ops/cuda/cuda_helpers.h index cec4a183899..e53a4fb6250 100644 --- a/torchvision/csrc/ops/cuda/cuda_helpers.h +++ b/torchvision/csrc/ops/cuda/cuda_helpers.h @@ -3,10 +3,12 @@ namespace vision { namespace ops { -#define CUDA_1D_KERNEL_LOOP(i, n) \ - for (int i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ +#define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ + for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ i += (blockDim.x * gridDim.x)) +#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int) + template constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { return (n + m - 1) / m; diff --git a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu index c72e0b666a3..b5422cce26d 100644 --- a/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu +++ b/torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu @@ -88,29 +88,26 @@ inline unsigned int GET_THREADS() { return 512; } -inline unsigned int GET_BLOCKS( - const unsigned int THREADS, - const unsigned int N) { - unsigned int kMaxGridNum = - at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; - return std::min(kMaxGridNum, (N + THREADS - 1) / THREADS); +inline unsigned int GET_BLOCKS(const unsigned int THREADS, const int64_t N) { + int64_t kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0]; + return (unsigned int)std::min(kMaxGridNum, (N + THREADS - 1) / THREADS); } -template +template __device__ scalar_t bilinear_interpolate( const scalar_t* in, - int height, - int width, + index_t height, + index_t width, scalar_t h, scalar_t w) { if (h <= -1 || height <= h || w <= -1 || width <= w) { return 0; } - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; + index_t h_low = floor(h); + index_t w_low = floor(w); + index_t h_high = h_low + 1; + index_t w_high = w_low + 1; scalar_t lh = h - h_low; scalar_t lw = w - w_low; @@ -135,38 +132,38 @@ __device__ scalar_t bilinear_interpolate( return val; } -template +template __global__ void deformable_im2col_kernel( - int n, + index_t n, const scalar_t* input_ptr, const scalar_t* offset_ptr, const scalar_t* mask_ptr, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int n_in_channels, - int n_offset_grps, - int out_h, - int out_w, + index_t height, + index_t width, + index_t weight_h, + index_t weight_w, + index_t pad_h, + index_t pad_w, + index_t stride_h, + index_t stride_w, + index_t dilation_h, + index_t dilation_w, + index_t batch_sz, + index_t n_in_channels, + index_t n_offset_grps, + index_t out_h, + index_t out_w, bool use_mask, scalar_t* columns_ptr) { - CUDA_1D_KERNEL_LOOP(index, n) { - const int out_x = index % out_w; - const int out_y = (index / out_w) % out_h; - const int out_b = (index / (out_w * out_h)) % batch_sz; - const int in_c = index / (out_w * out_h * batch_sz); - const int out_c = in_c * weight_h * weight_w; + CUDA_1D_KERNEL_LOOP_T(index, n, index_t) { + const index_t out_x = index % out_w; + const index_t out_y = (index / out_w) % out_h; + const index_t out_b = (index / (out_w * out_h)) % batch_sz; + const index_t in_c = index / (out_w * out_h * batch_sz); + const index_t out_c = in_c * weight_h * weight_w; - int c_per_offset_grp = n_in_channels / n_offset_grps; - const int grp_idx = in_c / c_per_offset_grp; + index_t c_per_offset_grp = n_in_channels / n_offset_grps; + const index_t grp_idx = in_c / c_per_offset_grp; columns_ptr += (out_c * (batch_sz * out_h * out_w) + out_b * (out_h * out_w) + @@ -185,8 +182,8 @@ __global__ void deformable_im2col_kernel( for (int i = 0; i < weight_h; ++i) { for (int j = 0; j < weight_w; ++j) { - const int mask_idx = i * weight_w + j; - const int offset_idx = 2 * mask_idx; + const index_t mask_idx = i * weight_w + j; + const index_t offset_idx = 2 * mask_idx; scalar_t mask_value = 1; if (use_mask) { @@ -231,36 +228,75 @@ void deformable_im2col( int deformable_group, bool use_mask, at::Tensor data_col) { - int num_kernels = n_in_channels * out_h * out_w * parallel_imgs; + int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs; const unsigned int threads = GET_THREADS(); const unsigned int blocks = GET_BLOCKS(threads, num_kernels); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - input.scalar_type(), "deformable_im2col", ([&] { - deformable_im2col_kernel<<>>( - num_kernels, - input.data_ptr(), - data_offset.data_ptr(), - data_mask.data_ptr(), - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - n_in_channels, - deformable_group, - out_h, - out_w, - use_mask, - data_col.data_ptr()); - })); + // Checks if we should use 64bits indexing + // https://github.com/pytorch/vision/issues/4269 + bool use_64bits_indexing = false; + // Checks if num_kernels or columns numel larger than 2 ** 31 + use_64bits_indexing |= num_kernels > (1 << 31); + use_64bits_indexing |= + ((int64_t)n_in_channels * weight_h * weight_w * parallel_imgs * out_h * + out_w > + (1 << 31)); + + if (use_64bits_indexing) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deformable_im2col", ([&] { + deformable_im2col_kernel<<>>( + num_kernels, + input.data_ptr(), + data_offset.data_ptr(), + data_mask.data_ptr(), + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_in_channels, + deformable_group, + out_h, + out_w, + use_mask, + data_col.data_ptr()); + })); + + } else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "deformable_im2col", ([&] { + deformable_im2col_kernel<<>>( + num_kernels, + input.data_ptr(), + data_offset.data_ptr(), + data_mask.data_ptr(), + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_in_channels, + deformable_group, + out_h, + out_w, + use_mask, + data_col.data_ptr()); + })); + } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -277,39 +313,40 @@ int get_greatest_divisor_below_bound(int n, int bound) { return 1; } -template +template __global__ void deformable_col2im_kernel( - int n, + index_t n, const scalar_t* col, const scalar_t* offset_ptr, const scalar_t* mask_ptr, - int channels, - int height, - int width, - int kernel_h, - int kernel_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int n_offset_grps, - int out_h, - int out_w, + index_t channels, + index_t height, + index_t width, + index_t kernel_h, + index_t kernel_w, + index_t pad_h, + index_t pad_w, + index_t stride_h, + index_t stride_w, + index_t dilation_h, + index_t dilation_w, + index_t batch_sz, + index_t n_offset_grps, + index_t out_h, + index_t out_w, bool use_mask, scalar_t* grad_im) { - CUDA_1D_KERNEL_LOOP(index, n) { - const int out_x = index % out_w; - const int out_y = (index / out_w) % out_h; - const int b = (index / (out_w * out_h)) % batch_sz; - const int j = (index / (out_w * out_h * batch_sz)) % kernel_w; - const int i = (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; - const int c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); - - int c_per_offset_grp = channels / n_offset_grps; - const int offset_grp = c / c_per_offset_grp; + CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { + const index_t out_x = index % out_w; + const index_t out_y = (index / out_w) % out_h; + const index_t b = (index / (out_w * out_h)) % batch_sz; + const index_t j = (index / (out_w * out_h * batch_sz)) % kernel_w; + const index_t i = + (index / (out_w * out_h * batch_sz * kernel_w)) % kernel_h; + const index_t c = index / (out_w * out_h * batch_sz * kernel_w * kernel_h); + + index_t c_per_offset_grp = channels / n_offset_grps; + const index_t offset_grp = c / c_per_offset_grp; offset_ptr += (b * n_offset_grps + offset_grp) * 2 * kernel_h * kernel_w * out_h * out_w; @@ -319,11 +356,12 @@ __global__ void deformable_col2im_kernel( out_h * out_w; } - const int mask_idx = i * kernel_w + j; - const int offset_idx = 2 * mask_idx; + const index_t mask_idx = i * kernel_w + j; + const index_t offset_idx = 2 * mask_idx; - const int offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; - const int offset_w_ptr = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + const index_t offset_h_ptr = ((offset_idx)*out_h + out_y) * out_w + out_x; + const index_t offset_w_ptr = + ((offset_idx + 1) * out_h + out_y) * out_w + out_x; const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; @@ -336,13 +374,13 @@ __global__ void deformable_col2im_kernel( const scalar_t y = (out_y * stride_h - pad_h) + i * dilation_h + offset_h; const scalar_t x = (out_x * stride_w - pad_w) + j * dilation_w + offset_w; - for (int dy = -1; dy <= 1; dy++) { - for (int dx = -1; dx <= 1; dx++) { - int yp = int(y) + dy; - int xp = int(x) + dx; + for (index_t dy = -1; dy <= 1; dy++) { + for (index_t dx = -1; dx <= 1; dx++) { + index_t yp = (index_t)y + dy; + index_t xp = (index_t)x + dx; if (0 <= yp && yp < height && 0 <= xp && xp < width && std::abs(y - yp) < 1 && std::abs(x - xp) < 1) { - int grad_pos = ((b * channels + c) * height + yp) * width + xp; + index_t grad_pos = ((b * channels + c) * height + yp) * width + xp; scalar_t weight = (1 - std::abs(y - yp)) * (1 - std::abs(x - xp)); atomicAdd(grad_im + grad_pos, mask_value * weight * col[index]); } @@ -374,37 +412,72 @@ void compute_grad_input( (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = - channels * weight_h * weight_w * out_h * out_w * parallel_imgs; + + int64_t num_kernels = + (int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs; const unsigned int threads = GET_THREADS(); const unsigned int blocks = GET_BLOCKS(threads, num_kernels); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - columns.scalar_type(), "compute_grad_input", ([&] { - deformable_col2im_kernel<<>>( - num_kernels, - columns.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), - channels, - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - n_offset_grps, - out_h, - out_w, - use_mask, - grad_im.data_ptr()); - })); + // Checks if we should use 64bits indexing + // https://github.com/pytorch/vision/issues/4269 + bool use_64bits_indexing = false; + // Checks if num_kernels or columns numel larger than 2 ** 31 + use_64bits_indexing |= num_kernels > (1 << 31); + + if (use_64bits_indexing) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_input", ([&] { + deformable_col2im_kernel<<>>( + num_kernels, + columns.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_im.data_ptr()); + })); + } else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_input", ([&] { + deformable_col2im_kernel<<>>( + num_kernels, + columns.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_im.data_ptr()); + })); + } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) { @@ -412,18 +485,18 @@ void compute_grad_input( } } -template +template __device__ scalar_t get_coordinate_weight( const scalar_t* im_data, - int height, - int width, + index_t height, + index_t width, scalar_t y, scalar_t x, bool is_y_direction) { - int y_l = floor(y); - int x_l = floor(x); - int y_h = y_l + 1; - int x_h = x_l + 1; + index_t y_l = floor(y); + index_t x_l = floor(x); + index_t y_h = y_l + 1; + index_t x_h = x_l + 1; bool valid_y_l = 0 <= y_l && y_l < height; bool valid_y_h = 0 <= y_h && y_h < height; @@ -445,47 +518,47 @@ __device__ scalar_t get_coordinate_weight( } } -template +template __global__ void deformable_col2im_coord_kernel( - int n, + index_t n, const scalar_t* col_ptr, const scalar_t* im_ptr, const scalar_t* offset_ptr, const scalar_t* mask_ptr, - int channels, - int height, - int width, - int weight_h, - int weight_w, - int pad_h, - int pad_w, - int stride_h, - int stride_w, - int dilation_h, - int dilation_w, - int batch_sz, - int offset_channels, - int n_offset_grps, - int out_h, - int out_w, + index_t channels, + index_t height, + index_t width, + index_t weight_h, + index_t weight_w, + index_t pad_h, + index_t pad_w, + index_t stride_h, + index_t stride_w, + index_t dilation_h, + index_t dilation_w, + index_t batch_sz, + index_t offset_channels, + index_t n_offset_grps, + index_t out_h, + index_t out_w, const bool use_mask, scalar_t* grad_offset, scalar_t* grad_mask) { - CUDA_1D_KERNEL_LOOP(index, n) { + CUDA_1D_KERNEL_LOOP_T(index, n, int64_t) { scalar_t grad_offset_val = 0; scalar_t grad_mask_val = 0; - int w = index % out_w; - int h = (index / out_w) % out_h; - int w_w = (index / (out_w * out_h * 2)) % weight_w; - int w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; - int c = (index / (out_w * out_h)) % offset_channels; - int b = index / (out_w * out_h * offset_channels); + index_t w = index % out_w; + index_t h = (index / out_w) % out_h; + index_t w_w = (index / (out_w * out_h * 2)) % weight_w; + index_t w_h = (index / (out_w * out_h * 2 * weight_w)) % weight_h; + index_t c = (index / (out_w * out_h)) % offset_channels; + index_t b = index / (out_w * out_h * offset_channels); - const int offset_grp = c / (2 * weight_h * weight_w); - const int col_step = weight_h * weight_w; + const index_t offset_grp = c / (2 * weight_h * weight_w); + const index_t col_step = weight_h * weight_w; - int c_per_offset_grp = channels / n_offset_grps; + index_t c_per_offset_grp = channels / n_offset_grps; col_ptr += offset_grp * c_per_offset_grp * weight_h * weight_w * batch_sz * out_w * out_h; @@ -499,23 +572,24 @@ __global__ void deformable_col2im_coord_kernel( out_h * out_w; } - const int offset_c = c - offset_grp * 2 * weight_h * weight_w; + const index_t offset_c = c - offset_grp * 2 * weight_h * weight_w; const bool is_y_direction = offset_c % 2 == 0; - const int c_bound = c_per_offset_grp * weight_h * weight_w; - for (int col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { - const int col_pos = (((col_c * batch_sz + b) * out_h) + h) * out_w + w; + const index_t c_bound = c_per_offset_grp * weight_h * weight_w; + for (index_t col_c = (offset_c / 2); col_c < c_bound; col_c += col_step) { + const index_t col_pos = + (((col_c * batch_sz + b) * out_h) + h) * out_w + w; - int out_x = col_pos % out_w; - int out_y = (col_pos / out_w) % out_h; - int j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; - int i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; + index_t out_x = col_pos % out_w; + index_t out_y = (col_pos / out_w) % out_h; + index_t j = (col_pos / (out_w * out_h * batch_sz)) % weight_w; + index_t i = (col_pos / (out_w * out_h * batch_sz * weight_w)) % weight_h; - const int mask_idx = i * weight_w + j; + const index_t mask_idx = i * weight_w + j; - const int offset_h_ptr = + const index_t offset_h_ptr = (((2 * mask_idx) * out_h + out_y) * out_w + out_x); - const int offset_w_ptr = + const index_t offset_w_ptr = (((2 * mask_idx + 1) * out_h + out_y) * out_w + out_x); const scalar_t offset_h = offset_ptr[offset_h_ptr]; const scalar_t offset_w = offset_ptr[offset_w_ptr]; @@ -543,7 +617,7 @@ __global__ void deformable_col2im_coord_kernel( grad_offset[index] = grad_offset_val; if (use_mask && is_y_direction) { - const int idx = + const index_t idx = ((((b * n_offset_grps + offset_grp) * weight_h + w_h) * weight_w + w_w) * out_h + @@ -580,40 +654,81 @@ void compute_grad_offset_and_mask( (height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1; int out_w = (width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1; - int num_kernels = - out_h * out_w * 2 * weight_h * weight_w * n_offset_grps * parallel_imgs; + int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w * + n_offset_grps * parallel_imgs; const unsigned int threads = GET_THREADS(); const unsigned int blocks = GET_BLOCKS(threads, num_kernels); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { - deformable_col2im_coord_kernel<<>>( - num_kernels, - columns.data_ptr(), - input.data_ptr(), - offset.data_ptr(), - mask.data_ptr(), - channels, - height, - width, - weight_h, - weight_w, - pad_h, - pad_w, - stride_h, - stride_w, - dilation_h, - dilation_w, - parallel_imgs, - 2 * weight_h * weight_w * n_offset_grps, - n_offset_grps, - out_h, - out_w, - use_mask, - grad_offset.data_ptr(), - grad_mask.data_ptr()); - })); + // Checks if we should use 64bits indexing + // https://github.com/pytorch/vision/issues/4269 + bool use_64bits_indexing = false; + // Checks if columns numel is larger than 2 ** 31 + use_64bits_indexing |= num_kernels > (1 << 31); + use_64bits_indexing |= + ((int64_t)channels * weight_h * weight_w * parallel_imgs * out_h * out_w > + (1 << 31)); + + if (use_64bits_indexing) { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { + deformable_col2im_coord_kernel + <<>>( + num_kernels, + columns.data_ptr(), + input.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + 2 * weight_h * weight_w * n_offset_grps, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_offset.data_ptr(), + grad_mask.data_ptr()); + })); + } else { + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + columns.scalar_type(), "compute_grad_offset_and_mask", ([&] { + deformable_col2im_coord_kernel<<>>( + num_kernels, + columns.data_ptr(), + input.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), + channels, + height, + width, + weight_h, + weight_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + parallel_imgs, + 2 * weight_h * weight_w * n_offset_grps, + n_offset_grps, + out_h, + out_w, + use_mask, + grad_offset.data_ptr(), + grad_mask.data_ptr()); + })); + } cudaError_t err = cudaGetLastError(); if (err != cudaSuccess) {