diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 0d95361eedb..5f1dabfaa85 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -27,7 +27,7 @@ def _resize_image_and_masks_onnx(image, self_min_size, self_max_size, target): if "masks" in target: mask = target["masks"] - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() target["masks"] = mask return image, target @@ -49,7 +49,7 @@ def _resize_image_and_masks(image, self_min_size, self_max_size, target): if "masks" in target: mask = target["masks"] - mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor)[:, 0].byte() + mask = F.interpolate(mask[:, None].float(), scale_factor=scale_factor, recompute_scale_factor=True)[:, 0].byte() target["masks"] = mask return image, target