Skip to content

Commit

Permalink
OBB: Fix plot_images (#7592)
Browse files Browse the repository at this point in the history
  • Loading branch information
Laughing-q committed Jan 15, 2024
1 parent 2f11ab5 commit cf50bd9
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
9 changes: 4 additions & 5 deletions ultralytics/models/yolo/obb/val.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ def pred_to_json(self, predn, filename):

def save_one_txt(self, predn, save_conf, shape, file):
"""Save YOLO detections to a txt file in normalized coordinates in a specific format."""
gn = torch.tensor(shape)[[1, 0, 1, 0]] # normalization gain whwh
for *xyxy, conf, cls, angle in predn.tolist():
xywha = torch.tensor([*xyxy, angle]).view(1, 5)
xywha[:, :4] /= gn
xyxyxyxy = ops.xywhr2xyxyxyxy(xywha).view(-1).tolist() # normalized xywh
gn = torch.tensor(shape)[[1, 0]] # normalization gain whwh
for *xywh, conf, cls, angle in predn.tolist():
xywha = torch.tensor([*xywh, angle]).view(1, 5)
xyxyxyxy = (ops.xywhr2xyxyxyxy(xywha) / gn).view(-1).tolist() # normalized xywh
line = (cls, *xyxyxyxy, conf) if save_conf else (cls, *xyxyxyxy) # label format
with open(file, "a") as f:
f.write(("%g " * len(line)).rstrip() % line + "\n")
Expand Down
8 changes: 4 additions & 4 deletions ultralytics/utils/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ def xyxyxyxy2xywhr(corners):
) # rboxes


def xywhr2xyxyxyxy(center):
def xywhr2xyxyxyxy(rboxes):
"""
Convert batched Oriented Bounding Boxes (OBB) from [xywh, rotation] to [xy1, xy2, xy3, xy4]. Rotation values should
be in degrees from 0 to 90.
Expand All @@ -552,11 +552,11 @@ def xywhr2xyxyxyxy(center):
Returns:
(numpy.ndarray | torch.Tensor): Converted corner points of shape (n, 4, 2) or (b, n, 4, 2).
"""
is_numpy = isinstance(center, np.ndarray)
is_numpy = isinstance(rboxes, np.ndarray)
cos, sin = (np.cos, np.sin) if is_numpy else (torch.cos, torch.sin)

ctr = center[..., :2]
w, h, angle = (center[..., i : i + 1] for i in range(2, 5))
ctr = rboxes[..., :2]
w, h, angle = (rboxes[..., i : i + 1] for i in range(2, 5))
cos_value, sin_value = cos(angle), sin(angle)
vec1 = [w / 2 * cos_value, w / 2 * sin_value]
vec2 = [-h / 2 * sin_value, h / 2 * cos_value]
Expand Down
14 changes: 7 additions & 7 deletions ultralytics/utils/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,16 +706,16 @@ def plot_images(
if len(bboxes):
boxes = bboxes[idx]
conf = confs[idx] if confs is not None else None # check for confidence presence (label vs pred)
is_obb = boxes.shape[-1] == 5 # xywhr
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
if len(boxes):
if boxes[:, :4].max() <= 1.1: # if normalized with tolerance 0.1
boxes[:, [0, 2]] *= w # scale to pixels
boxes[:, [1, 3]] *= h
boxes[..., 0::2] *= w # scale to pixels
boxes[..., 1::2] *= h
elif scale < 1: # absolute coords need scale if image scales
boxes[:, :4] *= scale
boxes[:, 0] += x
boxes[:, 1] += y
is_obb = boxes.shape[-1] == 5 # xywhr
boxes = ops.xywhr2xyxyxyxy(boxes) if is_obb else ops.xywh2xyxy(boxes)
boxes[..., :4] *= scale
boxes[..., 0::2] += x
boxes[..., 1::2] += y
for j, box in enumerate(boxes.astype(np.int64).tolist()):
c = classes[j]
color = colors(c)
Expand Down

0 comments on commit cf50bd9

Please sign in to comment.