diff --git a/anylabeling/services/auto_labeling/segment_anything.py b/anylabeling/services/auto_labeling/segment_anything.py index 9b40677..5ad9632 100644 --- a/anylabeling/services/auto_labeling/segment_anything.py +++ b/anylabeling/services/auto_labeling/segment_anything.py @@ -1,6 +1,5 @@ import logging import os -import sys from copy import deepcopy import cv2 @@ -11,11 +10,10 @@ from anylabeling.utils import GenericWorker from anylabeling.views.labeling.shape import Shape -from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img - +from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img, qt_img_to_rgb_cv_img +from .lru_cache import LRUCache from .model import Model from .types import AutoLabelingResult -from .lru_cache import LRUCache class SegmentAnything(Model): @@ -178,7 +176,7 @@ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): return (newh, neww) def apply_coords( - self, coords: np.ndarray, original_size, target_length + self, coords: np.ndarray, original_size, target_length ) -> np.ndarray: """ Expects a numpy array of length 2 in the final dimension. Requires the @@ -201,8 +199,8 @@ def run_decoder(self, image_embedding, resized_ratio): [input_points, np.array([[0.0, 0.0]])], axis=0 )[None, :, :] onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ - None, : - ].astype(np.float32) + None, : + ].astype(np.float32) onnx_coord = self.apply_coords( onnx_coord, self.size_after_apply_max_width_height, self.input_size ).astype(np.float32) @@ -336,7 +334,7 @@ def predict_shapes(self, image, filename=None) -> AutoLabelingResult: image_embedding, ) = cached_data else: - cv_image = qt_img_to_cv_img(image) + cv_image = qt_img_to_rgb_cv_img(image, filename) encoder_inputs, resized_ratio = self.pre_process(cv_image) if self.stop_inference: return AutoLabelingResult([], replace=False) @@ -394,8 +392,8 @@ def on_next_files_changed(self, next_files): and run inference to save time for user. """ if ( - self.pre_inference_thread is None - or not self.pre_inference_thread.isRunning() + self.pre_inference_thread is None + or not self.pre_inference_thread.isRunning() ): self.pre_inference_thread = QThread() self.pre_inference_worker = GenericWorker( diff --git a/anylabeling/services/auto_labeling/yolov5.py b/anylabeling/services/auto_labeling/yolov5.py index 601956e..7db7ed2 100644 --- a/anylabeling/services/auto_labeling/yolov5.py +++ b/anylabeling/services/auto_labeling/yolov5.py @@ -6,8 +6,7 @@ from PyQt5 import QtCore from anylabeling.views.labeling.shape import Shape -from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img - +from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img from .model import Model from .types import AutoLabelingResult @@ -157,7 +156,7 @@ def predict_shapes(self, image, image_path=None): return [] try: - image = qt_img_to_cv_img(image) + image = qt_img_to_rgb_cv_img(image, image_path) except Exception as e: # noqa logging.warning("Could not inference model") logging.warning(e) diff --git a/anylabeling/services/auto_labeling/yolov8.py b/anylabeling/services/auto_labeling/yolov8.py index c869a95..904153c 100644 --- a/anylabeling/services/auto_labeling/yolov8.py +++ b/anylabeling/services/auto_labeling/yolov8.py @@ -6,8 +6,7 @@ from PyQt5 import QtCore from anylabeling.views.labeling.shape import Shape -from anylabeling.views.labeling.utils.opencv import qt_img_to_cv_img - +from anylabeling.views.labeling.utils.opencv import qt_img_to_rgb_cv_img from .model import Model from .types import AutoLabelingResult @@ -153,7 +152,7 @@ def predict_shapes(self, image, image_path=None): return [] try: - image = qt_img_to_cv_img(image) + image = qt_img_to_rgb_cv_img(image, image_path) except Exception as e: # noqa logging.warning("Could not inference model") logging.warning(e) diff --git a/anylabeling/views/labeling/utils/opencv.py b/anylabeling/views/labeling/utils/opencv.py index cb75506..99b0c4a 100644 --- a/anylabeling/views/labeling/utils/opencv.py +++ b/anylabeling/views/labeling/utils/opencv.py @@ -1,5 +1,33 @@ +import os.path + +import cv2 +import numpy as np import qimage2ndarray from PyQt5 import QtGui +from PyQt5.QtGui import QImage + + +def qt_img_to_rgb_cv_img(qt_img, img_path=None): + """ + Convert 8bit/16bit RGB image or 8bit/16bit Gray image to 8bit RGB image + """ + if img_path is not None and os.path.exists(img_path): + # Load Image From Path Directly + cv_image = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1) + cv_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB) + else: + if qt_img.format() == QImage.Format_RGB32 or qt_img.format() == QImage.Format_ARGB32: + cv_image = qimage2ndarray.rgb_view(qt_img) + else: + cv_image = qimage2ndarray.raw_view(qt_img) + # To uint8 + if cv_image.dtype != np.uint8: + cv2.normalize(cv_image, cv_image, 0, 255, cv2.NORM_MINMAX) + cv_image = np.array(cv_image, dtype=np.uint8) + # To RGB + if len(cv_image.shape) == 2 or cv_image.shape[2] == 1: + cv_image = cv2.merge([cv_image, cv_image, cv_image]) + return cv_image def qt_img_to_cv_img(in_image):