diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 71893b3b0d0..93e966bae4b 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -188,16 +188,18 @@ def __init__( keypoint_roi_pool=None, keypoint_head=None, keypoint_predictor=None, - num_keypoints=17, + num_keypoints=None, ): assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) if min_size is None: min_size = (640, 672, 704, 736, 768, 800) - if num_classes is not None: + if num_keypoints is not None: if keypoint_predictor is not None: - raise ValueError("num_classes should be None when keypoint_predictor is specified") + raise ValueError("num_keypoints should be None when keypoint_predictor is specified") + else: + num_keypoints = 17 out_channels = backbone.out_channels