diff --git a/CMakeLists.txt b/CMakeLists.txt index 81ca559d530..e6b97786888 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,9 +32,11 @@ file(GLOB HEADERS torchvision/csrc/*.h) # Image extension file(GLOB IMAGE_HEADERS torchvision/csrc/cpu/image/*.h) file(GLOB IMAGE_SOURCES torchvision/csrc/cpu/image/*.cpp) -file(GLOB OPERATOR_SOURCES torchvision/csrc/cpu/*.h torchvision/csrc/cpu/*.cpp ${IMAGE_HEADERS} ${IMAGE_SOURCES} ${HEADERS} torchvision/csrc/*.cpp) +file(GLOB OPERATOR_HEADERS torchvision/csrc/cpu/*.h) +file(GLOB OPERATOR_SOURCES ${OPERATOR_HEADERS} torchvision/csrc/cpu/*.cpp ${IMAGE_HEADERS} ${IMAGE_SOURCES} ${HEADERS} torchvision/csrc/*.cpp) if(WITH_CUDA) - file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} torchvision/csrc/cuda/*.h torchvision/csrc/cuda/*.cu) + file(GLOB OPERATOR_HEADERS ${OPERATOR_HEADERS} torchvision/csrc/cuda/*.h) + file(GLOB OPERATOR_SOURCES ${OPERATOR_SOURCES} ${OPERATOR_HEADERS} torchvision/csrc/cuda/*.cu) endif() file(GLOB MODELS_HEADERS torchvision/csrc/models/*.h) file(GLOB MODELS_SOURCES torchvision/csrc/models/*.h torchvision/csrc/models/*.cpp) @@ -95,11 +97,11 @@ install(EXPORT TorchVisionTargets install(FILES ${HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}) install(FILES - torchvision/csrc/cpu/vision_cpu.h + ${OPERATOR_HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cpu) if(WITH_CUDA) install(FILES - torchvision/csrc/cuda/vision_cuda.h + ${OPERATOR_HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/cuda) endif() install(FILES ${MODELS_HEADERS} DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/${PROJECT_NAME}/models) diff --git a/torchvision/csrc/cpu/nms_cpu.cpp b/torchvision/csrc/cpu/nms_kernel.cpp similarity index 95% rename from torchvision/csrc/cpu/nms_cpu.cpp rename to torchvision/csrc/cpu/nms_kernel.cpp index 00a4c61db7a..036a91f56dc 100644 --- a/torchvision/csrc/cpu/nms_cpu.cpp +++ b/torchvision/csrc/cpu/nms_kernel.cpp @@ -1,7 +1,9 @@ -#include "vision_cpu.h" +#include "nms_kernel.h" + +namespace { template -at::Tensor nms_cpu_kernel( +at::Tensor nms_kernel( const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { @@ -69,6 +71,8 @@ at::Tensor nms_cpu_kernel( return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); } +} // namespace + at::Tensor nms_cpu( const at::Tensor& dets, const at::Tensor& scores, @@ -95,7 +99,7 @@ at::Tensor nms_cpu( auto result = at::empty({0}, dets.options()); AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] { - result = nms_cpu_kernel(dets, scores, iou_threshold); + result = nms_kernel(dets, scores, iou_threshold); }); return result; } diff --git a/torchvision/csrc/cpu/nms_kernel.h b/torchvision/csrc/cpu/nms_kernel.h new file mode 100644 index 00000000000..7b6ef442626 --- /dev/null +++ b/torchvision/csrc/cpu/nms_kernel.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API at::Tensor nms_cpu( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); diff --git a/torchvision/csrc/cpu/vision_cpu.h b/torchvision/csrc/cpu/vision_cpu.h index 6f85d9c0256..39d89bf6515 100644 --- a/torchvision/csrc/cpu/vision_cpu.h +++ b/torchvision/csrc/cpu/vision_cpu.h @@ -4,11 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API at::Tensor nms_cpu( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - VISION_API std::tuple PSROIAlign_forward_cpu( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/cuda/nms_cuda.cu b/torchvision/csrc/cuda/nms_kernel.cu similarity index 96% rename from torchvision/csrc/cuda/nms_cuda.cu rename to torchvision/csrc/cuda/nms_kernel.cu index 548dc2f69cb..8785bd84897 100644 --- a/torchvision/csrc/cuda/nms_cuda.cu +++ b/torchvision/csrc/cuda/nms_kernel.cu @@ -3,14 +3,17 @@ #include #include "cuda_helpers.h" +#include "nms_kernel.h" -#include -#include +namespace { int const threadsPerBlock = sizeof(unsigned long long) * 8; template -__device__ inline bool devIoU(T const* const a, T const* const b, const float threshold) { +__device__ inline bool devIoU( + T const* const a, + T const* const b, + const float threshold) { T left = max(a[0], b[0]), right = min(a[2], b[2]); T top = max(a[1], b[1]), bottom = min(a[3], b[3]); T width = max(right - left, (T)0), height = max(bottom - top, (T)0); @@ -29,7 +32,8 @@ __global__ void nms_kernel( const int row_start = blockIdx.y; const int col_start = blockIdx.x; - if (row_start > col_start) return; + if (row_start > col_start) + return; const int row_size = min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); @@ -68,6 +72,8 @@ __global__ void nms_kernel( } } +} // namespace + at::Tensor nms_cuda(const at::Tensor& dets, const at::Tensor& scores, double iou_threshold) { diff --git a/torchvision/csrc/cuda/nms_kernel.h b/torchvision/csrc/cuda/nms_kernel.h new file mode 100644 index 00000000000..1eceddaccf3 --- /dev/null +++ b/torchvision/csrc/cuda/nms_kernel.h @@ -0,0 +1,9 @@ +#pragma once + +#include +#include "../macros.h" + +VISION_API at::Tensor nms_cuda( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold); diff --git a/torchvision/csrc/cuda/vision_cuda.h b/torchvision/csrc/cuda/vision_cuda.h index 834973c5327..b17f00d6acf 100644 --- a/torchvision/csrc/cuda/vision_cuda.h +++ b/torchvision/csrc/cuda/vision_cuda.h @@ -4,11 +4,6 @@ // TODO: Delete this file once all the methods are gone -VISION_API at::Tensor nms_cuda( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold); - VISION_API std::tuple PSROIAlign_forward_cuda( const at::Tensor& input, const at::Tensor& rois, diff --git a/torchvision/csrc/nms.cpp b/torchvision/csrc/nms.cpp new file mode 100644 index 00000000000..075f3101937 --- /dev/null +++ b/torchvision/csrc/nms.cpp @@ -0,0 +1,29 @@ +#include "nms.h" +#include + +#if defined(WITH_CUDA) || defined(WITH_HIP) +#include +#endif + +at::Tensor nms( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::nms", "") + .typed(); + return op.call(dets, scores, iou_threshold); +} + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor nms_autocast( + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); + return nms( + at::autocast::cached_cast(at::kFloat, dets), + at::autocast::cached_cast(at::kFloat, scores), + iou_threshold); +} +#endif diff --git a/torchvision/csrc/nms.h b/torchvision/csrc/nms.h index aed675e5d26..87b07548454 100644 --- a/torchvision/csrc/nms.h +++ b/torchvision/csrc/nms.h @@ -1,36 +1,24 @@ #pragma once -#include "cpu/vision_cpu.h" +#include "cpu/nms_kernel.h" #ifdef WITH_CUDA -#include "autocast.h" -#include "cuda/vision_cuda.h" +#include "cuda/nms_kernel.h" #endif #ifdef WITH_HIP -#include "autocast.h" -#include "hip/vision_cuda.h" +#include "hip/nms_kernel.h" #endif -// nms dispatch nexus +// C++ Forward at::Tensor nms( const at::Tensor& dets, const at::Tensor& scores, - double iou_threshold) { - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("torchvision::nms", "") - .typed(); - return op.call(dets, scores, iou_threshold); -} + double iou_threshold); +// Autocast Forward #if defined(WITH_CUDA) || defined(WITH_HIP) at::Tensor nms_autocast( const at::Tensor& dets, const at::Tensor& scores, - double iou_threshold) { - c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast); - return nms( - at::autocast::cached_cast(at::kFloat, dets), - at::autocast::cached_cast(at::kFloat, scores), - iou_threshold); -} + double iou_threshold); #endif