Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSRoiAlign: SymInt support + meta-implem #8058

Merged
merged 10 commits into from
Oct 27, 2023
26 changes: 26 additions & 0 deletions test/optests_failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,32 @@
"_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit",
"_version": 1,
"data": {
"torchvision::ps_roi_align": {
"TestPSRoIAlign.test_aot_dispatch_dynamic__test_mps_error_inputs": {
"comment": "RuntimeError: MPS does not support ps_roi_align backward with float16 inputs",
"status": "xfail"
},
"TestPSRoIAlign.test_autograd_registration__test_backward[True-mps-0]": {
"comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}",
"status": "xfail"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @zou3519

I just opened pytorch/pytorch#111797 which I believe could be a fix for the problem I'm facing here:

This test (and a bunch of others)

vision/test/test_ops.py

Lines 186 to 189 in 3fb88b3

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
def test_backward(self, seed, device, contiguous, deterministic=False):

is parametrized over cpu, CUDA and MPS, and fails on MPS.

I'm trying to xfail the MPS parametrization as above (lines 10-12), but optests complains with:

E RuntimeError: In failures dict, got test name 'TestPSRoIAlign.test_autograd_registration__test_backward[True-mps-0]'. We parsed this as running test 'test_autograd_registration' on 'test_backward[True-mps-0]', but test_backward[True-mps-0] does not exist on the TestCase 'TestPSRoIAlign]. Maybe you need to change the test name?

The problem is, if I replace TestPSRoIAlign.test_autograd_registration__test_backward[True-mps-0] with TestPSRoIAlign.test_autograd_registration__test_backward in line 10, then I'm getting and "unexpected success":

E torch.testing._internal.optests.generate_tests.OpCheckError: generate_opcheck_tests: Unexpected success for operator torchvision::ps_roi_align on test TestPSRoIAlign.test_autograd_registration__test_backward. This may mean that you have fixed this test failure. Please rerun the test with PYTORCH_OPCHECK_ACCEPT=1 to automatically update the test runner or manually remove the expected failure in the failure dict at /home/nicolashug/dev/vision/test/optests_failures_dict.jsonFor more details, see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@NicolasHug can the tests be run under unittest, or are they pytest only?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pytest only

},
"TestPSRoIAlign.test_autograd_registration__test_mps_error_inputs": {
"comment": "NotImplementedError: autograd_registration_check: NYI devices other than CPU/CUDA, got {'mps'}",
"status": "xfail"
},
"TestPSRoIAlign.test_faketensor__test_backward[True-mps-0]": {
"comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!",
"status": "xfail"
},
"TestPSRoIAlign.test_faketensor__test_forward[x_dtype0-True-mps]": {
"comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!",
"status": "xfail"
},
"TestPSRoIAlign.test_faketensor__test_mps_error_inputs": {
"comment": "AssertionError: Dtypes torch.int64 and torch.int32 are not equal!",
"status": "xfail"
}
},
"torchvision::roi_align": {
"TestRoIAlign.test_aot_dispatch_dynamic__test_mps_error_inputs": {
"comment": "RuntimeError: MPS does not support roi_align backward with float16 inputs",
Expand Down
14 changes: 13 additions & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,9 @@ class RoIOpTester(ABC):
torch.float32,
torch.float64,
),
ids=str,
# ids=str,
)
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, x_dtype, rois_dtype=None, deterministic=False, **kwargs):
if device == "mps" and x_dtype is torch.float64:
pytest.skip("MPS does not support float64")
Expand Down Expand Up @@ -186,6 +187,7 @@ def test_torch_fx_trace(self, device, x_dtype=torch.float, rois_dtype=torch.floa
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("device", cpu_and_cuda_and_mps())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.opcheck_only_one()
def test_backward(self, seed, device, contiguous, deterministic=False):
atol = self.mps_backward_atol if device == "mps" else 1e-05
dtype = self.mps_dtype if device == "mps" else self.dtype
Expand Down Expand Up @@ -228,6 +230,7 @@ def func(z):
@needs_cuda
@pytest.mark.parametrize("x_dtype", (torch.float, torch.half))
@pytest.mark.parametrize("rois_dtype", (torch.float, torch.half))
@pytest.mark.opcheck_only_one()
def test_autocast(self, x_dtype, rois_dtype):
with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)
Expand Down Expand Up @@ -659,6 +662,15 @@ def test_boxes_shape(self):
self._helper_boxes_shape(ops.ps_roi_align)


optests.generate_opcheck_tests(
testcase=TestPSRoIAlign,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)


class TestMultiScaleRoIAlign:
def make_obj(self, fmap_names=None, output_size=(7, 7), sampling_ratio=2, wrap=False):
if fmap_names is None:
Expand Down
47 changes: 46 additions & 1 deletion torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def meta_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, samp
),
)
num_rois = rois.size(0)
_, channels, height, width = input.size()
channels = input.size(1)
return input.new_empty((num_rois, channels, pooled_height, pooled_width))


Expand All @@ -51,6 +51,51 @@ def meta_roi_align_backward(
return grad.new_empty((batch_size, channels, height, width))


@register_meta("ps_roi_align")
def meta_ps_roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio):
torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]")
torch._check(
input.dtype == rois.dtype,
lambda: (
"Expected tensor for input to have the same type as tensor for rois; "
f"but type {input.dtype} does not equal {rois.dtype}"
),
)
channels = input.size(1)
torch._check(
channels % (pooled_height * pooled_width) == 0,
"input channels must be a multiple of pooling height * pooling width",
)

num_rois = rois.size(0)
out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width)
return input.new_empty(out_size), torch.empty(out_size, dtype=torch.int32, device="meta")


@register_meta("_ps_roi_align_backward")
def meta_ps_roi_align_backward(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width,
):
torch._check(
grad.dtype == rois.dtype,
lambda: (
"Expected tensor for grad to have the same type as tensor for rois; "
f"but type {grad.dtype} does not equal {rois.dtype}"
),
)
return grad.new_empty((batch_size, channels, height, width))


@torch._custom_ops.impl_abstract("torchvision::nms")
def meta_nms(dets, scores, iou_threshold):
torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D")
Expand Down
54 changes: 27 additions & 27 deletions torchvision/csrc/ops/autograd/ps_roi_align_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@ class PSROIAlignFunction
const torch::autograd::Variable& input,
const torch::autograd::Variable& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio) {
ctx->saved_data["spatial_scale"] = spatial_scale;
ctx->saved_data["pooled_height"] = pooled_height;
ctx->saved_data["pooled_width"] = pooled_width;
ctx->saved_data["sampling_ratio"] = sampling_ratio;
ctx->saved_data["input_shape"] = input.sizes();
ctx->saved_data["input_shape"] = input.sym_sizes();
at::AutoDispatchBelowADInplaceOrView g;
auto result = ps_roi_align(
auto result = ps_roi_align_symint(
input,
rois,
spatial_scale,
Expand All @@ -48,19 +48,19 @@ class PSROIAlignFunction
auto saved = ctx->get_saved_variables();
auto rois = saved[0];
auto channel_mapping = saved[1];
auto input_shape = ctx->saved_data["input_shape"].toIntList();
auto grad_in = detail::_ps_roi_align_backward(
auto input_shape = ctx->saved_data["input_shape"].toList();
auto grad_in = detail::_ps_roi_align_backward_symint(
grad_output[0],
rois,
channel_mapping,
ctx->saved_data["spatial_scale"].toDouble(),
ctx->saved_data["pooled_height"].toInt(),
ctx->saved_data["pooled_width"].toInt(),
ctx->saved_data["pooled_height"].toSymInt(),
ctx->saved_data["pooled_width"].toSymInt(),
ctx->saved_data["sampling_ratio"].toInt(),
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3]);
input_shape[0].get().toSymInt(),
input_shape[1].get().toSymInt(),
input_shape[2].get().toSymInt(),
input_shape[3].get().toSymInt());

return {
grad_in,
Expand All @@ -82,15 +82,15 @@ class PSROIAlignBackwardFunction
const torch::autograd::Variable& rois,
const torch::autograd::Variable& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
at::AutoDispatchBelowADInplaceOrView g;
auto grad_in = detail::_ps_roi_align_backward(
auto grad_in = detail::_ps_roi_align_backward_symint(
grad,
rois,
channel_mapping,
Expand All @@ -117,8 +117,8 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align_autograd(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio) {
auto result = PSROIAlignFunction::apply(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
Expand All @@ -131,13 +131,13 @@ at::Tensor ps_roi_align_backward_autograd(
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
int64_t pooled_height,
int64_t pooled_width,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
int64_t batch_size,
int64_t channels,
int64_t height,
int64_t width) {
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
return PSROIAlignBackwardFunction::apply(
grad,
rois,
Expand Down
49 changes: 47 additions & 2 deletions torchvision/csrc/ops/ps_roi_align.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ std::tuple<at::Tensor, at::Tensor> ps_roi_align(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

std::tuple<at::Tensor, at::Tensor> ps_roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_align.ps_roi_align");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::ps_roi_align", "")
.typed<decltype(ps_roi_align_symint)>();
return op.call(
input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}

namespace detail {

at::Tensor _ps_roi_align_backward(
Expand Down Expand Up @@ -54,13 +69,43 @@ at::Tensor _ps_roi_align_backward(
width);
}

at::Tensor _ps_roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_ps_roi_align_backward", "")
.typed<decltype(_ps_roi_align_backward_symint)>();
return op.call(
grad,
rois,
channel_mapping,
spatial_scale,
pooled_height,
pooled_width,
sampling_ratio,
batch_size,
channels,
height,
width);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
"torchvision::ps_roi_align(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int sampling_ratio, int batch_size, int channels, int height, int width) -> Tensor"));
"torchvision::_ps_roi_align_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, int sampling_ratio, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor"));
}

} // namespace ops
Expand Down
21 changes: 21 additions & 0 deletions torchvision/csrc/ops/ps_roi_align.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@ VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align(
int64_t pooled_width,
int64_t sampling_ratio);

VISION_API std::tuple<at::Tensor, at::Tensor> ps_roi_align_symint(
const at::Tensor& input,
const at::Tensor& rois,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio);

namespace detail {

at::Tensor _ps_roi_align_backward(
Expand All @@ -29,6 +37,19 @@ at::Tensor _ps_roi_align_backward(
int64_t height,
int64_t width);

at::Tensor _ps_roi_align_backward_symint(
const at::Tensor& grad,
const at::Tensor& rois,
const at::Tensor& channel_mapping,
double spatial_scale,
c10::SymInt pooled_height,
c10::SymInt pooled_width,
int64_t sampling_ratio,
c10::SymInt batch_size,
c10::SymInt channels,
c10::SymInt height,
c10::SymInt width);

} // namespace detail

} // namespace ops
Expand Down