Skip to content

Commit

Permalink
Add segmentation and region methods
Browse files Browse the repository at this point in the history
  • Loading branch information
LinasKo committed Jun 20, 2024
1 parent 04c5f18 commit bb513f8
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 14 deletions.
5 changes: 3 additions & 2 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ def from_lmm(

if lmm == LMM.FLORENCE_2:
assert isinstance(result, dict)
xyxy, labels, xyxyxyxy = from_florence_2(result, **kwargs)
xyxy, labels, mask, xyxyxyxy = from_florence_2(result, **kwargs)
if len(xyxy) == 0:
return cls.empty()

Expand All @@ -873,7 +873,8 @@ def from_lmm(
data[CLASS_NAME_DATA_FIELD] = labels
if xyxyxyxy is not None:
data[ORIENTED_BOX_COORDINATES] = xyxyxyxy
return cls(xyxy=xyxy, data=data)

return cls(xyxy=xyxy, mask=mask, data=data)

raise ValueError(f"Unsupported LMM: {lmm}")

Expand Down
82 changes: 70 additions & 12 deletions supervision/detection/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from supervision.detection.utils import polygon_to_xyxy
from supervision.detection.utils import polygon_to_mask, polygon_to_xyxy


class LMM(Enum):
Expand All @@ -16,12 +16,12 @@ class LMM(Enum):

REQUIRED_ARGUMENTS: Dict[LMM, List[str]] = {
LMM.PALIGEMMA: ["resolution_wh"],
LMM.FLORENCE_2: [],
LMM.FLORENCE_2: ["resolution_wh"],
}

ALLOWED_ARGUMENTS: Dict[LMM, List[str]] = {
LMM.PALIGEMMA: ["resolution_wh", "classes"],
LMM.FLORENCE_2: [],
LMM.FLORENCE_2: ["resolution_wh"],
}

SUPPORTED_TASKS_FLORENCE_2 = [
Expand All @@ -30,6 +30,11 @@ class LMM(Enum):
"<DENSE_REGION_CAPTION>",
"<REGION_PROPOSAL>",
"<OCR_WITH_REGION>",
"<REFERRING_EXPRESSION_SEGMENTATION>",
"<REGION_TO_SEGMENTATION>",
"<OPEN_VOCABULARY_DETECTION>",
"<REGION_TO_CATEGORY>",
"<REGION_TO_DESCRIPTION>",
]


Expand Down Expand Up @@ -86,8 +91,10 @@ def from_paligemma(


def from_florence_2(
result: dict,
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[np.ndarray]]:
result: dict, resolution_wh: Tuple[int, int]
) -> Tuple[
np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]
]:
"""
Parse results from the Florence 2 multi-model model.
https://huggingface.co/microsoft/Florence-2-large
Expand All @@ -100,31 +107,82 @@ def from_florence_2(
the bounding boxes coordinates in format `[x1, y1, x2, y2]`
labels: (Optional[np.ndarray]): An array of shape `(n,)` containing
the class labels for each bounding box
masks: (Optional[np.ndarray]): An array of shape `(n, h, w)` containing
the segmentation masks for each bounding box
obb_boxes: (Optional[np.ndarray]): An array of shape `(n, 4, 2)` containing
oriented bounding boxes.
"""
for task in ["<OD>", "<CAPTION_TO_PHRASE_GROUNDING>", "<DENSE_REGION_CAPTION>"]:
if task in result:
result = result[task]
xyxy = np.array(result["bboxes"], dtype=np.float32)
labels = np.array(result["labels"])
return xyxy, labels, None
if task not in result:
continue
result = result[task]
xyxy = np.array(result["bboxes"], dtype=np.float32)
labels = np.array(result["labels"])
return xyxy, labels, None, None

if "<REGION_PROPOSAL>" in result:
result = result["<REGION_PROPOSAL>"]
xyxy = np.array(result["bboxes"], dtype=np.float32)
# provides labels, but they are ["", "", "", ...]
return xyxy, None, None
return xyxy, None, None, None

if "<OCR_WITH_REGION>" in result:
result = result["<OCR_WITH_REGION>"]
xyxyxyxy = np.array(result["quad_boxes"], dtype=np.float32)
xyxyxyxy = xyxyxyxy.reshape(-1, 4, 2)
xyxy = np.array([polygon_to_xyxy(polygon) for polygon in xyxyxyxy])
labels = np.array(result["labels"])
return xyxy, labels, xyxyxyxy
return xyxy, labels, None, xyxyxyxy

for task in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
if task not in result:
continue

result = result[task]
xyxy_list = []
masks_list = []
for polygons_of_same_class in result["polygons"]:
for polygon in polygons_of_same_class:
mask = polygon_to_mask(polygon, resolution_wh)
masks_list.append(mask)
xyxy = polygon_to_xyxy(polygon)
xyxy_list.append(xyxy)
# per-class labels also provided, but they are ["", "", "", ...]
# when we figure out how to set class names, we can do
# zip(result["labels"], result["polygons"])
xyxy = np.array(xyxy_list, dtype=np.float32)
masks = np.array(masks_list)
return xyxy, None, masks, None

if "<OPEN_VOCABULARY_DETECTION>" in result:
result = result["<OPEN_VOCABULARY_DETECTION>"]
xyxy = np.array(result["bboxes"], dtype=np.float32)
labels = np.array(result["bboxes_labels"])
# Also has "polygons" and "polygons_labels", but they don't seem to be used
return xyxy, labels, None, None

for task in ["<REGION_TO_CATEGORY>", "<REGION_TO_DESCRIPTION>"]:
if task not in result:
continue

result = result[task]
assert isinstance(
result, str
), f"Expected string as <REGION_TO_CATEGORY> result, got {type(result)}"

pattern = re.compile(r"<loc_(\d+)><loc_(\d+)><loc_(\d+)><loc_(\d+)>")
match = pattern.search(result)
assert (
match is not None
), f"Expected string to end in location tags, but got {result}"

xyxy = np.array(match.groups(), dtype=np.float32)
result_string = result[: match.start()]
labels = np.array([result_string])
return xyxy, labels, None, None

task = list(result.keys())[0]
assert task not in SUPPORTED_TASKS_FLORENCE_2, f"Expected to support task {task}"
raise NotImplementedError(
f"{task} task not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}"
)

0 comments on commit bb513f8

Please sign in to comment.