Skip to content

Commit

Permalink
[fbsync] Support encoded RLE format in for COCO segmentations (#8387)
Browse files Browse the repository at this point in the history
Reviewed By: vmoens

Differential Revision: D58283856

fbshipit-source-id: 98805162e3209173811108468037ea490d123082
  • Loading branch information
NicolasHug authored and facebook-github-bot committed Jun 7, 2024
1 parent 4ce50d6 commit 27dc11c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
42 changes: 36 additions & 6 deletions test/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,32 +782,46 @@ def inject_fake_data(self, tmpdir, config):

annotation_folder = tmpdir / self._ANNOTATIONS_FOLDER
os.makedirs(annotation_folder)

segmentation_kind = config.pop("segmentation_kind", "list")
info = self._create_annotation_file(
annotation_folder, self._ANNOTATIONS_FILE, file_names, num_annotations_per_image
annotation_folder,
self._ANNOTATIONS_FILE,
file_names,
num_annotations_per_image,
segmentation_kind=segmentation_kind,
)

info["num_examples"] = num_images
return info

def _create_annotation_file(self, root, name, file_names, num_annotations_per_image):
def _create_annotation_file(self, root, name, file_names, num_annotations_per_image, segmentation_kind="list"):
image_ids = [int(file_name.stem) for file_name in file_names]
images = [dict(file_name=str(file_name), id=id) for file_name, id in zip(file_names, image_ids)]

annotations, info = self._create_annotations(image_ids, num_annotations_per_image)
annotations, info = self._create_annotations(image_ids, num_annotations_per_image, segmentation_kind)
self._create_json(root, name, dict(images=images, annotations=annotations))

return info

def _create_annotations(self, image_ids, num_annotations_per_image):
def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"):
annotations = []
annotion_id = 0

for image_id in itertools.islice(itertools.cycle(image_ids), len(image_ids) * num_annotations_per_image):
segmentation = {
"list": [torch.rand(8).tolist()],
"rle": {"size": [10, 10], "counts": [1]},
"rle_encoded": {"size": [2400, 2400], "counts": "PQRQ2[1\\Y2f0gNVNRhMg2"},
"bad": 123,
}[segmentation_kind]

annotations.append(
dict(
image_id=image_id,
id=annotion_id,
bbox=torch.rand(4).tolist(),
segmentation=[torch.rand(8).tolist()],
segmentation=segmentation,
category_id=int(torch.randint(91, ())),
area=float(torch.rand(1)),
iscrowd=int(torch.randint(2, size=(1,))),
Expand All @@ -832,11 +846,27 @@ def test_slice_error(self):
with pytest.raises(ValueError, match="Index must be of type integer"):
dataset[:2]

def test_segmentation_kind(self):
if isinstance(self, CocoCaptionsTestCase):
return

for segmentation_kind in ("list", "rle", "rle_encoded"):
config = {"segmentation_kind": segmentation_kind}
with self.create_dataset(config) as (dataset, _):
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all")
list(dataset)

config = {"segmentation_kind": "bad"}
with self.create_dataset(config) as (dataset, _):
dataset = datasets.wrap_dataset_for_transforms_v2(dataset, target_keys="all")
with pytest.raises(ValueError, match="COCO segmentation expected to be a dict or a list"):
list(dataset)


class CocoCaptionsTestCase(CocoDetectionTestCase):
DATASET_CLASS = datasets.CocoCaptions

def _create_annotations(self, image_ids, num_annotations_per_image):
def _create_annotations(self, image_ids, num_annotations_per_image, segmentation_kind="list"):
captions = [str(idx) for idx in range(num_annotations_per_image)]
annotations = combinations_grid(image_id=image_ids, caption=captions)
for id, annotation in enumerate(annotations):
Expand Down
13 changes: 8 additions & 5 deletions torchvision/tv_tensors/_dataset_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,14 @@ def coco_dectection_wrapper_factory(dataset, target_keys):
def segmentation_to_mask(segmentation, *, canvas_size):
from pycocotools import mask

segmentation = (
mask.frPyObjects(segmentation, *canvas_size)
if isinstance(segmentation, dict)
else mask.merge(mask.frPyObjects(segmentation, *canvas_size))
)
if isinstance(segmentation, dict):
# if counts is a string, it is already an encoded RLE mask
if not isinstance(segmentation["counts"], str):
segmentation = mask.frPyObjects(segmentation, *canvas_size)
elif isinstance(segmentation, list):
segmentation = mask.merge(mask.frPyObjects(segmentation, *canvas_size))
else:
raise ValueError(f"COCO segmentation expected to be a dict or a list, got {type(segmentation)}")
return torch.from_numpy(mask.decode(segmentation))

def wrapper(idx, sample):
Expand Down

0 comments on commit 27dc11c

Please sign in to comment.