Skip to content
89 changes: 88 additions & 1 deletion test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import enum
import inspect
from importlib.machinery import SourceFileLoader
from pathlib import Path

import numpy as np
import PIL.Image
import pytest

import torch
from prototype_common_utils import ArgsKwargs, assert_equal, make_images
from prototype_common_utils import (
ArgsKwargs,
assert_equal,
make_bounding_box,
make_detection_mask,
make_image,
make_images,
make_label,
)
from torchvision import transforms as legacy_transforms
from torchvision._utils import sequence_to_str
from torchvision.prototype import features, transforms as prototype_transforms
Expand Down Expand Up @@ -840,3 +850,80 @@ def test_aa(self, inpt, interpolation):
output = t(inpt)

assert_equal(expected_output, output)


# Import reference detection transforms here for consistency checks
# torchvision/references/detection/transforms.py
ref_det_filepath = Path(__file__).parent.parent / "references" / "detection" / "transforms.py"
det_transforms = SourceFileLoader(ref_det_filepath.stem, ref_det_filepath.as_posix()).load_module()


class TestRefDetTransforms:
def make_datapoints(self, with_mask=True):
size = (600, 800)
num_objects = 22

pil_image = to_image_pil(make_image(size=size, color_space=features.ColorSpace.RGB))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

yield (pil_image, target)

tensor_image = torch.randint(0, 256, size=(3, *size), dtype=torch.uint8)
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

yield (tensor_image, target)

feature_image = features.Image(torch.randint(0, 256, size=(3, *size), dtype=torch.uint8))
target = {
"boxes": make_bounding_box(image_size=size, format="XYXY", extra_dims=(num_objects,), dtype=torch.float),
"labels": make_label(extra_dims=(num_objects,), categories=80),
}
if with_mask:
target["masks"] = make_detection_mask(size=size, num_objects=num_objects, dtype=torch.long)

yield (feature_image, target)

@pytest.mark.parametrize(
"t_ref, t, data_kwargs",
[
(det_transforms.RandomHorizontalFlip(p=1.0), prototype_transforms.RandomHorizontalFlip(p=1.0), {}),
(det_transforms.RandomIoUCrop(), prototype_transforms.RandomIoUCrop(), {"with_mask": False}),
(det_transforms.RandomZoomOut(), prototype_transforms.RandomZoomOut(), {"with_mask": False}),
(det_transforms.ScaleJitter((1024, 1024)), prototype_transforms.ScaleJitter((1024, 1024)), {}),
(
det_transforms.FixedSizeCrop((1024, 1024), fill=0),
prototype_transforms.FixedSizeCrop((1024, 1024), fill=0),
{},
),
(
det_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
prototype_transforms.RandomShortestSize(
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
),
{},
),
],
)
def test_transform(self, t_ref, t, data_kwargs):
for dp in self.make_datapoints(**data_kwargs):

# We should use prototype transform first as reference transform performs inplace target update
torch.manual_seed(12)
output = t(dp)

torch.manual_seed(12)
expected_output = t_ref(*dp)

assert_equal(expected_output, output)