Skip to content

Commit

Permalink
allow sequence fill for v2 AA scripted
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Aug 31, 2023
1 parent 96950a5 commit ac6bd0c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
21 changes: 12 additions & 9 deletions test/test_transforms_v2_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -755,10 +755,11 @@ def test_randaug(self, inpt, interpolation, mocker):
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_randaug_jit(self, interpolation):
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_randaug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1)
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)
t = v2_transforms.RandAugment(interpolation=interpolation, num_ops=1, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
Expand Down Expand Up @@ -830,10 +831,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_trivial_aug_jit(self, interpolation):
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_trivial_aug_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation)
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)
t = v2_transforms.TrivialAugmentWide(interpolation=interpolation, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
Expand Down Expand Up @@ -906,11 +908,12 @@ def test_augmix(self, inpt, interpolation, mocker):
v2_transforms.InterpolationMode.BILINEAR,
],
)
def test_augmix_jit(self, interpolation):
@pytest.mark.parametrize("fill", [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1], 1])
def test_augmix_jit(self, interpolation, fill):
inpt = torch.randint(0, 256, size=(1, 3, 256, 256), dtype=torch.uint8)

t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)
t = v2_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1, fill=fill)

tt_ref = torch.jit.script(t_ref)
tt = torch.jit.script(t)
Expand Down
4 changes: 2 additions & 2 deletions torchvision/transforms/v2/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def __init__(
def _extract_params_for_v1_transform(self) -> Dict[str, Any]:
params = super()._extract_params_for_v1_transform()

if not (params["fill"] is None or isinstance(params["fill"], (int, float))):
raise ValueError(f"{type(self).__name__}() can only be scripted for a scalar `fill`, but got {self.fill}.")
if isinstance(params["fill"], dict):
raise ValueError(f"{type(self).__name__}() can not be scripted for when `fill` is a dictionary.")

return params

Expand Down

0 comments on commit ac6bd0c

Please sign in to comment.