Skip to content

Commit

Permalink
require categories to be explicitly passed
Browse files Browse the repository at this point in the history
  • Loading branch information
brimoor committed May 6, 2024
1 parent 6b600a8 commit b9ed7e9
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 30 deletions.
9 changes: 5 additions & 4 deletions docs/source/user_guide/export_datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1742,11 +1742,12 @@ format as follows:
.. note::

You can pass the optional `classes` parameter to
You can pass the optional `classes` or `categories` parameters to
:meth:`export() <fiftyone.core.collections.SampleCollection.export>` to
explicitly define the class list to use in the exported labels. Otherwise,
the strategy outlined in :ref:`this section <export-class-lists>` will be
used to populate the class list.
explicitly define the class list/category IDs to use in the exported
labels. Otherwise, the strategy outlined in
:ref:`this section <export-class-lists>` will be used to populate the class
list.

You can also perform labels-only exports of COCO-formatted labels by providing
the `labels_path` parameter instead of `export_dir`:
Expand Down
40 changes: 14 additions & 26 deletions fiftyone/utils/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,9 @@ class COCODetectionDatasetExporter(
images to disk. By default, ``fiftyone.config.default_image_ext``
is used
classes (None): the list of possible class labels
categories (None): a list of category dicts in the format of
:meth:`parse_coco_categories` specifying the classes and their
category IDs
info (None): a dict of info as returned by
:meth:`load_coco_detection_annotations` to include in the exported
JSON. If not provided, this info will be extracted when
Expand Down Expand Up @@ -725,6 +728,7 @@ def __init__(
abs_paths=False,
image_format=None,
classes=None,
categories=None,
info=None,
extra_attrs=True,
annotation_id=None,
Expand Down Expand Up @@ -754,6 +758,7 @@ def __init__(
self.abs_paths = abs_paths
self.image_format = image_format
self.classes = classes
self.categories = categories
self.info = info
self.extra_attrs = extra_attrs
self.annotation_id = annotation_id
Expand Down Expand Up @@ -799,8 +804,6 @@ def setup(self):
def log_collection(self, sample_collection):
if self.info is None:
self.info = sample_collection.info
if "categories" in self.info:
self._parse_classes()

def export_sample(self, image_or_path, label, metadata=None):
out_image_path, uuid = self._media_exporter.export(image_or_path)
Expand Down Expand Up @@ -902,13 +905,9 @@ def close(self, *args):

licenses = _info.get("licenses", [])

try:
categories = _info.get("categories", None)
parse_coco_categories(categories)
except:
categories = None

if categories is None:
if self.categories is not None:
categories = self.categories
else:
categories = [
{
"id": i,
Expand All @@ -933,13 +932,10 @@ def close(self, *args):
self._media_exporter.close()

def _parse_classes(self):
if self.info is not None:
labels_map_rev = _parse_categories(self.info, self.classes)
else:
labels_map_rev = None

if labels_map_rev is not None:
self._labels_map_rev = labels_map_rev
if self.categories is not None:
self._labels_map_rev = _parse_categories(
self.categories, classes=self.classes
)
self._dynamic_classes = False
elif self.classes is None:
self._classes = set()
Expand Down Expand Up @@ -2032,16 +2028,8 @@ def _get_matching_objects(coco_objects, class_ids):
return [obj for obj in coco_objects if obj.category_id in class_ids]


def _parse_categories(info, classes):
categories = info.get("categories", None)
if categories is None:
return None

try:
classes_map, _ = parse_coco_categories(categories)
except:
logger.debug("Failed to parse categories from info")
return None
def _parse_categories(categories, classes=None):
classes_map, _ = parse_coco_categories(categories)

if classes is None:
return {c: i for i, c in classes_map.items()}
Expand Down
28 changes: 28 additions & 0 deletions tests/unittests/import_export_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,6 +1287,34 @@ def test_coco_detection_dataset(self):
# data/_images/<filename>
self.assertEqual(len(relpath.split(os.path.sep)), 3)

# Non-sequential categories

export_dir = self._new_dir()

categories = [
{"supercategory": "animal", "id": 10, "name": "cat"},
{"supercategory": "vehicle", "id": 20, "name": "dog"},
]

dataset.export(
export_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
categories=categories,
)

dataset2 = fo.Dataset.from_dir(
dataset_dir=export_dir,
dataset_type=fo.types.COCODetectionDataset,
label_types="detections",
label_field="predictions",
)
categories2 = dataset2.info["categories"]

self.assertSetEqual(
{c["id"] for c in categories},
{c["id"] for c in categories2},
)

@drop_datasets
def test_voc_detection_dataset(self):
dataset = self._make_dataset()
Expand Down

0 comments on commit b9ed7e9

Please sign in to comment.