diff --git a/references/classification/transforms.py b/references/classification/transforms.py index e72cd67fbfd..9a8ef7877d6 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -21,8 +21,14 @@ class RandomMixup(torch.nn.Module): def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() - assert num_classes > 0, "Please provide a valid positive value for the num_classes." - assert alpha > 0, "Alpha param can't be zero." + + if num_classes < 1: + raise ValueError( + f"Please provide a valid positive value for the num_classes. Got num_classes={num_classes}" + ) + + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") self.num_classes = num_classes self.p = p @@ -99,8 +105,10 @@ class RandomCutmix(torch.nn.Module): def __init__(self, num_classes: int, p: float = 0.5, alpha: float = 1.0, inplace: bool = False) -> None: super().__init__() - assert num_classes > 0, "Please provide a valid positive value for the num_classes." - assert alpha > 0, "Alpha param can't be zero." + if num_classes < 1: + raise ValueError("Please provide a valid positive value for the num_classes.") + if alpha <= 0: + raise ValueError("Alpha param can't be zero.") self.num_classes = num_classes self.p = p diff --git a/references/detection/coco_eval.py b/references/detection/coco_eval.py index ec0709c5d91..ba1359f8c65 100644 --- a/references/detection/coco_eval.py +++ b/references/detection/coco_eval.py @@ -12,7 +12,8 @@ class CocoEvaluator: def __init__(self, coco_gt, iou_types): - assert isinstance(iou_types, (list, tuple)) + if not isinstance(iou_types, (list, tuple)): + raise TypeError(f"This constructor expects iou_types of type list or tuple, instead got {type(iou_types)}") coco_gt = copy.deepcopy(coco_gt) self.coco_gt = coco_gt diff --git a/references/detection/coco_utils.py b/references/detection/coco_utils.py index b0f193135ee..396de63297b 100644 --- a/references/detection/coco_utils.py +++ b/references/detection/coco_utils.py @@ -126,7 +126,10 @@ def _has_valid_annotation(anno): return True return False - assert isinstance(dataset, torchvision.datasets.CocoDetection) + if not isinstance(dataset, torchvision.datasets.CocoDetection): + raise TypeError( + f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" + ) ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) diff --git a/references/optical_flow/transforms.py b/references/optical_flow/transforms.py index b6a42f402e1..6011608183a 100644 --- a/references/optical_flow/transforms.py +++ b/references/optical_flow/transforms.py @@ -7,16 +7,21 @@ class ValidateModelInput(torch.nn.Module): # Pass-through transform that checks the shape and dtypes to make sure the model gets what it expects def forward(self, img1, img2, flow, valid_flow_mask): - assert all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None) - assert all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None) + if not all(isinstance(arg, torch.Tensor) for arg in (img1, img2, flow, valid_flow_mask) if arg is not None): + raise TypeError("This method expects all input arguments to be of type torch.Tensor.") + if not all(arg.dtype == torch.float32 for arg in (img1, img2, flow) if arg is not None): + raise TypeError("This method expects the tensors img1, img2 and flow of be of dtype torch.float32.") - assert img1.shape == img2.shape + if img1.shape != img2.shape: + raise ValueError("img1 and img2 should have the same shape.") h, w = img1.shape[-2:] - if flow is not None: - assert flow.shape == (2, h, w) + if flow is not None and flow.shape != (2, h, w): + raise ValueError(f"flow.shape should be (2, {h}, {w}) instead of {flow.shape}") if valid_flow_mask is not None: - assert valid_flow_mask.shape == (h, w) - assert valid_flow_mask.dtype == torch.bool + if valid_flow_mask.shape != (h, w): + raise ValueError(f"valid_flow_mask.shape should be ({h}, {w}) instead of {valid_flow_mask.shape}") + if valid_flow_mask.dtype != torch.bool: + raise TypeError("valid_flow_mask should be of dtype torch.bool instead of {valid_flow_mask.dtype}") return img1, img2, flow, valid_flow_mask @@ -109,7 +114,8 @@ class RandomErasing(T.RandomErasing): def __init__(self, p=0.5, scale=(0.02, 0.33), ratio=(0.3, 3.3), value=0, inplace=False, max_erase=1): super().__init__(p=p, scale=scale, ratio=ratio, value=value, inplace=inplace) self.max_erase = max_erase - assert self.max_erase > 0 + if self.max_erase <= 0: + raise ValueError("max_raise should be greater than 0") def forward(self, img1, img2, flow, valid_flow_mask): if torch.rand(1) > self.p: diff --git a/references/optical_flow/utils.py b/references/optical_flow/utils.py index 4b6d0049f54..065a2be8bfc 100644 --- a/references/optical_flow/utils.py +++ b/references/optical_flow/utils.py @@ -71,7 +71,10 @@ def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() - assert isinstance(v, (float, int)) + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) self.meters[k].update(v) def __getattr__(self, attr): diff --git a/references/segmentation/coco_utils.py b/references/segmentation/coco_utils.py index 4d37187f7ec..e02434012f1 100644 --- a/references/segmentation/coco_utils.py +++ b/references/segmentation/coco_utils.py @@ -68,7 +68,11 @@ def _has_valid_annotation(anno): # if more than 1k pixels occupied in the image return sum(obj["area"] for obj in anno) > 1000 - assert isinstance(dataset, torchvision.datasets.CocoDetection) + if not isinstance(dataset, torchvision.datasets.CocoDetection): + raise TypeError( + f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}" + ) + ids = [] for ds_idx, img_id in enumerate(dataset.ids): ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) diff --git a/references/segmentation/utils.py b/references/segmentation/utils.py index 22096c9dd2c..27c8f4ce51e 100644 --- a/references/segmentation/utils.py +++ b/references/segmentation/utils.py @@ -118,7 +118,10 @@ def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() - assert isinstance(v, (float, int)) + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) self.meters[k].update(v) def __getattr__(self, attr): diff --git a/references/similarity/sampler.py b/references/similarity/sampler.py index 591155fb449..f4564eca33e 100644 --- a/references/similarity/sampler.py +++ b/references/similarity/sampler.py @@ -47,7 +47,8 @@ def __init__(self, groups, p, k): self.groups = create_groups(groups, self.k) # Ensures there are enough classes to sample from - assert len(self.groups) >= p + if len(self.groups) < p: + raise ValueError("There are not enought classes to sample from") def __iter__(self): # Shuffle samples within groups diff --git a/references/video_classification/utils.py b/references/video_classification/utils.py index a68c2386bcf..116adf8d72f 100644 --- a/references/video_classification/utils.py +++ b/references/video_classification/utils.py @@ -76,7 +76,10 @@ def update(self, **kwargs): for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() - assert isinstance(v, (float, int)) + if not isinstance(v, (float, int)): + raise TypeError( + f"This method expects the value of the input arguments to be of type float or int, instead got {type(v)}" + ) self.meters[k].update(v) def __getattr__(self, attr): diff --git a/test/test_backbone_utils.py b/test/test_backbone_utils.py index ed9b52d0499..a3ba427f1de 100644 --- a/test/test_backbone_utils.py +++ b/test/test_backbone_utils.py @@ -144,16 +144,16 @@ def test_build_fx_feature_extractor(self, model_name): model, train_return_nodes=train_return_nodes, eval_return_nodes=eval_return_nodes ) # Check must specify return nodes - with pytest.raises(AssertionError): + with pytest.raises(ValueError): self._create_feature_extractor(model) # Check return_nodes and train_return_nodes / eval_return nodes # mutual exclusivity - with pytest.raises(AssertionError): + with pytest.raises(ValueError): self._create_feature_extractor( model, return_nodes=train_return_nodes, train_return_nodes=train_return_nodes ) # Check train_return_nodes / eval_return nodes must both be specified - with pytest.raises(AssertionError): + with pytest.raises(ValueError): self._create_feature_extractor(model, train_return_nodes=train_return_nodes) # Check invalid node name raises ValueError with pytest.raises(ValueError): diff --git a/test/test_models.py b/test/test_models.py index af433049a94..fb024c8da3f 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -767,7 +767,7 @@ def test_detection_model_validation(model_fn): # validate type targets = [{"boxes": 0.0}] - with pytest.raises(ValueError): + with pytest.raises(TypeError): model(x, targets=targets) # validate boxes shape diff --git a/test/test_ops.py b/test/test_ops.py index ad9aaefee52..d1562b00a42 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -138,13 +138,13 @@ def test_autocast(self, x_dtype, rois_dtype): def _helper_boxes_shape(self, func): # test boxes as Tensor[N, 5] - with pytest.raises(AssertionError): + with pytest.raises(ValueError): a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) boxes = torch.tensor([[0, 0, 3, 3]], dtype=a.dtype) func(a, boxes, output_size=(2, 2)) # test boxes as List[Tensor[N, 4]] - with pytest.raises(AssertionError): + with pytest.raises(ValueError): a = torch.linspace(1, 8 * 8, 8 * 8).reshape(1, 1, 8, 8) boxes = torch.tensor([[0, 0, 3]], dtype=a.dtype) ops.roi_pool(a, [boxes], output_size=(2, 2)) diff --git a/torchvision/datasets/kinetics.py b/torchvision/datasets/kinetics.py index 189142a5e67..e1ac4ac500c 100644 --- a/torchvision/datasets/kinetics.py +++ b/torchvision/datasets/kinetics.py @@ -118,7 +118,8 @@ def __init__( print("Using legacy structure") self.split_folder = root self.split = "unknown" - assert not download, "Cannot download the videos using legacy_structure." + if download: + raise ValueError("Cannot download the videos using legacy_structure.") else: self.split_folder = path.join(root, split) self.split = verify_str_arg(split, arg="split", valid_values=["train", "val"]) diff --git a/torchvision/datasets/mnist.py b/torchvision/datasets/mnist.py index 660de3d420f..9f9ec457499 100644 --- a/torchvision/datasets/mnist.py +++ b/torchvision/datasets/mnist.py @@ -442,11 +442,14 @@ def _check_exists(self) -> bool: def _load_data(self): data = read_sn3_pascalvincent_tensor(self.images_file) - assert data.dtype == torch.uint8 - assert data.ndimension() == 3 + if data.dtype != torch.uint8: + raise TypeError(f"data should be of dtype torch.uint8 instead of {data.dtype}") + if data.ndimension() != 3: + raise ValueError("data should have 3 dimensions instead of {data.ndimension()}") targets = read_sn3_pascalvincent_tensor(self.labels_file).long() - assert targets.ndimension() == 2 + if targets.ndimension() != 2: + raise ValueError(f"targets should have 2 dimensions instead of {targets.ndimension()}") if self.what == "test10k": data = data[0:10000, :, :].clone() @@ -530,13 +533,17 @@ def read_sn3_pascalvincent_tensor(path: str, strict: bool = True) -> torch.Tenso def read_label_file(path: str) -> torch.Tensor: x = read_sn3_pascalvincent_tensor(path, strict=False) - assert x.dtype == torch.uint8 - assert x.ndimension() == 1 + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 1: + raise ValueError(f"x should have 1 dimension instead of {x.ndimension()}") return x.long() def read_image_file(path: str) -> torch.Tensor: x = read_sn3_pascalvincent_tensor(path, strict=False) - assert x.dtype == torch.uint8 - assert x.ndimension() == 3 + if x.dtype != torch.uint8: + raise TypeError(f"x should be of dtype torch.uint8 instead of {x.dtype}") + if x.ndimension() != 3: + raise ValueError(f"x should have 3 dimension instead of {x.ndimension()}") return x diff --git a/torchvision/datasets/samplers/clip_sampler.py b/torchvision/datasets/samplers/clip_sampler.py index ad7427f1949..f4975f8c021 100644 --- a/torchvision/datasets/samplers/clip_sampler.py +++ b/torchvision/datasets/samplers/clip_sampler.py @@ -52,12 +52,10 @@ def __init__( if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() - assert ( - len(dataset) % group_size == 0 - ), "dataset length must be a multiplier of group size dataset length: %d, group size: %d" % ( - len(dataset), - group_size, - ) + if len(dataset) % group_size != 0: + raise ValueError( + f"dataset length must be a multiplier of group size dataset length: {len(dataset)}, group size: {group_size}" + ) self.dataset = dataset self.group_size = group_size self.num_replicas = num_replicas diff --git a/torchvision/datasets/sbd.py b/torchvision/datasets/sbd.py index ce485680910..030643dc794 100644 --- a/torchvision/datasets/sbd.py +++ b/torchvision/datasets/sbd.py @@ -92,7 +92,6 @@ def __init__( self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names] self.masks = [os.path.join(mask_dir, x + ".mat") for x in file_names] - assert len(self.images) == len(self.masks) self._get_target = self._get_segmentation_target if self.mode == "segmentation" else self._get_boundaries_target diff --git a/torchvision/datasets/video_utils.py b/torchvision/datasets/video_utils.py index efa3836c8d1..d444496ffe7 100644 --- a/torchvision/datasets/video_utils.py +++ b/torchvision/datasets/video_utils.py @@ -38,7 +38,8 @@ def unfold(tensor: torch.Tensor, size: int, step: int, dilation: int = 1) -> tor `step` between windows. The distance between each element in a window is given by `dilation`. """ - assert tensor.dim() == 1 + if tensor.dim() != 1: + raise ValueError(f"tensor should have 1 dimension instead of {tensor.dim()}") o_stride = tensor.stride(0) numel = tensor.numel() new_stride = (step * o_stride, dilation * o_stride) diff --git a/torchvision/io/_video_opt.py b/torchvision/io/_video_opt.py index de4b25bb7b5..5357d25ea62 100644 --- a/torchvision/io/_video_opt.py +++ b/torchvision/io/_video_opt.py @@ -67,13 +67,9 @@ def __init__(self) -> None: def _validate_pts(pts_range: Tuple[int, int]) -> None: - if pts_range[1] > 0: - assert ( - pts_range[0] <= pts_range[1] - ), """Start pts should not be smaller than end pts, got - start pts: {:d} and end pts: {:d}""".format( - pts_range[0], - pts_range[1], + if pts_range[0] > pts_range[1] > 0: + raise ValueError( + f"Start pts should not be smaller than end pts, got start pts: {pts_range[0]} and end pts: {pts_range[1]}" ) diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index 36e99e6506d..d7126ef681b 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -159,8 +159,10 @@ def encode_single(self, reference_boxes: Tensor, proposals: Tensor) -> Tensor: return targets def decode(self, rel_codes: Tensor, boxes: List[Tensor]) -> Tensor: - assert isinstance(boxes, (list, tuple)) - assert isinstance(rel_codes, torch.Tensor) + if not isinstance(boxes, (list, tuple)): + raise TypeError(f"This function expects boxes of type list or tuple, instead got {type(boxes)}") + if not isinstance(rel_codes, torch.Tensor): + raise TypeError(f"This function expects rel_codes of type torch.Tensor, instead got {type(rel_codes)}") boxes_per_image = [b.size(0) for b in boxes] concat_boxes = torch.cat(boxes, dim=0) box_sum = 0 @@ -333,7 +335,8 @@ def __init__(self, high_threshold: float, low_threshold: float, allow_low_qualit """ self.BELOW_LOW_THRESHOLD = -1 self.BETWEEN_THRESHOLDS = -2 - assert low_threshold <= high_threshold + if low_threshold > high_threshold: + raise ValueError("low_threshold should be <= high_threshold") self.high_threshold = high_threshold self.low_threshold = low_threshold self.allow_low_quality_matches = allow_low_quality_matches @@ -371,7 +374,8 @@ def __call__(self, match_quality_matrix: Tensor) -> Tensor: matches[between_thresholds] = self.BETWEEN_THRESHOLDS if self.allow_low_quality_matches: - assert all_matches is not None + if all_matches is None: + raise ValueError("all_matches should not be None") self.set_low_quality_matches_(matches, all_matches, match_quality_matrix) return matches diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 6771dda0ce4..3248fc2e1aa 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -45,8 +45,6 @@ def __init__( if not isinstance(aspect_ratios[0], (list, tuple)): aspect_ratios = (aspect_ratios,) * len(sizes) - assert len(sizes) == len(aspect_ratios) - self.sizes = sizes self.aspect_ratios = aspect_ratios self.cell_anchors = [ @@ -86,7 +84,9 @@ def num_anchors_per_location(self): def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: anchors = [] cell_anchors = self.cell_anchors - assert cell_anchors is not None + + if cell_anchors is None: + ValueError("cell_anchors should not be None") if not (len(grid_sizes) == len(strides) == len(cell_anchors)): raise ValueError( @@ -164,8 +164,8 @@ def __init__( clip: bool = True, ): super().__init__() - if steps is not None: - assert len(aspect_ratios) == len(steps) + if steps is not None and len(aspect_ratios) != len(steps): + raise ValueError("aspect_ratios and steps should have the same length") self.aspect_ratios = aspect_ratios self.steps = steps self.clip = clip diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 790740fe9c5..35cb968d711 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -187,8 +187,14 @@ def __init__( "same for all the levels)" ) - assert isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))) - assert isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))) + if not isinstance(rpn_anchor_generator, (AnchorGenerator, type(None))): + raise TypeError( + f"rpn_anchor_generator should be of type AnchorGenerator or None instead of {type(rpn_anchor_generator)}" + ) + if not isinstance(box_roi_pool, (MultiScaleRoIAlign, type(None))): + raise TypeError( + f"box_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(box_roi_pool)}" + ) if num_classes is not None: if box_predictor is not None: @@ -299,7 +305,10 @@ def __init__(self, in_channels, num_classes): def forward(self, x): if x.dim() == 4: - assert list(x.shape[2:]) == [1, 1] + if list(x.shape[2:]) != [1, 1]: + raise ValueError( + f"x has the wrong shape, expecting the last two dimensions to be [1,1] instead of {list(x.shape[2:])}" + ) x = x.flatten(start_dim=1) scores = self.cls_score(x) bbox_deltas = self.bbox_pred(x) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index c4c2e6f5842..c15702f5e18 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -378,14 +378,20 @@ def __init__( ) self.backbone = backbone - assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + if not isinstance(anchor_generator, (AnchorGenerator, type(None))): + raise TypeError( + f"anchor_generator should be of type AnchorGenerator or None, instead got {type(anchor_generator)}" + ) if anchor_generator is None: anchor_sizes = ((8,), (16,), (32,), (64,), (128,)) # equal to strides of multi-level feature map aspect_ratios = ((1.0,),) * len(anchor_sizes) # set only one anchor anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios) self.anchor_generator = anchor_generator - assert self.anchor_generator.num_anchors_per_location()[0] == 1 + if self.anchor_generator.num_anchors_per_location()[0] != 1: + raise ValueError( + f"anchor_generator.num_anchors_per_location()[0] should be 1 instead of {anchor_generator.num_anchors_per_location()[0]}" + ) if head is None: head = FCOSHead(backbone.out_channels, anchor_generator.num_anchors_per_location()[0], num_classes) @@ -560,12 +566,15 @@ def forward( if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") else: - raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] - assert len(val) == 2 + if len(val) != 2: + raise ValueError( + f"expecting the last two dimensions of the Tensor to be H and W instead got {img.shape[-2:]}" + ) original_image_sizes.append((val[0], val[1])) # transform the input @@ -603,9 +612,10 @@ def forward( losses = {} detections: List[Dict[str, Tensor]] = [] if self.training: - assert targets is not None - # compute the losses + if targets is None: + raise ValueError("targets should not be none when in training mode") + losses = self.compute_loss(targets, head_outputs, anchors, num_anchors_per_level) else: # split outputs per level diff --git a/torchvision/models/detection/generalized_rcnn.py b/torchvision/models/detection/generalized_rcnn.py index 37ef1820d71..dba8e5b8148 100644 --- a/torchvision/models/detection/generalized_rcnn.py +++ b/torchvision/models/detection/generalized_rcnn.py @@ -57,22 +57,25 @@ def forward(self, images, targets=None): like `scores`, `labels` and `mask` (for Mask R-CNN models). """ - if self.training and targets is None: - raise ValueError("In training mode, targets should be passed") if self.training: - assert targets is not None + if targets is None: + raise ValueError("In training mode, targets should be passed") + for target in targets: boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") else: - raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] - assert len(val) == 2 + if len(val) != 2: + raise ValueError( + f"Expecting the last two dimensions of the input tensor to be H and W, instead got {img.shape[-2:]}" + ) original_image_sizes.append((val[0], val[1])) images, targets = self.transform(images, targets) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 9f23e66e0c5..aadd390afc8 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -191,7 +191,10 @@ def __init__( num_keypoints=None, ): - assert isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))) + if not isinstance(keypoint_roi_pool, (MultiScaleRoIAlign, type(None))): + raise TypeError( + "keypoint_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(keypoint_roi_pool)}" + ) if min_size is None: min_size = (640, 672, 704, 736, 768, 800) diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 37f88116c5e..c733613452a 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -191,7 +191,10 @@ def __init__( mask_predictor=None, ): - assert isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))) + if not isinstance(mask_roi_pool, (MultiScaleRoIAlign, type(None))): + raise TypeError( + f"mask_roi_pool should be of type MultiScaleRoIAlign or None instead of {type(mask_roi_pool)}" + ) if num_classes is not None: if mask_predictor is not None: diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 4f79b5ddbfc..6d6463d6894 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -347,7 +347,10 @@ def __init__( ) self.backbone = backbone - assert isinstance(anchor_generator, (AnchorGenerator, type(None))) + if not isinstance(anchor_generator, (AnchorGenerator, type(None))): + raise TypeError( + f"anchor_generator should be of type AnchorGenerator or None instead of {type(anchor_generator)}" + ) if anchor_generator is None: anchor_sizes = tuple((x, int(x * 2 ** (1.0 / 3)), int(x * 2 ** (2.0 / 3))) for x in [32, 64, 128, 256, 512]) @@ -488,20 +491,24 @@ def forward(self, images, targets=None): raise ValueError("In training mode, targets should be passed") if self.training: - assert targets is not None + if targets is None: + raise ValueError("In training mode, targets should be passed") for target in targets: boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") else: - raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] - assert len(val) == 2 + if len(val) != 2: + raise ValueError( + f"Expecting the two last elements of the input tensors to be H and W instead got {img.shape[-2:]}" + ) original_image_sizes.append((val[0], val[1])) # transform the input @@ -539,8 +546,8 @@ def forward(self, images, targets=None): losses = {} detections: List[Dict[str, Tensor]] = [] if self.training: - assert targets is not None - + if targets is None: + raise ValueError("In training mode, targets should be passed") # compute the losses losses = self.compute_loss(targets, head_outputs, anchors) else: diff --git a/torchvision/models/detection/roi_heads.py b/torchvision/models/detection/roi_heads.py index b7bbb81111e..9f2ef20d17c 100644 --- a/torchvision/models/detection/roi_heads.py +++ b/torchvision/models/detection/roi_heads.py @@ -299,7 +299,10 @@ def heatmaps_to_keypoints(maps, rois): def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched_idxs): # type: (Tensor, List[Tensor], List[Tensor], List[Tensor]) -> Tensor N, K, H, W = keypoint_logits.shape - assert H == W + if H != W: + raise ValueError( + f"keypoint_logits height and width (last two elements of shape) should be equal. Instead got H = {H} and W = {W}" + ) discretization_size = H heatmaps = [] valid = [] @@ -615,11 +618,15 @@ def add_gt_proposals(self, proposals, gt_boxes): def check_targets(self, targets): # type: (Optional[List[Dict[str, Tensor]]]) -> None - assert targets is not None - assert all(["boxes" in t for t in targets]) - assert all(["labels" in t for t in targets]) + if targets is None: + raise ValueError("targets should not be None") + if not all(["boxes" in t for t in targets]): + raise ValueError("Every element of targets should have a boxes key") + if not all(["labels" in t for t in targets]): + raise ValueError("Every element of targets should have a labels key") if self.has_mask(): - assert all(["masks" in t for t in targets]) + if not all(["masks" in t for t in targets]): + raise ValueError("Every element of targets should have a masks key") def select_training_samples( self, @@ -628,7 +635,8 @@ def select_training_samples( ): # type: (...) -> Tuple[List[Tensor], List[Tensor], List[Tensor], List[Tensor]] self.check_targets(targets) - assert targets is not None + if targets is None: + raise ValueError("targets should not be None") dtype = proposals[0].dtype device = proposals[0].device @@ -736,10 +744,13 @@ def forward( for t in targets: # TODO: https://github.com/pytorch/pytorch/issues/26731 floating_point_types = (torch.float, torch.double, torch.half) - assert t["boxes"].dtype in floating_point_types, "target boxes must of float type" - assert t["labels"].dtype == torch.int64, "target labels must of int64 type" + if not t["boxes"].dtype in floating_point_types: + raise TypeError(f"target boxes must of float type, instead got {t['boxes'].dtype}") + if not t["labels"].dtype == torch.int64: + raise TypeError("target labels must of int64 type, instead got {t['labels'].dtype}") if self.has_keypoint(): - assert t["keypoints"].dtype == torch.float32, "target keypoints must of float type" + if not t["keypoints"].dtype == torch.float32: + raise TypeError(f"target keypoints must of float type, instead got {t['keypoints'].dtype}") if self.training: proposals, matched_idxs, labels, regression_targets = self.select_training_samples(proposals, targets) @@ -755,7 +766,10 @@ def forward( result: List[Dict[str, torch.Tensor]] = [] losses = {} if self.training: - assert labels is not None and regression_targets is not None + if labels is None: + raise ValueError("labels cannot be None") + if regression_targets is None: + raise ValueError("regression_targets cannot be None") loss_classifier, loss_box_reg = fastrcnn_loss(class_logits, box_regression, labels, regression_targets) losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg} else: @@ -773,7 +787,9 @@ def forward( if self.has_mask(): mask_proposals = [p["boxes"] for p in result] if self.training: - assert matched_idxs is not None + if matched_idxs is None: + raise ValueError("if in trainning, matched_idxs should not be None") + # during training, only focus on positive boxes num_images = len(proposals) mask_proposals = [] @@ -794,9 +810,8 @@ def forward( loss_mask = {} if self.training: - assert targets is not None - assert pos_matched_idxs is not None - assert mask_logits is not None + if targets is None or pos_matched_idxs is None or mask_logits is None: + raise ValueError("targets, pos_matched_idxs, mask_logits cannot be None when training") gt_masks = [t["masks"] for t in targets] gt_labels = [t["labels"] for t in targets] @@ -823,7 +838,9 @@ def forward( num_images = len(proposals) keypoint_proposals = [] pos_matched_idxs = [] - assert matched_idxs is not None + if matched_idxs is None: + raise ValueError("if in trainning, matched_idxs should not be None") + for img_id in range(num_images): pos = torch.where(labels[img_id] > 0)[0] keypoint_proposals.append(proposals[img_id][pos]) @@ -837,8 +854,8 @@ def forward( loss_keypoint = {} if self.training: - assert targets is not None - assert pos_matched_idxs is not None + if targets is None or pos_matched_idxs is None: + raise ValueError("both targets and pos_matched_idxs should not be None when in training mode") gt_keypoints = [t["keypoints"] for t in targets] rcnn_loss_keypoint = keypointrcnn_loss( @@ -846,14 +863,15 @@ def forward( ) loss_keypoint = {"loss_keypoint": rcnn_loss_keypoint} else: - assert keypoint_logits is not None - assert keypoint_proposals is not None + if keypoint_logits is None or keypoint_proposals is None: + raise ValueError( + "both keypoint_logits and keypoint_proposals should not be None when not in training mode" + ) keypoints_probs, kp_scores = keypointrcnn_inference(keypoint_logits, keypoint_proposals) for keypoint_prob, kps, r in zip(keypoints_probs, kp_scores, result): r["keypoints"] = keypoint_prob r["keypoints_scores"] = kps - losses.update(loss_keypoint) return result, losses diff --git a/torchvision/models/detection/rpn.py b/torchvision/models/detection/rpn.py index 1d63bcc8a54..18379ac25f6 100644 --- a/torchvision/models/detection/rpn.py +++ b/torchvision/models/detection/rpn.py @@ -339,7 +339,8 @@ def forward( losses = {} if self.training: - assert targets is not None + if targets is None: + raise ValueError("targets should not be None") labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets) regression_targets = self.box_coder.encode(matched_gt_boxes, anchors) loss_objectness, loss_rpn_box_reg = self.compute_loss( diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 08a9ed68e4e..bd7f1b2863f 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -196,7 +196,10 @@ def __init__( else: out_channels = det_utils.retrieve_out_channels(backbone, size) - assert len(out_channels) == len(anchor_generator.aspect_ratios) + if len(out_channels) != len(anchor_generator.aspect_ratios): + raise ValueError( + f"The length of the output channels from the backbone ({len(out_channels)}) do not match the length of the anchor generator aspect ratios ({len(anchor_generator.aspect_ratios)})" + ) num_anchors = self.anchor_generator.num_anchors_per_location() head = SSDHead(out_channels, num_anchors, num_classes) @@ -308,20 +311,24 @@ def forward( raise ValueError("In training mode, targets should be passed") if self.training: - assert targets is not None + if targets is None: + raise ValueError("targets should not be None") for target in targets: boxes = target["boxes"] if isinstance(boxes, torch.Tensor): if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.") else: - raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") + raise TypeError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.") # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] for img in images: val = img.shape[-2:] - assert len(val) == 2 + if len(val) != 2: + raise ValueError( + f"The last two dimensions of the input tensors should contain H and W, instead got {img.shape[-2:]}" + ) original_image_sizes.append((val[0], val[1])) # transform the input @@ -356,7 +363,8 @@ def forward( losses = {} detections: List[Dict[str, Tensor]] = [] if self.training: - assert targets is not None + if targets is None: + raise ValueError("targets should not be None when in training mode") matched_idxs = [] for anchors_per_image, targets_per_image in zip(anchors, targets): @@ -527,7 +535,8 @@ def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): num_stages = len(stage_indices) # find the index of the layer from which we wont freeze - assert 0 <= trainable_layers <= num_stages + if not 0 <= trainable_layers <= num_stages: + raise ValueError(f"trainable_layers should be in the range [0, {num_stages}]. Instead got {trainable_layers}") freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] for b in backbone[:freeze_before]: diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 1ee59e069ea..1c59814f8d4 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -122,7 +122,9 @@ def __init__( super().__init__() _log_api_usage_once(self) - assert not backbone[c4_pos].use_res_connect + if backbone[c4_pos].use_res_connect: + raise ValueError("backbone[c4_pos].use_res_connect should be False") + self.features = nn.Sequential( # As described in section 6.3 of MobileNetV3 paper nn.Sequential(*backbone[:c4_pos], backbone[c4_pos].block[0]), # from start until C4 expansion layer @@ -168,7 +170,8 @@ def _mobilenet_extractor( num_stages = len(stage_indices) # find the index of the layer from which we wont freeze - assert 0 <= trainable_layers <= num_stages + if not 0 <= trainable_layers <= num_stages: + raise ValueError("trainable_layers should be in the range [0, {num_stages}], instead got {trainable_layers}") freeze_before = len(backbone) if trainable_layers == 0 else stage_indices[num_stages - trainable_layers] for b in backbone[:freeze_before]: @@ -244,7 +247,10 @@ def ssdlite320_mobilenet_v3_large( anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) out_channels = det_utils.retrieve_out_channels(backbone, size) num_anchors = anchor_generator.num_anchors_per_location() - assert len(out_channels) == len(anchor_generator.aspect_ratios) + if len(out_channels) != len(anchor_generator.aspect_ratios): + raise ValueError( + f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}" + ) defaults = { "score_thresh": 0.001, diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py index 960e28500a1..58b38baee04 100644 --- a/torchvision/models/detection/transform.py +++ b/torchvision/models/detection/transform.py @@ -134,7 +134,10 @@ def forward( images = self.batch_images(images, size_divisible=self.size_divisible) image_sizes_list: List[Tuple[int, int]] = [] for image_size in image_sizes: - assert len(image_size) == 2 + if len(image_size) != 2: + raise ValueError( + f"Input tensors expected to have in the last two elements H and W, instead got {image_size}" + ) image_sizes_list.append((image_size[0], image_size[1])) image_list = ImageList(images, image_sizes_list) diff --git a/torchvision/models/feature_extraction.py b/torchvision/models/feature_extraction.py index 0a2b597da23..a6c26913093 100644 --- a/torchvision/models/feature_extraction.py +++ b/torchvision/models/feature_extraction.py @@ -277,7 +277,8 @@ def __init__( # eval graphs) for node in chain(iter(train_graph.nodes), iter(eval_graph.nodes)): if node.op in ["get_attr", "call_module"]: - assert isinstance(node.target, str) + if not isinstance(node.target, str): + raise TypeError(f"node.target should be of type str instead of {type(node.target)}") _copy_attr(root, self, node.target) # train mode by default @@ -290,9 +291,10 @@ def __init__( # Locally defined Tracers are not pickleable. This is needed because torch.package will # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer # to re-create the Graph during deserialization. - assert ( - self.eval_graph._tracer_cls == self.train_graph._tracer_cls - ), "Train mode and eval mode should use the same tracer class" + if self.eval_graph._tracer_cls != self.train_graph._tracer_cls: + raise TypeError( + f"Train mode and eval mode should use the same tracer class. Instead got {self.eval_graph._tracer_cls} for eval vs {self.train_graph._tracer_cls} for train" + ) self._tracer_cls = None if self.graph._tracer_cls and "" not in self.graph._tracer_cls.__qualname__: self._tracer_cls = self.graph._tracer_cls @@ -431,17 +433,19 @@ def create_feature_extractor( } is_training = model.training - assert any( - arg is not None for arg in [return_nodes, train_return_nodes, eval_return_nodes] - ), "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified" + if all(arg is None for arg in [return_nodes, train_return_nodes, eval_return_nodes]): - assert not ( - (train_return_nodes is None) ^ (eval_return_nodes is None) - ), "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" + raise ValueError( + "Either `return_nodes` or `train_return_nodes` and `eval_return_nodes` together, should be specified" + ) + + if (train_return_nodes is None) ^ (eval_return_nodes is None): + raise ValueError( + "If any of `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" + ) - assert (return_nodes is None) ^ ( - train_return_nodes is None - ), "If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified" + if not ((return_nodes is None) ^ (train_return_nodes is None)): + raise ValueError("If `train_return_nodes` and `eval_return_nodes` are specified, then both should be specified") # Put *_return_nodes into Dict[str, str] format def to_strdict(n) -> Dict[str, str]: @@ -476,9 +480,10 @@ def to_strdict(n) -> Dict[str, str]: available_nodes = list(tracer.node_to_qualname.values()) # FIXME We don't know if we should expect this to happen - assert len(set(available_nodes)) == len( - available_nodes - ), "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + if len(set(available_nodes)) != len(available_nodes): + raise ValueError( + "There are duplicate nodes! Please raise an issue https://github.com/pytorch/vision/issues" + ) # Check that all outputs in return_nodes are present in the model for query in mode_return_nodes[mode].keys(): # To check if a query is available we need to check that at least @@ -497,7 +502,9 @@ def to_strdict(n) -> Dict[str, str]: for n in reversed(graph_module.graph.nodes): if n.op == "output": orig_output_nodes.append(n) - assert len(orig_output_nodes) + if not orig_output_nodes: + raise ValueError("No output nodes found in graph_module.graph.nodes") + for n in orig_output_nodes: graph_module.graph.erase_node(n) diff --git a/torchvision/models/googlenet.py b/torchvision/models/googlenet.py index 9e4c3498aab..f3487b44c09 100644 --- a/torchvision/models/googlenet.py +++ b/torchvision/models/googlenet.py @@ -50,7 +50,8 @@ def __init__( FutureWarning, ) init_weights = True - assert len(blocks) == 3 + if len(blocks) != 3: + raise ValueError(f"blocks length should be 3 instead of {len(blocks)}") conv_block = blocks[0] inception_block = blocks[1] inception_aux_block = blocks[2] diff --git a/torchvision/models/inception.py b/torchvision/models/inception.py index c489925cb45..0fe6400a681 100644 --- a/torchvision/models/inception.py +++ b/torchvision/models/inception.py @@ -48,7 +48,8 @@ def __init__( FutureWarning, ) init_weights = True - assert len(inception_blocks) == 7 + if len(inception_blocks) != 7: + raise ValueError(f"lenght of inception_blocks should be 7 instead of {len(inception_blocks)}") conv_block = inception_blocks[0] inception_a = inception_blocks[1] inception_b = inception_blocks[2] diff --git a/torchvision/models/mnasnet.py b/torchvision/models/mnasnet.py index c3d4013f30c..9608c555a88 100644 --- a/torchvision/models/mnasnet.py +++ b/torchvision/models/mnasnet.py @@ -27,8 +27,10 @@ def __init__( self, in_ch: int, out_ch: int, kernel_size: int, stride: int, expansion_factor: int, bn_momentum: float = 0.1 ) -> None: super().__init__() - assert stride in [1, 2] - assert kernel_size in [3, 5] + if stride not in [1, 2]: + raise ValueError(f"stride should be 1 or 2 instead of {stride}") + if kernel_size not in [3, 5]: + raise ValueError(f"kernel_size should be 3 or 5 instead of {kernel_size}") mid_ch = in_ch * expansion_factor self.apply_residual = in_ch == out_ch and stride == 1 self.layers = nn.Sequential( @@ -56,7 +58,8 @@ def _stack( in_ch: int, out_ch: int, kernel_size: int, stride: int, exp_factor: int, repeats: int, bn_momentum: float ) -> nn.Sequential: """Creates a stack of inverted residuals.""" - assert repeats >= 1 + if repeats < 1: + raise ValueError(f"repeats should be >= 1, instead got {repeats}") # First one has no skip, because feature map size changes. first = _InvertedResidual(in_ch, out_ch, kernel_size, stride, exp_factor, bn_momentum=bn_momentum) remaining = [] @@ -69,7 +72,8 @@ def _round_to_multiple_of(val: float, divisor: int, round_up_bias: float = 0.9) """Asymmetric rounding to make `val` divisible by `divisor`. With default bias, will round up, unless the number is no more than 10% greater than the smaller divisible value, i.e. (83, 8) -> 80, but (84, 8) -> 88.""" - assert 0.0 < round_up_bias < 1.0 + if not 0.0 < round_up_bias < 1.0: + raise ValueError(f"round_up_bias should be greater than 0.0 and smaller than 1.0 instead of {round_up_bias}") new_val = max(divisor, int(val + divisor / 2) // divisor * divisor) return new_val if new_val >= round_up_bias * val else new_val + divisor @@ -99,7 +103,8 @@ class MNASNet(torch.nn.Module): def __init__(self, alpha: float, num_classes: int = 1000, dropout: float = 0.2) -> None: super().__init__() _log_api_usage_once(self) - assert alpha > 0.0 + if alpha <= 0.0: + raise ValueError(f"alpha should be greater than 0.0 instead of {alpha}") self.alpha = alpha self.num_classes = num_classes depths = _get_depths(alpha) @@ -158,7 +163,8 @@ def _load_from_state_dict( error_msgs: List[str], ) -> None: version = local_metadata.get("version", None) - assert version in [1, 2] + if version not in [1, 2]: + raise ValueError(f"version shluld be set to 1 or 2 instead of {version}") if version == 1 and not self.alpha == 1.0: # In the initial version of the model (v1), stem was fixed-size. diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 930f68d13e9..f65993b0a5a 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -44,7 +44,8 @@ def __init__( ) -> None: super().__init__() self.stride = stride - assert stride in [1, 2] + if stride not in [1, 2]: + raise ValueError(f"stride should be 1 or 2 insted of {stride}") if norm_layer is None: norm_layer = nn.BatchNorm2d diff --git a/torchvision/models/optical_flow/raft.py b/torchvision/models/optical_flow/raft.py index 4dfd232d499..00200529f66 100644 --- a/torchvision/models/optical_flow/raft.py +++ b/torchvision/models/optical_flow/raft.py @@ -121,7 +121,8 @@ class FeatureEncoder(nn.Module): def __init__(self, *, block=ResidualBlock, layers=(64, 64, 96, 128, 256), norm_layer=nn.BatchNorm2d): super().__init__() - assert len(layers) == 5 + if len(layers) != 5: + raise ValueError(f"The expected number of layers is 5, instead got {len(layers)}") # See note in ResidualBlock for the reason behind bias=True self.convnormrelu = Conv2dNormActivation( @@ -169,8 +170,10 @@ class MotionEncoder(nn.Module): def __init__(self, *, in_channels_corr, corr_layers=(256, 192), flow_layers=(128, 64), out_channels=128): super().__init__() - assert len(flow_layers) == 2 - assert len(corr_layers) in (1, 2) + if len(flow_layers) != 2: + raise ValueError(f"The expected number of flow_layers is 2, instead got {len(flow_layers)}") + if len(corr_layers) not in (1, 2): + raise ValueError(f"The number of corr_layers should be 1 or 2, instead got {len(corr_layers)}") self.convcorr1 = Conv2dNormActivation(in_channels_corr, corr_layers[0], norm_layer=None, kernel_size=1) if len(corr_layers) == 2: @@ -234,8 +237,12 @@ class RecurrentBlock(nn.Module): def __init__(self, *, input_size, hidden_size, kernel_size=((1, 5), (5, 1)), padding=((0, 2), (2, 0))): super().__init__() - assert len(kernel_size) == len(padding) - assert len(kernel_size) in (1, 2) + if len(kernel_size) != len(padding): + raise ValueError( + f"kernel_size should have the same length as padding, instead got len(kernel_size) = {len(kernel_size)} and len(padding) = {len(padding)}" + ) + if len(kernel_size) not in (1, 2): + raise ValueError(f"kernel_size should either 1 or 2, instead got {len(kernel_size)}") self.convgru1 = ConvGRU( input_size=input_size, hidden_size=hidden_size, kernel_size=kernel_size[0], padding=padding[0] @@ -351,7 +358,10 @@ def build_pyramid(self, fmap1, fmap2): to build the correlation pyramid. """ - torch._assert(fmap1.shape == fmap2.shape, "Input feature maps should have the same shapes") + if fmap1.shape != fmap2.shape: + raise ValueError( + f"Input feature maps should have the same shape, instead got {fmap1.shape} (fmap1.shape) != {fmap2.shape} (fmap2.shape)" + ) corr_volume = self._compute_corr_volume(fmap1, fmap2) batch_size, h, w, num_channels, _, _ = corr_volume.shape # _, _ = h, w @@ -384,10 +394,10 @@ def index_pyramid(self, centroids_coords): corr_features = torch.cat(indexed_pyramid, dim=-1).permute(0, 3, 1, 2).contiguous() expected_output_shape = (batch_size, self.out_channels, h, w) - torch._assert( - corr_features.shape == expected_output_shape, - f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}", - ) + if corr_features.shape != expected_output_shape: + raise ValueError( + f"Output shape of index pyramid is incorrect. Should be {expected_output_shape}, got {corr_features.shape}" + ) return corr_features @@ -454,28 +464,31 @@ def __init__(self, *, feature_encoder, context_encoder, corr_block, update_block def forward(self, image1, image2, num_flow_updates: int = 12): batch_size, _, h, w = image1.shape - torch._assert((h, w) == image2.shape[-2:], "input images should have the same shape") - torch._assert((h % 8 == 0) and (w % 8 == 0), "input image H and W should be divisible by 8") + if (h, w) != image2.shape[-2:]: + raise ValueError(f"input images should have the same shape, instead got ({h}, {w}) != {image2.shape[-2:]}") + if not (h % 8 == 0) and (w % 8 == 0): + raise ValueError(f"input image H and W should be divisible by 8, insted got {h} (h) and {w} (w)") fmaps = self.feature_encoder(torch.cat([image1, image2], dim=0)) fmap1, fmap2 = torch.chunk(fmaps, chunks=2, dim=0) - torch._assert(fmap1.shape[-2:] == (h // 8, w // 8), "The feature encoder should downsample H and W by 8") + if fmap1.shape[-2:] != (h // 8, w // 8): + raise ValueError("The feature encoder should downsample H and W by 8") self.corr_block.build_pyramid(fmap1, fmap2) context_out = self.context_encoder(image1) - torch._assert(context_out.shape[-2:] == (h // 8, w // 8), "The context encoder should downsample H and W by 8") + if context_out.shape[-2:] != (h // 8, w // 8): + raise ValueError("The context encoder should downsample H and W by 8") # As in the original paper, the actual output of the context encoder is split in 2 parts: # - one part is used to initialize the hidden state of the recurent units of the update block # - the rest is the "actual" context. hidden_state_size = self.update_block.hidden_state_size out_channels_context = context_out.shape[1] - hidden_state_size - torch._assert( - out_channels_context > 0, - f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than" - f"hidden_state={hidden_state_size} channels", - ) + if out_channels_context <= 0: + raise ValueError( + f"The context encoder outputs {context_out.shape[1]} channels, but it should have at strictly more than hidden_state={hidden_state_size} channels" + ) hidden_state, context = torch.split(context_out, [hidden_state_size, out_channels_context], dim=1) hidden_state = torch.tanh(hidden_state) context = F.relu(context) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index f3758c54aaf..9a893ba1510 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -42,7 +42,10 @@ def __init__(self, inp: int, oup: int, stride: int) -> None: self.stride = stride branch_features = oup // 2 - assert (self.stride != 1) or (inp == branch_features << 1) + if (self.stride == 1) and (inp != branch_features << 1): + raise ValueError( + f"Invalid combination of stride {stride}, inp {inp} and oup {oup} values. If stride == 1 then inp should be equal to oup // 2 << 1." + ) if self.stride > 1: self.branch1 = nn.Sequential( diff --git a/torchvision/models/vision_transformer.py b/torchvision/models/vision_transformer.py index 29f756ccbe5..43e4d315cec 100644 --- a/torchvision/models/vision_transformer.py +++ b/torchvision/models/vision_transformer.py @@ -434,7 +434,10 @@ def interpolate_embeddings( # (1, seq_length, hidden_dim) -> (1, hidden_dim, seq_length) pos_embedding_img = pos_embedding_img.permute(0, 2, 1) seq_length_1d = int(math.sqrt(seq_length)) - torch._assert(seq_length_1d * seq_length_1d == seq_length, "seq_length is not a perfect square!") + if seq_length_1d * seq_length_1d != seq_length: + raise ValueError( + f"seq_length is not a perfect square! Instead got seq_length_1d * seq_length_1d = {seq_length_1d * seq_length_1d } and seq_length = {seq_length}" + ) # (1, hidden_dim, seq_length) -> (1, hidden_dim, seq_l_1d, seq_l_1d) pos_embedding_img = pos_embedding_img.reshape(1, hidden_dim, seq_length_1d, seq_length_1d) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 3a07c747f58..30f28e51c4c 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -28,13 +28,13 @@ def convert_boxes_to_roi_format(boxes: List[Tensor]) -> Tensor: def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]): if isinstance(boxes, (list, tuple)): for _tensor in boxes: - assert ( - _tensor.size(1) == 4 - ), "The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]]" + if _tensor.size(1) != 4: + raise ValueError("The shape of the tensor in the boxes list is not correct as List[Tensor[L, 4]].") elif isinstance(boxes, torch.Tensor): - assert boxes.size(1) == 5, "The boxes tensor shape is not correct as Tensor[K, 5]" + if boxes.size(1) != 5: + raise ValueError("The boxes tensor shape is not correct as Tensor[K, 5]/") else: - assert False, "boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]" + raise TypeError(f"boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]], instead got {type(boxes)}") return diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 38cb4c1a836..2f1f984ca25 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -300,8 +300,10 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: _log_api_usage_once(generalized_box_iou) # degenerate boxes gives inf / nan results # so do an early check - assert (boxes1[:, 2:] >= boxes1[:, :2]).all() - assert (boxes2[:, 2:] >= boxes2[:, :2]).all() + if (boxes1[:, 2:] < boxes1[:, :2]).any(): + raise ValueError("Some of the input boxes1 are invalid.") + if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): + raise ValueError("Some of the input boxes2 are invalid.") inter, union = _box_inter_union(boxes1, boxes2) iou = inter / union diff --git a/torchvision/ops/feature_pyramid_network.py b/torchvision/ops/feature_pyramid_network.py index 93caa47d04b..2e1ac0cd8cf 100644 --- a/torchvision/ops/feature_pyramid_network.py +++ b/torchvision/ops/feature_pyramid_network.py @@ -95,7 +95,8 @@ def __init__( nn.init.constant_(m.bias, 0) if extra_blocks is not None: - assert isinstance(extra_blocks, ExtraFPNBlock) + if not isinstance(extra_blocks, ExtraFPNBlock): + raise TypeError(f"extra_blocks should be of type ExtraFPNBlock not {type(extra_blocks)}") self.extra_blocks = extra_blocks def get_result_from_inner_blocks(self, x: Tensor, idx: int) -> Tensor: diff --git a/torchvision/ops/poolers.py b/torchvision/ops/poolers.py index ceabb77732b..f881201a2d2 100644 --- a/torchvision/ops/poolers.py +++ b/torchvision/ops/poolers.py @@ -104,7 +104,6 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float: approx_scale = float(s1) / float(s2) scale = 2 ** float(torch.tensor(approx_scale).log2().round()) possible_scales.append(scale) - assert possible_scales[0] == possible_scales[1] return possible_scales[0] @@ -112,7 +111,8 @@ def _infer_scale(feature: Tensor, original_size: List[int]) -> float: def _setup_scales( features: List[Tensor], image_shapes: List[Tuple[int, int]], canonical_scale: int, canonical_level: int ) -> Tuple[List[float], LevelMapper]: - assert len(image_shapes) != 0 + if not image_shapes: + raise ValueError("images list should not be empty") max_x = 0 max_y = 0 for shape in image_shapes: @@ -166,8 +166,8 @@ def _multiscale_roi_align( Returns: result (Tensor) """ - assert scales is not None - assert mapper is not None + if scales is None or mapper is None: + raise ValueError("scales and mapper should not be None") num_levels = len(x_filtered) rois = _convert_to_roi_format(boxes) diff --git a/torchvision/prototype/models/detection/ssdlite.py b/torchvision/prototype/models/detection/ssdlite.py index 6de34acb5ae..f69e860dff4 100644 --- a/torchvision/prototype/models/detection/ssdlite.py +++ b/torchvision/prototype/models/detection/ssdlite.py @@ -98,7 +98,10 @@ def ssdlite320_mobilenet_v3_large( anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) out_channels = det_utils.retrieve_out_channels(backbone, size) num_anchors = anchor_generator.num_anchors_per_location() - assert len(out_channels) == len(anchor_generator.aspect_ratios) + if len(out_channels) != len(anchor_generator.aspect_ratios): + raise ValueError( + f"The length of the output channels from the backbone {len(out_channels)} do not match the length of the anchor generator aspect ratios {len(anchor_generator.aspect_ratios)}" + ) defaults = { "score_thresh": 0.001, diff --git a/torchvision/transforms/_functional_video.py b/torchvision/transforms/_functional_video.py index 2ab7adb8af9..f969a2542d0 100644 --- a/torchvision/transforms/_functional_video.py +++ b/torchvision/transforms/_functional_video.py @@ -24,12 +24,14 @@ def crop(clip, i, j, h, w): Args: clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) """ - assert len(clip.size()) == 4, "clip should be a 4D tensor" + if len(clip.size()) != 4: + raise ValueError("clip should be a 4D tensor") return clip[..., i : i + h, j : j + w] def resize(clip, target_size, interpolation_mode): - assert len(target_size) == 2, "target size should be tuple (height, width)" + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) @@ -46,17 +48,20 @@ def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): Returns: clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) """ - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") clip = crop(clip, i, j, h, w) clip = resize(clip, size, interpolation_mode) return clip def center_crop(clip, crop_size): - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") h, w = clip.size(-2), clip.size(-1) th, tw = crop_size - assert h >= th and w >= tw, "height and width must be no smaller than crop_size" + if h < th or w < tw: + raise ValueError("height and width must be no smaller than crop_size") i = int(round((h - th) / 2.0)) j = int(round((w - tw) / 2.0)) @@ -87,7 +92,8 @@ def normalize(clip, mean, std, inplace=False): Returns: normalized clip (torch.tensor): Size is (C, T, H, W) """ - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") if not inplace: clip = clip.clone() mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) @@ -103,5 +109,6 @@ def hflip(clip): Returns: flipped clip (torch.tensor): Size is (C, T, H, W) """ - assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + if not _is_tensor_video_clip(clip): + raise ValueError("clip should be a 4D torch.tensor") return clip.flip(-1) diff --git a/torchvision/transforms/_transforms_video.py b/torchvision/transforms/_transforms_video.py index 4a36c8abbf9..69512af6eb1 100644 --- a/torchvision/transforms/_transforms_video.py +++ b/torchvision/transforms/_transforms_video.py @@ -59,7 +59,8 @@ def __init__( interpolation_mode="bilinear", ): if isinstance(size, tuple): - assert len(size) == 2, "size should be tuple (height, width)" + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") self.size = size else: self.size = (size, size) diff --git a/torchvision/utils.py b/torchvision/utils.py index 6d3293d103d..4737a047327 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -82,10 +82,8 @@ def make_grid( if normalize is True: tensor = tensor.clone() # avoid modifying tensor in-place - if value_range is not None: - assert isinstance( - value_range, tuple - ), "value_range has to be a tuple (min, max) if specified. min and max are numbers" + if value_range is not None and not isinstance(value_range, tuple): + raise TypeError("value_range has to be a tuple (min, max) if specified. min and max are numbers") def norm_ip(img, low, high): img.clamp_(min=low, max=high) @@ -103,7 +101,8 @@ def norm_range(t, value_range): else: norm_range(tensor, value_range) - assert isinstance(tensor, torch.Tensor) + if not isinstance(tensor, torch.Tensor): + raise TypeError("tensor should be of type torch.Tensor") if tensor.size(0) == 1: return tensor.squeeze(0)