Skip to content

Commit

Permalink
florence_2: Clean up task selector
Browse files Browse the repository at this point in the history
  • Loading branch information
LinasKo committed Jun 21, 2024
1 parent 23e6350 commit ec51712
Showing 1 changed file with 15 additions and 25 deletions.
40 changes: 15 additions & 25 deletions supervision/detection/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,32 @@ def from_florence_2(
obb_boxes: (Optional[np.ndarray]): An array of shape `(n, 4, 2)` containing
oriented bounding boxes.
"""
for task in ["<OD>", "<CAPTION_TO_PHRASE_GROUNDING>", "<DENSE_REGION_CAPTION>"]:
if task not in result:
continue
result = result[task]
assert len(result) == 1, f"Expected result with a single element. Got: {result}"
task = list(result.keys())[0]
if task not in SUPPORTED_TASKS_FLORENCE_2:
raise ValueError(
f"{task} not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}"
)
result = result[task]

if task in ["<OD>", "<CAPTION_TO_PHRASE_GROUNDING>", "<DENSE_REGION_CAPTION>"]:
xyxy = np.array(result["bboxes"], dtype=np.float32)
labels = np.array(result["labels"])
return xyxy, labels, None, None

if "<REGION_PROPOSAL>" in result:
result = result["<REGION_PROPOSAL>"]
if task == "<REGION_PROPOSAL>":
xyxy = np.array(result["bboxes"], dtype=np.float32)
# provides labels, but they are ["", "", "", ...]
return xyxy, None, None, None

if "<OCR_WITH_REGION>" in result:
result = result["<OCR_WITH_REGION>"]
if task == "<OCR_WITH_REGION>":
xyxyxyxy = np.array(result["quad_boxes"], dtype=np.float32)
xyxyxyxy = xyxyxyxy.reshape(-1, 4, 2)
xyxy = np.array([polygon_to_xyxy(polygon) for polygon in xyxyxyxy])
labels = np.array(result["labels"])
return xyxy, labels, None, xyxyxyxy

for task in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
if task not in result:
continue

result = result[task]
if task in ["<REFERRING_EXPRESSION_SEGMENTATION>", "<REGION_TO_SEGMENTATION>"]:
xyxy_list = []
masks_list = []
for polygons_of_same_class in result["polygons"]:
Expand All @@ -155,18 +154,13 @@ def from_florence_2(
masks = np.array(masks_list)
return xyxy, None, masks, None

if "<OPEN_VOCABULARY_DETECTION>" in result:
result = result["<OPEN_VOCABULARY_DETECTION>"]
if task == "<OPEN_VOCABULARY_DETECTION>":
xyxy = np.array(result["bboxes"], dtype=np.float32)
labels = np.array(result["bboxes_labels"])
# Also has "polygons" and "polygons_labels", but they don't seem to be used
return xyxy, labels, None, None

for task in ["<REGION_TO_CATEGORY>", "<REGION_TO_DESCRIPTION>"]:
if task not in result:
continue

result = result[task]
if task in ["<REGION_TO_CATEGORY>", "<REGION_TO_DESCRIPTION>"]:
assert isinstance(
result, str
), f"Expected string as <REGION_TO_CATEGORY> result, got {type(result)}"
Expand All @@ -182,8 +176,4 @@ def from_florence_2(
labels = np.array([result_string])
return xyxy, labels, None, None

task = list(result.keys())[0]
assert task not in SUPPORTED_TASKS_FLORENCE_2, f"Expected to support task {task}"
raise NotImplementedError(
f"{task} task not supported. Supported tasks are: {SUPPORTED_TASKS_FLORENCE_2}"
)
assert False, f"Unimplemented task: {task}"

0 comments on commit ec51712

Please sign in to comment.