Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions anylabeling/services/auto_labeling/segment_anything.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
import sys
from copy import deepcopy

import cv2
Expand All @@ -11,11 +10,10 @@

from anylabeling.utils import GenericWorker
from anylabeling.views.labeling.shape import Shape
from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img

from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img, qt_img_to_rgb_cv_img
from .lru_cache import LRUCache
from .model import Model
from .types import AutoLabelingResult
from .lru_cache import LRUCache


class SegmentAnything(Model):
Expand Down Expand Up @@ -178,7 +176,7 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
return (newh, neww)

def apply_coords(
self, coords: np.ndarray, original_size, target_length
self, coords: np.ndarray, original_size, target_length
) -> np.ndarray:
"""
Expects a numpy array of length 2 in the final dimension. Requires the
Expand All @@ -201,8 +199,8 @@ def run_decoder(self, image_embedding, resized_ratio):
[input_points, np.array([[0.0, 0.0]])], axis=0
)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
None, :
].astype(np.float32)
onnx_coord = self.apply_coords(
onnx_coord, self.size_after_apply_max_width_height, self.input_size
).astype(np.float32)
Expand Down Expand Up @@ -336,7 +334,7 @@ def predict_shapes(self, image, filename=None) -> AutoLabelingResult:
image_embedding,
) = cached_data
else:
cv_image = qt_img_to_cv_img(image)
cv_image = qt_img_to_rgb_cv_img(image, filename)
encoder_inputs, resized_ratio = self.pre_process(cv_image)
if self.stop_inference:
return AutoLabelingResult([], replace=False)
Expand Down Expand Up @@ -394,8 +392,8 @@ def on_next_files_changed(self, next_files):
and run inference to save time for user.
"""
if (
self.pre_inference_thread is None
or not self.pre_inference_thread.isRunning()
self.pre_inference_thread is None
or not self.pre_inference_thread.isRunning()
):
self.pre_inference_thread = QThread()
self.pre_inference_worker = GenericWorker(
Expand Down
5 changes: 2 additions & 3 deletions anylabeling/services/auto_labeling/yolov5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from PyQt5 import QtCore

from anylabeling.views.labeling.shape import Shape
from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img

from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img
from .model import Model
from .types import AutoLabelingResult

Expand Down Expand Up @@ -157,7 +156,7 @@ def predict_shapes(self, image, image_path=None):
return []

try:
image = qt_img_to_cv_img(image)
image = qt_img_to_rgb_cv_img(image, image_path)
except Exception as e: # noqa
logging.warning("Could not inference model")
logging.warning(e)
Expand Down
5 changes: 2 additions & 3 deletions anylabeling/services/auto_labeling/yolov8.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from PyQt5 import QtCore

from anylabeling.views.labeling.shape import Shape
from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img

from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img
from .model import Model
from .types import AutoLabelingResult

Expand Down Expand Up @@ -153,7 +152,7 @@ def predict_shapes(self, image, image_path=None):
return []

try:
image = qt_img_to_cv_img(image)
image = qt_img_to_rgb_cv_img(image, image_path)
except Exception as e: # noqa
logging.warning("Could not inference model")
logging.warning(e)
Expand Down
28 changes: 28 additions & 0 deletions anylabeling/views/labeling/utils/opencv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,33 @@
import os.path

import cv2
import numpy as np
import qimage2ndarray
from PyQt5 import QtGui
from PyQt5.QtGui import QImage


def qt_img_to_rgb_cv_img(qt_img, img_path=None):
"""
Convert 8bit/16bit RGB image or 8bit/16bit Gray image to 8bit RGB image
"""
if img_path is not None and os.path.exists(img_path):
# Load Image From Path Directly
cv_image = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
else:
if qt_img.format() == QImage.Format_RGB32 or qt_img.format() == QImage.Format_ARGB32:
cv_image = qimage2ndarray.rgb_view(qt_img)
else:
cv_image = qimage2ndarray.raw_view(qt_img)
# To uint8
if cv_image.dtype != np.uint8:
cv2.normalize(cv_image, cv_image, 0, 255, cv2.NORM_MINMAX)
cv_image = np.array(cv_image, dtype=np.uint8)
# To RGB
if len(cv_image.shape) == 2 or cv_image.shape[2] == 1:
cv_image = cv2.merge([cv_image, cv_image, cv_image])
return cv_image


def qt_img_to_cv_img(in_image):
Expand Down