Skip to content

Commit

Permalink
DeformConv2d: SymInt support + meta-implem + opchecks (#8063)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Oct 30, 2023
1 parent 668348e commit a8ebd0b
Show file tree
Hide file tree
Showing 5 changed files with 218 additions and 45 deletions.
13 changes: 13 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ def test_is_leaf_node(self, device):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, batch_sz, dtype=None):
dtype = dtype or self.dtype
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
Expand Down Expand Up @@ -1071,6 +1072,7 @@ def test_wrong_sizes(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_backward(self, device, contiguous, batch_sz):
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
device, contiguous, batch_sz, self.dtype
Expand Down Expand Up @@ -1120,6 +1122,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):

@needs_cuda
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.opcheck_only_one()
def test_compare_cpu_cuda_grads(self, contiguous):
# Test from https://github.com/pytorch/vision/issues/2598
# Run on CUDA only
Expand Down Expand Up @@ -1154,6 +1157,7 @@ def test_compare_cpu_cuda_grads(self, contiguous):
@needs_cuda
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
@pytest.mark.opcheck_only_one()
def test_autocast(self, batch_sz, dtype):
with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
Expand All @@ -1163,6 +1167,15 @@ def test_forward_scriptability(self):
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))


optests.generate_opcheck_tests(
testcase=TestDeformConv,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)


class TestFrozenBNT:
def test_frozenbatchnorm2d_repr(self):
num_features = 32
Expand Down
51 changes: 51 additions & 0 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,54 @@ def meta_nms(dets, scores, iou_threshold):
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)


@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
):

out_height, out_width = offset.shape[-2:]
out_channels = weight.shape[0]
batch_size = input.shape[0]
return input.new_empty((batch_size, out_channels, out_height, out_width))


@register_meta("_deform_conv2d_backward")
def meta_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):

grad_input = input.new_empty(input.shape)
grad_weight = weight.new_empty(weight.shape)
grad_offset = offset.new_empty(offset.shape)
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
86 changes: 43 additions & 43 deletions torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ class DeformConv2dFunction
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g;
auto output = deform_conv2d(
auto output = deform_conv2d_symint(
input,
weight,
offset,
Expand Down Expand Up @@ -70,17 +70,17 @@ class DeformConv2dFunction
auto mask = saved[3];
auto bias = saved[4];

auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto stride_h = ctx->saved_data["stride_h"].toSymInt();
auto stride_w = ctx->saved_data["stride_w"].toSymInt();
auto pad_h = ctx->saved_data["pad_h"].toSymInt();
auto pad_w = ctx->saved_data["pad_w"].toSymInt();
auto dilation_h = ctx->saved_data["dilation_h"].toSymInt();
auto dilation_w = ctx->saved_data["dilation_w"].toSymInt();
auto groups = ctx->saved_data["groups"].toSymInt();
auto offset_groups = ctx->saved_data["offset_groups"].toSymInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();

auto grads = detail::_deform_conv2d_backward(
auto grads = detail::_deform_conv2d_backward_symint(
grad_output[0],
input,
weight,
Expand Down Expand Up @@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_deform_conv2d_backward(
auto result = detail::_deform_conv2d_backward_symint(
grad,
input,
weight,
Expand Down Expand Up @@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd(
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
Expand All @@ -222,14 +222,14 @@ deform_conv2d_backward_autograd(
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
Expand Down
79 changes: 77 additions & 2 deletions torchvision/csrc/ops/deform_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,42 @@ at::Tensor deform_conv2d(
use_mask);
}

at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d_symint)>();
return op.call(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}

namespace detail {

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand Down Expand Up @@ -84,13 +120,52 @@ _deform_conv2d_backward(
use_mask);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward_symint)>();
return op.call(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
}

} // namespace ops
Expand Down
34 changes: 34 additions & 0 deletions torchvision/csrc/ops/deform_conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ VISION_API at::Tensor deform_conv2d(
int64_t offset_groups,
bool use_mask);

VISION_API at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);

namespace detail {

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand All @@ -42,6 +58,24 @@ _deform_conv2d_backward(
int64_t offset_groups,
bool use_mask);

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);

} // namespace detail

} // namespace ops
Expand Down

0 comments on commit a8ebd0b

Please sign in to comment.