Skip to content

Commit

Permalink
Fixup draw segm w/ non-rgb masks + pred/gt bypass
Browse files Browse the repository at this point in the history
  • Loading branch information
plstcharles committed Jan 15, 2019
1 parent f13e8f6 commit 474e70f
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion thelper/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,17 +897,23 @@ def draw_segments(images, # type: Union[List[np.ndarray], np.nda
if img_shape is None:
img_shape = image.shape
if img_grid_shape is None:
img_grid_shape = (img_shape[0] * grid_size_y, img_shape[1] * grid_size_x, img_shape[2])
img_grid_shape = (img_shape[0] * grid_size_y, img_shape[1] * grid_size_x, 3)
if img_grid is None or img_grid.shape != img_grid_shape:
img_grid = np.zeros(img_grid_shape, dtype=np.uint8)
mask = None
if masks_pred is not None:
mask = masks_pred[img_idx] if isinstance(masks_pred, list) else masks_pred[img_idx, ...]
elif masks_gt is not None:
mask = masks_gt[img_idx] if isinstance(masks_gt, list) else masks_gt[img_idx, ...]
if mask is not None:
if labels_color_map is not None:
mask = apply_color_map(mask, labels_color_map)
if image.ndim == 2 or image.shape[2] != 3:
image = cv.cvtColor(image, cv.COLOR_GRAY2BGR)
image = cv.addWeighted(image, 0.5, mask, 0.5, 0)
offsets = (img_idx // grid_size_x) * img_shape[0], (img_idx % grid_size_x) * img_shape[1]
if image.ndim < 3 or image.shape[2] == 1:
image = cv.cvtColor(image, cv.COLOR_GRAY2BGR)
np.copyto(img_grid[offsets[0]:(offsets[0] + img_shape[0]), offsets[1]:(offsets[1] + img_shape[1]), :], image)
win_name = "segments" if redraw is None else redraw[0]
if img_grid is not None:
Expand Down

0 comments on commit 474e70f

Please sign in to comment.