diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 5564866c571..9059c184949 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -20,7 +20,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): scale_factor = torch.min(self_min_size / min_size, self_max_size / max_size) image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', + image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)[0] if target is None: @@ -42,7 +42,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target): if max_size * scale_factor > self_max_size: scale_factor = self_max_size / max_size image = torch.nn.functional.interpolate( - image[None], scale_factor=scale_factor, mode='bilinear', + image[None], scale_factor=scale_factor, mode='bilinear', recompute_scale_factor=True, align_corners=False)[0] if target is None: