In [2]:
import sys
import os
import cv2
import numpy as np
import platform
import tkinter

from PyQt5.QtWidgets import (
    QApplication, QMainWindow, QLabel, QPushButton, QVBoxLayout, QHBoxLayout,
    QWidget, QFileDialog, QSlider, QListWidget, QListWidgetItem,
    QComboBox, QScrollArea, QMenuBar, QAction, QMessageBox, QDialog, QProgressBar,
    QOpenGLWidget, QStatusBar
)
from PyQt5.QtGui import QPixmap, QImage, QMouseEvent
from PyQt5.QtCore import Qt

from utils import find_images_and_texts_yolo_format, find_images_and_texts_same_folder

class KeypointVisualizer(QMainWindow):
    def __init__(self, window_width = 1280, window_height = 720):
        super().__init__()        

        os_name = platform.system()
        print(f"Current OS: {os_name}")
        print(f"Current resolution: {window_width} x {window_height}")
        
        self.canvas_width = int(window_width * 70 / 100)
        self.canvas_height = int(window_height * 70 / 100)
        
        self.colors = [
            [255, 0, 0],         # Red: nose
            [255, 192, 203],     # Pink: head
            [255, 165, 0],       # Orange: ass
            [235, 206, 135],     # Burlywood: torso
            [128, 0, 128],       # Purple: right_leg
            [0, 255, 255],       # Cyan: left_leg
            [0, 0, 255],         # Blue: right_hand
            [0, 128, 0],         # Green: left_hand
            [0, 0, 0]            # Black: tail
        ]
        self.currentTxtPath = None  # 현재 로드된 txt 파일 저장
        self.currentImgPath = None
        
        self.name = None
        self.keypoints = None  # 현재 keypoints 데이터 저장
        self.torso_box = None  # torso bbox 저장
        self.file_idx = 0
        
        self.selected_keypoint = None
        self.selected_box = False  # torso 박스 선택 여부
        self.box_dragging = False  # torso 박스 드래그 여부
        self.resize_corner = None
        self.resize_idx = None

        self.current_image = None
        self.current_name = 0
        self.image_text_list = {}
        
        # Keypoint Label (순서)
        self.KEYPOINT_NAMES = [
            "nose", "head", "ass", "torso", "right_foot", "left_foot",
            "right_hand", "left_hand", "tail"
        ]

        self.initUI()

    def initUI(self):        
        self.setWindowTitle("Keypoint Visualizer")
        self.setGeometry(10, 40, int(self.canvas_width * 1.2), int(self.canvas_height * 1.1))

        self.scroll_area = QScrollArea(self)
        self.scroll_area.setWidgetResizable(False)
        self.scroll_area.setFixedSize(self.canvas_width + 10, self.canvas_height + 10)
        
        self.canvas = QLabel(self)
        self.canvas.resize(self.canvas_width, self.canvas_height)  # 기본 크기
        self.canvas.setScaledContents(True)
        self.scroll_area.setWidget(self.canvas)

        self.fileList = QListWidget()
        self.fileList.clicked.connect(self.loadSelectedImage)
        
        # 폴더 선택 버튼
        self.btnLoadFolder = QPushButton("Select Folder", self)
        self.btnLoadFolder.clicked.connect(self.selectFolder)

        # Hand / Foot Switch 버튼
        self.btnHandSwitch = QPushButton("Hand Switch", self)
        self.btnHandSwitch.clicked.connect(self.switchHands)

        self.btnFootSwitch = QPushButton("Foot Switch", self)
        self.btnFootSwitch.clicked.connect(self.switchFeet)

        # Save 버튼
        self.btnSave = QPushButton("Save", self)
        self.btnSave.clicked.connect(self.saveKeypoints)
        self.btnSave.setShortcut("Ctrl+S")

        ### Status bar ########################
        self.status = self.statusBar()
        self.status.showMessage("Ready", 3000)  # 메시지를 3초 동안 표시

        self.name_selector = QComboBox()
        self.name_selector.addItems(self.KEYPOINT_NAMES)
        self.name_selector.currentIndexChanged.connect(self.name_selected)
        
        # 레이아웃 설정        
        fileLayout = QVBoxLayout()
        fileLayout.addWidget(self.name_selector)
        fileLayout.addWidget(self.fileList)
        
        layout = QHBoxLayout()
        layout.addWidget(self.scroll_area)
        layout.addLayout(fileLayout)
        
        buttonLayout = QHBoxLayout()
        buttonLayout.addWidget(self.btnHandSwitch)
        buttonLayout.addWidget(self.btnFootSwitch)
        buttonLayout.addWidget(self.btnSave)

        mainLayout = QVBoxLayout()
        mainLayout.addWidget(self.btnLoadFolder)
        mainLayout.addLayout(layout)
        mainLayout.addLayout(buttonLayout)  # 버튼 추가

        central_widget = QWidget()
        central_widget.setLayout(mainLayout)
        self.setCentralWidget(central_widget)
        
        # self.setLayout(mainLayout)

        self.h, self.w = 0, 0

        # Zoom in/out
        self.h_scroll = self.scroll_area.horizontalScrollBar().value()
        self.v_scroll = self.scroll_area.verticalScrollBar().value()
        self.zoom_scale = 1.0
        self.original_scale = 1.0
    
    def selectFolder(self):
        folderPath = QFileDialog.getExistingDirectory(self, "Select Folder")
    
        if folderPath:
            self.fileList.clear()
            if os.path.exists(os.path.join(folderPath, "images")):
                # Searching the jpg - label pair in YOLO format
                # (train/val folders in images/labels folder)
                self.image_text_list = find_images_and_texts_yolo_format(
                    image_root = f"{folderPath}/images", 
                    label_root = f"{folderPath}/labels"
                )
            else:
                # Searching the jpg - label pair in same folder
                self.image_text_list = find_images_and_texts_same_folder(folderPath)

            for image_path, _ in self.image_text_list:
                self.fileList.addItem(image_path)
        else:
            QMessageBox.warning(self, "QMessageBox", "Please choose appropriate folder!")

    def loadSelectedImage(self, selected, key_event = None):
        if key_event is None:
            self.file_idx = selected.row()
        else:
            self.file_idx = key_event

        imagePath = self.image_text_list[self.file_idx][0]
        txtPath = self.image_text_list[self.file_idx][1]
        
        if os.path.exists(txtPath):
            self.torso_box = []
            self.keypoints = []
            self.name = []
            self.currentTxtPath = txtPath  # 현재 로드된 txt 파일 저장
            self.currentImgPath = imagePath
            self.current_image = cv2.imread(self.currentImgPath)
            self.h, self.w, _ = self.current_image.shape

            # Zoom scale setting
            self.original_scale = min(self.canvas_width / self.w,
                                      self.canvas_height / self.h)
            
            self.zoom_scale = self.original_scale
            
            # Keypoint 데이터 읽기
            with open(txtPath, "r") as f:
                for line in f:
                    data = line.strip().split()
                    self.name.append(data[0])
                    torso_x, torso_y, torso_w, torso_h = map(float, data[1:5])  # torso 정보
                    self.torso_box.append([torso_x * self.w,
                                      torso_y * self.h,
                                      torso_w * self.w,
                                      torso_h * self.h])
                    keypoints_temp = np.array([list(map(float, data[5 + i * 3: 5 + (i + 1) * 3])) for i in range(len(self.KEYPOINT_NAMES))])
                    keypoints_temp[:, 0] *= self.w
                    keypoints_temp[:, 1] *= self.h
                    self.keypoints.append(keypoints_temp)
                            
            self.visualizeKeypoints(self.current_image)

    def visualizeKeypoints(self, image):
        for i in range(len(self.keypoints)):
            # Object bounding box visualization
            top_left_x = int(self.torso_box[i][0] - self.torso_box[i][2] / 2)
            top_left_y = int(self.torso_box[i][1] - self.torso_box[i][3] / 2)
            bottom_right_x = int(self.torso_box[i][0] + self.torso_box[i][2] / 2)
            bottom_right_y = int(self.torso_box[i][1] + self.torso_box[i][3] / 2)
    
            cv2.rectangle(image, (top_left_x, top_left_y), (bottom_right_x, bottom_right_y), (0, 255, 0), 1)
            cv2.putText(image, f"{i}", (min(top_left_x + 5, self.w), min(top_left_y + 20, self.h)),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
            
            # Keypoints visualization
            for j, (x, y, v) in enumerate(self.keypoints[i]):
                if v != 0:  # visibility가 0이 아닌 경우만 표시
                    color = self.colors[j]  # 지정된 색상 사용
                    cv2.circle(image, (int(x), int(y)), 5, color, -1)
                    cv2.putText(image, f"{self.KEYPOINT_NAMES[j]}({int(v)})", (min(int(x) + 5, self.w), max(int(y) - 5, 0)),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1, cv2.LINE_AA)

        # OpenCV 이미지 → PyQt5 QPixmap 변환
        height, width, channel = image.shape
        qImg = QImage(image.data, width, height, image.strides[0], QImage.Format_RGB888).rgbSwapped()
        pixmap = QPixmap.fromImage(qImg)
        self.canvas.setPixmap(pixmap)
        self.canvas.resize(int(self.w * self.zoom_scale), int(self.h * self.zoom_scale))
        
    def switchHands(self):
        """Right Hand <-> Left Hand 교체"""
        if self.keypoints is not None:
            temp = self.keypoints[6].copy()
            self.keypoints[6] = self.keypoints[7]
            self.keypoints[7] = temp
            print("🔄 Switched Hands")
            self.refreshImage()

    def switchFeet(self):
        """Right Foot <-> Left Foot 교체"""
        if self.keypoints is not None:
            temp = self.keypoints[4].copy()
            self.keypoints[4] = self.keypoints[5]
            self.keypoints[5] = temp
            print("🔄 Switched Feet")
            self.refreshImage()

    def saveKeypoints(self):
        """현재 keypoints를 txt 파일에 저장"""
        if self.currentTxtPath and self.keypoints is not None:
            # Remove previous data in txt file
            with open(self.currentTxtPath, "w") as file:
                file.write("")
                
            for i in range(len(self.keypoints)):
                normalized_torso = [self.torso_box[i][0] / self.w, self.torso_box[i][1] / self.h, 
                                    self.torso_box[i][2] / self.w, self.torso_box[i][3] / self.h]
                with open(self.currentTxtPath, "a") as file:
                    file.write(f"{self.name[i]} {' '.join(map(str, normalized_torso))} ")
                    file.write(" ".join(f"{x/self.w:.6f} {y/self.h:.6f} {int(v)}" for x, y, v in self.keypoints[i]))
                    file.write("\n")
                    
            self.status.showMessage(f"💾 Saved: {self.currentTxtPath}", 3000)
    
    def mousePressEvent(self, event: QMouseEvent):
        if self.keypoints:
            click_x, click_y = event.pos().x() - 10, event.pos().y() - 40
            self.get_scroll_position()
            
            click_x = int((click_x + self.h_scroll) / self.zoom_scale)
            click_y = int((click_y + self.v_scroll) / self.zoom_scale)
            
            print(f"{click_x}, {click_y}")
            
            if event.modifiers() & Qt.ControlModifier:
                visibility = 2
                if event.modifiers() & Qt.ShiftModifier:
                    print("Shift key pressed")
                    visibility = 1
                
                target_idx = None
                for i in range(len(self.torso_box)):
                    x1 = int(self.torso_box[i][0] - self.torso_box[i][2] / 2)
                    y1 = int(self.torso_box[i][1] - self.torso_box[i][3] / 2)
                    x2 = int(self.torso_box[i][0] + self.torso_box[i][2] / 2)
                    y2 = int(self.torso_box[i][1] + self.torso_box[i][3] / 2)
    
                    if x1 <= click_x <= x2 and y1 <= click_y <= y2: # Find target object
                        if self.keypoints[i][self.current_name][2] == 0: # When visibility == 0
                            self.keypoints[i][self.current_name] = click_x, click_y, visibility
                            self.refreshImage()
                            break
                        else:
                            QMessageBox.warning(self, "Keypoint error", 
                                                f"{self.KEYPOINT_NAMES[self.current_name]} already exists in object {i}"
                                                )
            else:            
                min_dist = float("inf")
                nearest_index = None
        
                for i in range(len(self.keypoints)):
                    for j, (x, y, v) in enumerate(self.keypoints[i]):
                        if v != 0:
                            dist = np.sqrt((x - click_x) ** 2 + (y - click_y) ** 2)
                            if dist < min_dist:
                                min_dist = dist
                                nearest_index = [i, j]
        
                if min_dist < 20:  # 20픽셀 이내일 때만 선택
                    if event.button() == Qt.LeftButton:    
                        self.selected_keypoint = nearest_index
                        print(f"🎯 Selected Keypoint: {self.KEYPOINT_NAMES[self.selected_keypoint[1]]}")
                    elif event.button() == Qt.RightButton:
                        self.keypoints[nearest_index[0]][nearest_index[1]] = 0.0, 0.0, 0
                        print(f"{nearest_index} is removed!")
                        self.refreshImage()
                else:
                    corner_threshold = 10  # 모서리 감지 거리
                    for i in range(len(self.torso_box)):
                        x1 = int(self.torso_box[i][0] - self.torso_box[i][2] / 2)
                        y1 = int(self.torso_box[i][1] - self.torso_box[i][3] / 2)
                        x2 = int(self.torso_box[i][0] + self.torso_box[i][2] / 2)
                        y2 = int(self.torso_box[i][1] + self.torso_box[i][3] / 2)
        
                        # 모서리 클릭 감지 (각 네 모서리)
                        if abs(click_x - x1) < corner_threshold and abs(click_y - y1) < corner_threshold:
                            self.resize_idx = i
                            self.resize_corner = "top_left"
                        elif abs(click_x - x2) < corner_threshold and abs(click_y - y1) < corner_threshold:
                            self.resize_idx = i
                            self.resize_corner = "top_right"
                        elif abs(click_x - x1) < corner_threshold and abs(click_y - y2) < corner_threshold:
                            self.resize_idx = i
                            self.resize_corner = "bottom_left"
                        elif abs(click_x - x2) < corner_threshold and abs(click_y - y2) < corner_threshold:
                            self.resize_idx = i
                            self.resize_corner = "bottom_right"
        
                    print(f" Selected box: {self.resize_idx}, {self.resize_corner}")

    def mouseMoveEvent(self, event: QMouseEvent):
        if self.keypoints:
            new_x, new_y = event.pos().x() - 10, event.pos().y() - 40
            new_x = int((new_x + self.h_scroll) / self.zoom_scale)
            new_y = int((new_y + self.v_scroll) / self.zoom_scale)
            
            if self.selected_keypoint is not None:
                self.keypoints[self.selected_keypoint[0]][self.selected_keypoint[1]][0] = int(new_x)
                self.keypoints[self.selected_keypoint[0]][self.selected_keypoint[1]][1] = int(new_y)
                self.refreshImage()
            elif self.resize_idx is not None:
                # Make bounding box stay in image
                if new_x > self.w:
                    new_x = self.w
                elif new_x < 0:
                    new_x = 0
                if new_y > self.h:
                    new_y = self.h
                elif new_y < 0:
                    new_y = 0
                    
                center_x, center_y, box_w, box_h = self.torso_box[self.resize_idx]
                top_left_x = center_x - box_w / 2
                top_left_y = center_y - box_h / 2
                bottom_right_x = center_x + box_w / 2
                bottom_right_y = center_y + box_h / 2
    
                if self.resize_corner == "top_left":
                    top_left_x = new_x
                    top_left_y = new_y
                elif self.resize_corner == "top_right":
                    bottom_right_x = new_x
                    top_left_y = new_y
                elif self.resize_corner == "bottom_left":
                    top_left_x = new_x
                    bottom_right_y = new_y
                elif self.resize_corner == "bottom_right":
                    bottom_right_x = new_x
                    bottom_right_y = new_y
    
                # 중심 좌표와 너비/높이 다시 계산
                new_center_x = (top_left_x + bottom_right_x) / 2
                new_center_y = (top_left_y + bottom_right_y) / 2
                new_width = abs(bottom_right_x - top_left_x)
                new_height = abs(bottom_right_y - top_left_y)
    
                self.torso_box[self.resize_idx] = [new_center_x, new_center_y, new_width, new_height]
                
                self.refreshImage()

    def mouseReleaseEvent(self, event: QMouseEvent):
        if self.keypoints:
            self.selected_keypoint = None
            self.box_dragging = False
            self.resize_idx = None
            self.resize_corner = None

    # ======== 확대/축소 ========
    def wheelEvent(self, event):
        # Scale change only when Ctrl key is pressed
        if event.modifiers() and Qt.ControlModifier and self.current_image is not None:
            delta = event.angleDelta().y()
            if delta > 0:
                self.zoom_scale = min(self.original_scale * 5.0, self.zoom_scale + 0.1)
            else:
                self.zoom_scale = max(self.original_scale, self.zoom_scale - 0.1)
            
            self.get_scroll_position()
            self.visualizeKeypoints(self.current_image)
        else:
            super().wheelEvent(event)

    def get_scroll_position(self):
        self.h_scroll = self.scroll_area.horizontalScrollBar().value()
        self.v_scroll = self.scroll_area.verticalScrollBar().value()
    
    def refreshImage(self):
        self.current_image = cv2.imread(self.currentImgPath)
        self.visualizeKeypoints(self.current_image)

    def name_selected(self, index):
        self.current_name = index

    def keyPressEvent(self, event):
        if 48 <= event.key() < 48 + len(self.KEYPOINT_NAMES):
            self.current_name = event.key() - 48
            self.name_selector.setCurrentIndex(self.current_name)
        elif event.key() == Qt.Key_PageUp:
            self.file_idx = min(self.file_idx + 1, len(self.keypoints))
            loadSelectedImage(key_event = self.file_idx)
        elif event.key() == Qt.Key_PageDown:
            self.file_idx = max(self.file_idx - 1, 0)
            loadSelectedImage(key_event = self.file_idx)

if __name__ == "__main__":
    ### Get current resolution of screen
    tk_for_res = tkinter.Tk()
    width = tk_for_res.winfo_screenwidth()
    height = tk_for_res.winfo_screenheight()
    tk_for_res.destroy()

    app = QApplication.instance()
    if app is None:
        app = QApplication(sys.argv)
    
    window = KeypointVisualizer(width, height)
    window.show()

    try:
        app.exec_()
    except SystemExit:
        print("[Info] PyQt5 Application exited cleanly.")
    finally:
        app.quit()
        del app
        print("[Info] QApplication resources have been cleaned up.")


Current OS: Windows
Current resolution: 1920 x 1080
[Info] QApplication resources have been cleaned up.
