Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Allow autocast for 1.6 #2384

Merged
merged 13 commits into from
Jul 9, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
23 changes: 23 additions & 0 deletions torchvision/csrc/ROIAlign.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#include "autocast.h"
fmassa marked this conversation as resolved.
Show resolved Hide resolved
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif

// TODO: put this stuff in torchvision namespace

// roi_align dispatch nexus
at::Tensor roi_align(
const at::Tensor& input, // Input feature map.
const at::Tensor& rois, // List of ROIs to pool over.
Expand All @@ -35,6 +37,27 @@ at::Tensor roi_align(
aligned);
}

#ifdef WITH_CUDA
at::Tensor ROIAlign_autocast(
const at::Tensor& input,
const at::Tensor& rois,
const double spatial_scale,
const int64_t pooled_height,
const int64_t pooled_width,
const int64_t sampling_ratio,
const bool aligned) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return roi_align(
autocast::_cast(at::kFloat, input),
autocast::_cast(at::kFloat, rois),
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
aligned).to(input.scalar_type());
}
#endif

at::Tensor _roi_align_backward(
const at::Tensor& grad,
const at::Tensor& rois,
Expand Down
26 changes: 26 additions & 0 deletions torchvision/csrc/autocast.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#pragma once

#ifdef WITH_CUDA
namespace autocast {

inline bool is_eligible(const at::Tensor& arg) {
return (arg.is_cuda() && arg.is_floating_point() && (arg.scalar_type() != at::kDouble));
}

// Overload to catch Tensor args
inline at::Tensor _cast(at::ScalarType to_type, const at::Tensor& arg) {
if (is_eligible(arg) && (arg.scalar_type() != to_type)) {
return arg.to(to_type);
} else {
return arg;
}
}

// Template to catch non-Tensor args
template<typename T>
inline T _cast(at::ScalarType to_type, T arg) {
return arg;
}

}
#endif
23 changes: 21 additions & 2 deletions torchvision/csrc/cpu/nms_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ template <typename scalar_t>
at::Tensor nms_cpu_kernel(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold) {
const double iou_threshold) {
AT_ASSERTM(!dets.is_cuda(), "dets must be a CPU tensor");
AT_ASSERTM(!scores.is_cuda(), "scores must be a CPU tensor");
AT_ASSERTM(
Expand Down Expand Up @@ -72,7 +72,26 @@ at::Tensor nms_cpu_kernel(
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold) {
const double iou_threshold) {
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0));

auto result = at::empty({0}, dets.options());

AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms", [&] {
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/cpu/vision_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ at::Tensor PSROIAlign_backward_cpu(
at::Tensor nms_cpu(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
const double iou_threshold);

at::Tensor DeformConv2d_forward_cpu(
const at::Tensor& input,
Expand Down
37 changes: 36 additions & 1 deletion torchvision/csrc/cuda/nms_cuda.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>

#if defined(WITH_CUDA)
#include <c10/cuda/CUDAGuard.h>
#elif defined(WITH_HIP)
#include <c10/hip/HIPGuard.h>
#endif

#include "cuda_helpers.h"

Expand Down Expand Up @@ -70,10 +75,40 @@ __global__ void nms_kernel(

at::Tensor nms_cuda(const at::Tensor& dets,
const at::Tensor& scores,
float iou_threshold) {
const double iou_threshold) {
AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor");
AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor");

TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0))

#if defined(WITH_CUDA)
at::cuda::CUDAGuard device_guard(dets.device());
#elif defined(WITH_HIP)
at::cuda::HIPGuard device_guard(dets.device());
#else
AT_ERROR("Not compiled with GPU support");
fmassa marked this conversation as resolved.
Show resolved Hide resolved
#endif

if (dets.numel() == 0) {
return at::empty({0}, dets.options().dtype(at::kLong));
}

auto order_t = std::get<1>(scores.sort(0, /* descending=*/true));
auto dets_sorted = dets.index_select(0, order_t).contiguous();
Expand Down
7 changes: 1 addition & 6 deletions torchvision/csrc/cuda/vision_cuda.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
#pragma once
#if defined(WITH_CUDA)
#include <c10/cuda/CUDAGuard.h>
#elif defined(WITH_HIP)
#include <c10/hip/HIPGuard.h>
#endif
#include <torch/extension.h>

at::Tensor ROIAlign_forward_cuda(
Expand Down Expand Up @@ -90,7 +85,7 @@ at::Tensor PSROIAlign_backward_cuda(
at::Tensor nms_cuda(
const at::Tensor& dets,
const at::Tensor& scores,
const float iou_threshold);
const double iou_threshold);

at::Tensor DeformConv2d_forward_cuda(
const at::Tensor& input,
Expand Down
58 changes: 21 additions & 37 deletions torchvision/csrc/nms.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,51 +3,35 @@

#ifdef WITH_CUDA
#include "cuda/vision_cuda.h"
#include "autocast.h"
#endif
#ifdef WITH_HIP
#include "hip/vision_cuda.h"
#endif

// nms dispatch nexus
at::Tensor nms(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
TORCH_CHECK(
dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D");
TORCH_CHECK(
dets.size(1) == 4,
"boxes should have 4 elements in dimension 1, got ",
dets.size(1));
TORCH_CHECK(
scores.dim() == 1,
"scores should be a 1d tensor, got ",
scores.dim(),
"D");
TORCH_CHECK(
dets.size(0) == scores.size(0),
"boxes and scores should have same number of elements in ",
"dimension 0, got ",
dets.size(0),
" and ",
scores.size(0));
if (dets.is_cuda()) {
#if defined(WITH_CUDA)
if (dets.numel() == 0) {
at::cuda::CUDAGuard device_guard(dets.device());
return at::empty({0}, dets.options().dtype(at::kLong));
}
return nms_cuda(dets, scores, iou_threshold);
#elif defined(WITH_HIP)
if (dets.numel() == 0) {
at::cuda::HIPGuard device_guard(dets.device());
return at::empty({0}, dets.options().dtype(at::kLong));
}
return nms_cuda(dets, scores, iou_threshold);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::nms", "")
.typed<decltype(nms)>();
return op.call(
dets,
scores,
iou_threshold);
}

at::Tensor result = nms_cpu(dets, scores, iou_threshold);
return result;
#ifdef WITH_CUDA
at::Tensor nms_autocast(
const at::Tensor& dets,
const at::Tensor& scores,
const double iou_threshold) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(c10::DispatchKey::Autocast);
return nms(
autocast::_cast(at::kFloat, dets),
autocast::_cast(at::kFloat, scores),
iou_threshold);
}
#endif
12 changes: 11 additions & 1 deletion torchvision/csrc/vision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ int64_t _cuda_version() {
}

TORCH_LIBRARY(torchvision, m) {
m.def("nms", &nms);
m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor");
m.def(
"roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, bool aligned) -> Tensor");
m.def(
Expand All @@ -59,13 +59,23 @@ TORCH_LIBRARY(torchvision, m) {
TORCH_LIBRARY_IMPL(torchvision, CPU, m) {
m.impl("roi_align", ROIAlign_forward_cpu);
m.impl("_roi_align_backward", ROIAlign_backward_cpu);
m.impl("nms", nms_cpu);
}

// TODO: Place this in a hypothetical separate torchvision_cuda library
#if defined(WITH_CUDA) || defined(WITH_HIP)
TORCH_LIBRARY_IMPL(torchvision, CUDA, m) {
m.impl("roi_align", ROIAlign_forward_cuda);
m.impl("_roi_align_backward", ROIAlign_backward_cuda);
m.impl("nms", nms_cuda);
}
#endif

// Autocast only needs to wrap forward pass ops.
#if defined(WITH_CUDA)
TORCH_LIBRARY_IMPL(torchvision, Autocast, m) {
m.impl("roi_align", ROIAlign_autocast);
m.impl("nms", nms_autocast);
}
#endif

Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def forward(
if torchvision._is_tracing():
tracing_results.append(result_idx_in_level.to(dtype))
else:
result[idx_in_level] = result_idx_in_level
result[idx_in_level] = result_idx_in_level.to(result.dtype)
fmassa marked this conversation as resolved.
Show resolved Hide resolved

if torchvision._is_tracing():
result = _onnx_merge_levels(levels, tracing_results)
Expand Down