Skip to content
99 changes: 99 additions & 0 deletions supervision/dataset/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
import numpy as np

from supervision.classification.core import Classifications
from supervision.dataset.formats.coco import (
load_coco_annotations,
save_coco_annotations,
)
from supervision.dataset.formats.pascal_voc import (
detections_to_pascal_voc,
load_pascal_voc_annotations,
Expand Down Expand Up @@ -333,6 +337,101 @@ def as_yolo(
if data_yaml_path is not None:
save_data_yaml(data_yaml_path=data_yaml_path, classes=self.classes)

@classmethod
def from_coco(
cls,
images_directory_path: str,
annotations_path: str,
force_masks: bool = False,
) -> DetectionDataset:
"""
Creates a Dataset instance from YOLO formatted data.

Args:
images_directory_path (str): The path to the directory containing the images.
annotations_path (str): The path to the json annotation files.
force_masks (bool, optional): If True, forces masks to be loaded for all annotations, regardless of whether they are present.

Returns:
DetectionDataset: A DetectionDataset instance containing the loaded images and annotations.

Example:
```python
>>> import roboflow
>>> from roboflow import Roboflow
>>> import supervision as sv

>>> roboflow.login()

>>> rf = Roboflow()

>>> project = rf.workspace(WORKSPACE_ID).project(PROJECT_ID)
>>> dataset = project.version(PROJECT_VERSION).download("coco")

>>> ds = sv.DetectionDataset.from_coco(
... images_directory_path=f"{dataset.location}/train",
... annotations_path=f"{dataset.location}/train/_annotations.coco.json",
... )

>>> ds.classes
['dog', 'person']
```
"""
classes, images, annotations = load_coco_annotations(
images_directory_path=images_directory_path,
annotations_path=annotations_path,
force_masks=force_masks,
)
return DetectionDataset(classes=classes, images=images, annotations=annotations)

def as_coco(
self,
images_directory_path: Optional[str] = None,
annotations_path: Optional[str] = None,
min_image_area_percentage: float = 0.0,
max_image_area_percentage: float = 1.0,
approximation_percentage: float = 0.0,
licenses: Optional[list] = None,
info: Optional[dict] = None,
) -> None:
"""
Exports the dataset to COCO format. This method saves the images and their corresponding
annotations in COCO format, which is a simple json file that describes an object in the image.
Annotation json file also include category maps.

The method allows filtering the detections based on their area percentage and offers an option for polygon approximation.

Args:
images_directory_path (Optional[str]): The path to the directory where the images should be saved.
If not provided, images will not be saved.
annotations_directory_path (Optional[str]): The path to the directory where the annotations in
YOLO format should be saved. If not provided, annotations will not be saved.
min_image_area_percentage (float): The minimum percentage of detection area relative to
the image area for a detection to be included.
max_image_area_percentage (float): The maximum percentage of detection area relative to
the image area for a detection to be included.
approximation_percentage (float): The percentage of polygon points to be removed from the input polygon,
in the range [0, 1). This is useful for simplifying the annotations.
licenses (Optional[str]): List of licenses for images
info (Optional[dict]): Information of Dataset as dictionary
"""
if images_directory_path is not None:
save_dataset_images(
images_directory_path=images_directory_path, images=self.images
)
if annotations_path is not None:
save_coco_annotations(
annotation_path=annotations_path,
images=self.images,
annotations=self.annotations,
classes=self.classes,
min_image_area_percentage=min_image_area_percentage,
max_image_area_percentage=max_image_area_percentage,
approximation_percentage=approximation_percentage,
licenses=licenses,
info=info,
)


@dataclass
class ClassificationDataset(BaseDataset):
Expand Down
221 changes: 221 additions & 0 deletions supervision/dataset/formats/coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import os
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Tuple

import cv2
import numpy as np

from supervision.dataset.ultils import approximate_mask_with_polygons
from supervision.detection.core import Detections
from supervision.detection.utils import polygon_to_mask
from supervision.utils.file import read_json_file, save_json_file


def coco_categories_to_classes(coco_categories: List[dict]) -> List[str]:
return [
category["name"]
for category in sorted(coco_categories, key=lambda category: category["id"])
if category["supercategory"] != "none"
]


def classes_to_coco_categories(classes: List[str]) -> List[dict]:
return [
{
"id": class_id,
"name": class_name,
"supercategory": "common-objects",
}
for class_id, class_name in enumerate(classes)
]


def group_coco_annotations_by_image_id(
coco_annotations: List[dict],
) -> Dict[int, List[dict]]:
annotations = {}
for annotation in coco_annotations:
image_id = annotation["image_id"]
if image_id not in annotations:
annotations[image_id] = []
annotations[image_id].append(annotation)
return annotations


def _polygons_to_masks(
polygons: List[np.ndarray], resolution_wh: Tuple[int, int]
) -> np.ndarray:
return np.array(
[
polygon_to_mask(polygon=polygon, resolution_wh=resolution_wh)
for polygon in polygons
],
dtype=bool,
)


def coco_annotations_to_detections(
image_annotations: List[dict], resolution_wh: Tuple[int, int], with_masks: bool
) -> Detections:
if not image_annotations:
return Detections.empty()

class_ids = [
image_annotation["category_id"] for image_annotation in image_annotations
]
xyxy = [image_annotation["bbox"] for image_annotation in image_annotations]
xyxy = np.asarray(xyxy)
xyxy[:, 2:4] += xyxy[:, 0:2]

if with_masks:
polygons = [
np.reshape(
np.asarray(image_annotation["segmentation"], dtype=np.int32), (-1, 2)
)
for image_annotation in image_annotations
]
mask = _polygons_to_masks(polygons=polygons, resolution_wh=resolution_wh)
return Detections(
class_id=np.asarray(class_ids, dtype=int), xyxy=xyxy, mask=mask
)

return Detections(xyxy=xyxy, class_id=np.asarray(class_ids, dtype=int))


def detections_to_coco_annotations(
detections: Detections,
image_id: int,
annotation_id: int,
min_image_area_percentage: float = 0.0,
max_image_area_percentage: float = 1.0,
approximation_percentage: float = 0.75,
) -> Tuple[List[Dict], int]:
coco_annotations = []
for xyxy, mask, _, class_id, _ in detections:
box_width, box_height = xyxy[2] - xyxy[0], xyxy[3] - xyxy[1]
polygon = []
if mask is not None:
polygon = list(
approximate_mask_with_polygons(
mask=mask,
min_image_area_percentage=min_image_area_percentage,
max_image_area_percentage=max_image_area_percentage,
approximation_percentage=approximation_percentage,
)[0].flatten()
)
coco_annotation = {
"id": annotation_id,
"image_id": image_id,
"category_id": int(class_id),
"bbox": [xyxy[0], xyxy[1], box_width, box_height],
"area": box_width * box_height,
"segmentation": polygon,
"iscrowd": 0,
}
coco_annotations.append(coco_annotation)
annotation_id += 1
return coco_annotations, annotation_id


def load_coco_annotations(
images_directory_path: str,
annotations_path: str,
force_masks: bool = False,
) -> Tuple[List[str], Dict[str, np.ndarray], Dict[str, Detections]]:
coco_data = read_json_file(file_path=annotations_path)
classes = coco_categories_to_classes(coco_categories=coco_data["categories"])
coco_images = coco_data["images"]
coco_annotations_groups = group_coco_annotations_by_image_id(
coco_annotations=coco_data["annotations"]
)

images = {}
annotations = {}

for coco_image in coco_images:
image_name, image_width, image_height = (
coco_image["file_name"],
coco_image["width"],
coco_image["height"],
)
image_annotations = coco_annotations_groups.get(coco_image["id"], [])
image_path = os.path.join(images_directory_path, image_name)

image = cv2.imread(str(image_path))
annotation = coco_annotations_to_detections(
image_annotations=image_annotations,
resolution_wh=(image_width, image_height),
with_masks=force_masks,
)

images[image_name] = image
annotations[image_name] = annotation

return classes, images, annotations


def save_coco_annotations(
annotation_path: str,
images: Dict[str, np.ndarray],
annotations: Dict[str, Detections],
classes: List[str],
min_image_area_percentage: float = 0.0,
max_image_area_percentage: float = 1.0,
approximation_percentage: float = 0.75,
licenses: List[dict] = None,
info: dict = None,
) -> None:
Path(annotation_path).parent.mkdir(parents=True, exist_ok=True)
if not info:
info = {}
if not licenses:
licenses = [
{
"id": 1,
"url": "https://creativecommons.org/licenses/by/4.0/",
"name": "CC BY 4.0",
}
]

coco_annotations = []
coco_images = []
coco_categories = classes_to_coco_categories(classes=classes)

image_id = 0
annotation_id = 0
for image_name, image in images.items():
image_height, image_width, _ = image.shape

coco_image = {
"id": image_id,
"license": 1,
"file_name": image_name,
"height": image_height,
"width": image_width,
"date_captured": datetime.now().strftime("%m/%d/%Y,%H:%M:%S"),
}

coco_images.append(coco_image)
detections = annotations[image_name]

coco_annotation, label_id = detections_to_coco_annotations(
detections=detections,
image_id=image_id,
annotation_id=annotation_id,
min_image_area_percentage=min_image_area_percentage,
max_image_area_percentage=max_image_area_percentage,
approximation_percentage=approximation_percentage,
)

coco_annotations.extend(coco_annotation)
image_id += 1

annotation_dict = {
"info": info,
"licenses": licenses,
"categories": coco_categories,
"images": coco_images,
"annotations": coco_annotations,
}
save_json_file(annotation_dict, file_path=annotation_path)
14 changes: 0 additions & 14 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,20 +355,6 @@ def from_sam(cls, sam_result: List[dict]) -> Detections:

return Detections(xyxy=xywh_to_xyxy(boxes_xywh=xywh), mask=mask)

@classmethod
@deprecated(
"Dataset loading and saving is going to be executed by supervision.dataset.core.Dataset"
)
def from_coco_annotations(cls, coco_annotation: dict) -> Detections:
xyxy, class_id = [], []

for annotation in coco_annotation:
x_min, y_min, width, height = annotation["bbox"]
xyxy.append([x_min, y_min, x_min + width, y_min + height])
class_id.append(annotation["category_id"])

return cls(xyxy=np.array(xyxy), class_id=np.array(class_id))

@classmethod
def empty(cls) -> Detections:
"""
Expand Down
Loading