From 3ce23ef43ac70db6ff8d5b6fbbe10f9d5d691bcf Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 8 Jul 2022 18:01:09 +0200 Subject: [PATCH] Fixed acceptable and non-acceptable types for Cutmix/Mixup --- test/test_prototype_transforms.py | 22 +++++++++++++++++++- torchvision/prototype/transforms/_augment.py | 8 ++++--- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 2c8f65e3086..1934a8bd408 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -3,7 +3,13 @@ import pytest import torch from common_utils import assert_equal -from test_prototype_transforms_functional import make_images, make_bounding_boxes, make_one_hot_labels +from test_prototype_transforms_functional import ( + make_images, + make_bounding_boxes, + make_one_hot_labels, + make_segmentation_masks, + make_label, +) from torchvision.prototype import transforms, features from torchvision.transforms.functional import to_pil_image, pil_to_tensor @@ -102,6 +108,20 @@ def test_common(self, transform, input): def test_mixup_cutmix(self, transform, input): transform(input) + @pytest.mark.parametrize("transform", [transforms.RandomMixup(alpha=1.0), transforms.RandomCutmix(alpha=1.0)]) + def test_mixup_cutmix_assertions(self, transform): + for bbox in make_bounding_boxes(): + with pytest.raises(TypeError, match="does not support"): + transform(bbox) + break + for mask in make_segmentation_masks(): + with pytest.raises(TypeError, match="does not support"): + transform(mask) + break + label = make_label() + with pytest.raises(TypeError, match="does not support"): + transform(label) + @parametrize( [ ( diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 9a2c2c0f416..df1dd916467 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -9,7 +9,7 @@ from torchvision.prototype.transforms import Transform, functional as F from ._transform import _RandomApplyTransform -from ._utils import query_image, get_image_dimensions, has_all +from ._utils import query_image, get_image_dimensions, has_any class RandomErasing(_RandomApplyTransform): @@ -106,8 +106,10 @@ def __init__(self, *, alpha: float) -> None: def forward(self, *inpts: Any) -> Any: sample = inpts if len(inpts) > 1 else inpts[0] - if not has_all(sample, features.Image, features.OneHotLabel): - raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.") + if has_any(sample, features.BoundingBox, features.SegmentationMask, features.Label): + raise TypeError( + f"{type(self).__name__}() does not support bounding boxes, segmentation masks and plain labels." + ) return super().forward(sample) def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features.OneHotLabel: