From 977baea68a7d19d518f2f440ea0b64d00e305027 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 14 Sep 2022 22:57:58 +0100 Subject: [PATCH] Fix the error message of `_ovewrite_value_param` --- torchvision/models/_utils.py | 10 +++++----- torchvision/models/detection/faster_rcnn.py | 6 +++--- torchvision/models/detection/fcos.py | 2 +- torchvision/models/detection/keypoint_rcnn.py | 4 ++-- torchvision/models/detection/mask_rcnn.py | 4 ++-- torchvision/models/detection/retinanet.py | 4 ++-- torchvision/models/detection/ssd.py | 2 +- torchvision/models/detection/ssdlite.py | 2 +- torchvision/models/segmentation/deeplabv3.py | 12 ++++++------ torchvision/models/segmentation/fcn.py | 8 ++++---- torchvision/models/segmentation/lraspp.py | 2 +- 11 files changed, 28 insertions(+), 28 deletions(-) diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 5d930e60295..7c67ea4ec91 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -240,11 +240,11 @@ def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> N kwargs[param] = new_value -def _ovewrite_value_param(param: Optional[V], new_value: V) -> V: - if param is not None: - if param != new_value: - raise ValueError(f"The parameter '{param}' expected value {new_value} but got {param} instead.") - return new_value +def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V: + if actual is not None: + if actual != expected: + raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.") + return expected class _ModelURLs(dict): diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index f5cf588cc9d..9d99fd236c7 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -550,7 +550,7 @@ def fasterrcnn_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 @@ -621,7 +621,7 @@ def fasterrcnn_resnet50_fpn_v2( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 @@ -661,7 +661,7 @@ def _fasterrcnn_mobilenet_v3_large_fpn( ) -> FasterRCNN: if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 73c9a6e042d..2ac71c339a4 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -749,7 +749,7 @@ def fcos_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 21fb53c2a49..c19dd21a5ce 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -444,8 +444,8 @@ def keypointrcnn_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - num_keypoints = _ovewrite_value_param(num_keypoints, len(weights.meta["keypoint_names"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + num_keypoints = _ovewrite_value_param("num_keypoints", num_keypoints, len(weights.meta["keypoint_names"])) else: if num_classes is None: num_classes = 2 diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 16ad074e189..795f9b8f79c 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -484,7 +484,7 @@ def maskrcnn_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 @@ -552,7 +552,7 @@ def maskrcnn_resnet50_fpn_v2( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index e8df41926e8..ffa21b14f70 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -795,7 +795,7 @@ def retinanet_resnet50_fpn( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 @@ -868,7 +868,7 @@ def retinanet_resnet50_fpn_v2( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index c30e508f488..44102f7ac5a 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -649,7 +649,7 @@ def ssd300_vgg16( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 63ac0d2bc73..d34795d7286 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -268,7 +268,7 @@ def ssdlite320_mobilenet_v3_large( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 91 diff --git a/torchvision/models/segmentation/deeplabv3.py b/torchvision/models/segmentation/deeplabv3.py index 3e451a21aaf..29ab0154807 100644 --- a/torchvision/models/segmentation/deeplabv3.py +++ b/torchvision/models/segmentation/deeplabv3.py @@ -260,8 +260,8 @@ def deeplabv3_resnet50( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) elif num_classes is None: num_classes = 21 @@ -316,8 +316,8 @@ def deeplabv3_resnet101( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) elif num_classes is None: num_classes = 21 @@ -370,8 +370,8 @@ def deeplabv3_mobilenet_v3_large( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) elif num_classes is None: num_classes = 21 diff --git a/torchvision/models/segmentation/fcn.py b/torchvision/models/segmentation/fcn.py index 5ec0747b710..6f1c9c4b80b 100644 --- a/torchvision/models/segmentation/fcn.py +++ b/torchvision/models/segmentation/fcn.py @@ -155,8 +155,8 @@ def fcn_resnet50( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) elif num_classes is None: num_classes = 21 @@ -214,8 +214,8 @@ def fcn_resnet101( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) - aux_loss = _ovewrite_value_param(aux_loss, True) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) + aux_loss = _ovewrite_value_param("aux_loss", aux_loss, True) elif num_classes is None: num_classes = 21 diff --git a/torchvision/models/segmentation/lraspp.py b/torchvision/models/segmentation/lraspp.py index 4bf71e77ae2..44c96f1c272 100644 --- a/torchvision/models/segmentation/lraspp.py +++ b/torchvision/models/segmentation/lraspp.py @@ -163,7 +163,7 @@ def lraspp_mobilenet_v3_large( if weights is not None: weights_backbone = None - num_classes = _ovewrite_value_param(num_classes, len(weights.meta["categories"])) + num_classes = _ovewrite_value_param("num_classes", num_classes, len(weights.meta["categories"])) elif num_classes is None: num_classes = 21