In [27]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import cv2
import threading
import numpy as np
import torch
from torchvision import models, transforms
from sklearn.cluster import DBSCAN  # 修改：從 sklearn 導入 DBSCAN
from ultralytics import YOLO

class VideoPlayer:
    def __init__(self, root):
        self.root = root
        self.root.title("Team_division")

        # Create frames for layout
        self.left_frame = tk.Frame(root)
        self.left_frame.grid(row=0, column=0, padx=10, pady=10)

        self.right_frame = tk.Frame(root)
        self.right_frame.grid(row=0, column=1, padx=10, pady=10)

        # 顯示影片
        self.canvas = tk.Canvas(self.left_frame, width=800, height=480, bg="grey")
        self.canvas.grid(row=0, column=0, columnspan=2, padx=5, pady=5)

        self.load_button = tk.Button(self.left_frame, text="Select a Video", command=self.load_video)
        self.load_button.grid(row=1, column=0, columnspan=2, pady=10)

        # 控制顯示A/B隊
        self.effects_label = tk.Label(self.right_frame, text="Effects on")
        self.effects_label.grid(row=0, column=0, pady=10)

        self.team_a_var = tk.BooleanVar(value=True)
        self.team_b_var = tk.BooleanVar(value=True)
        self.hoop_var = tk.BooleanVar(value=True)
        self.ball_var = tk.BooleanVar(value=True)

        self.team_a_check = tk.Checkbutton(self.right_frame, text="Team A", variable=self.team_a_var)
        self.team_a_check.grid(row=1, column=0, sticky="w", padx=5, pady=5)

        self.team_b_check = tk.Checkbutton(self.right_frame, text="Team B", variable=self.team_b_var)
        self.team_b_check.grid(row=2, column=0, sticky="w", padx=5, pady=5)

        self.hoop_check = tk.Checkbutton(self.right_frame, text="Hoop", variable=self.hoop_var)
        self.hoop_check.grid(row=3, column=0, sticky="w", padx=5, pady=5)

        self.ball_check = tk.Checkbutton(self.right_frame, text="Ball", variable=self.ball_var)
        self.ball_check.grid(row=4, column=0, sticky="w", padx=5, pady=5)

        self.video_path = None
        self.cap = None
        self.playing = False
        self.paused = False
        self.stop_flag = False

        # YOLOv8模型路徑
        self.model_path = 'C:/Users/a3221/deltable_data/yolov8_teamDiff/yolov8_trained2_0719.pt'
        self.model1 = YOLO(self.model_path)
        self.model1.model.names = {0: 'Ball', 1: 'Hoop', 2: 'Player'}
        
        # 載入 ResNet 模型
        self.resnet = models.resnet50(pretrained=True)
        self.resnet = torch.nn.Sequential(*list(self.resnet.children())[:-1])  # 移除最後一層全連接層
        self.resnet.eval()
        
        # 定義影像預處理步驟
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def toggle_all(self):
        state = self.all_var.get()
        self.team_a_var.set(state)
        self.team_b_var.set(state)
        
    def load_video(self):
        self.video_path = filedialog.askopenfilename()
        if self.video_path:
            self.cap = cv2.VideoCapture(self.video_path)
            self.play_video()
        else:
            messagebox.showerror("Error", "Failed to Load Video")

    def play_video(self):
        if self.video_path and not self.playing:
            self.playing = True
            self.paused = False
            self.stop_flag = False
            threading.Thread(target=self._play).start()
        elif self.paused:
            self.paused = False

    def _play(self):
        while self.cap.isOpened() and self.playing and not self.stop_flag:
            if not self.paused:
                ret, frame = self.cap.read()
                if not ret:
                    break

                # Process frame for displaying team bounding boxes
                frame = self.process_frame(frame)

                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (800, 480))
                img = Image.fromarray(frame)
                imgtk = ImageTk.PhotoImage(image=img)
                self.canvas.create_image(0, 0, anchor=tk.NW, image=imgtk)
                self.canvas.image = imgtk
                self.root.update()

                # 控制播放速度
                if cv2.waitKey(1) & 0xFF == ord('q'):
                    break

        self.playing = False
        self.cap.release()

    def extract_features(self, image):
        image_tensor = self.preprocess(image).unsqueeze(0)
        with torch.no_grad():
            features = self.resnet(image_tensor)
        return features.squeeze().numpy()

    def process_image(self, img):
        results = self.model1.predict(img)[0]
        
        features_list = []
        for result in results.boxes:
            class_id = int(result.cls[0])
            
            if class_id == 2:  # player
                x1, y1, x2, y2 = map(int, result.xyxy[0])
                cropped_img = img[y1:y2, x1:x2]
                cropped_pil_img = Image.fromarray(cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB))
                features = self.extract_features(cropped_pil_img)
                features_list.append(features)
        
        n = len(features_list)
        if n < 2:
            print("特徵向量數量不足，無法進行聚類。")
            labels = []
        else:
            features_2d = np.array(features_list)
            
            # 進行DBSCAN聚類
            dbscan = DBSCAN(eps=0.9, min_samples=1, metric='euclidean').fit(features_2d)
            labels = dbscan.labels_
    
        colors = [(255, 0, 0), (0, 0, 255)]  # Red for B隊, Blue for A隊
        color_mapping = {}
        
        for label in set(labels):
            if label != -1 and label not in color_mapping:
                color_mapping[label] = colors[len(color_mapping) % len(colors)]
    
        j = 0
        for result in results.boxes:
            x1, y1, x2, y2 = map(int, result.xyxy[0])
            class_id = int(result.cls[0])
            score = float(result.conf[0])
            
            if (class_id == 1 and self.hoop_var.get()):  # Hoop
                color = (0, 255, 0)  # 綠色框
                cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                cv2.putText(img, f'{self.model1.model.names[class_id]} {score:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
            elif class_id == 2:  # Player
                if j < len(labels):
                    label = labels[j]
                    if label == -1:
                        color = (128, 128, 128)  # 灰色用于未分類的球員
                    else:
                        color = color_mapping[label]
                    team_label = '(A)' if color == colors[1] else '(B)'
                    j += 1

                    # Add conditions to show/hide boxes based on team selection
                    if (team_label == '(A)' and self.team_a_var.get()) or (team_label == '(B)' and self.team_b_var.get()):
                        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                        cv2.putText(img, f'{self.model1.model.names[class_id]} {score:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
                    elif not self.team_a_var.get() and not self.team_b_var.get():
                        cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                        cv2.putText(img, f'{self.model1.model.names[class_id]} {score:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
            elif (class_id == 0 and self.ball_var.get()):  # Ball
                color = (0, 255, 0)  # 綠色框
                cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                cv2.putText(img, f'{self.model1.model.names[class_id]} {score:.2f}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
    
        return img

    def process_frame(self, frame):
        # 處理當前帧
        frame = self.process_image(frame)
        return frame

    def pause_video(self):
        if self.playing:
            self.paused = True

    def stop_video(self):
        if self.playing:
            self.stop_flag = True
            self.playing = False
            self.cap.release()
            self.canvas.delete("all")

if __name__ == "__main__":
    root = tk.Tk()
    app = VideoPlayer(root)
    root.mainloop()





0: 288x512 5 Players, 94.0ms
Speed: 4.0ms preprocess, 94.0ms inference, 2.0ms postprocess per image at shape (1, 3, 288, 512)

0: 288x512 1 Hoop, 7 Players, 21.0ms
Speed: 2.0ms preprocess, 21.0ms inference, 6.0ms postprocess per image at shape (1, 3, 288, 512)

0: 288x512 1 Hoop, 7 Players, 18.0ms
Speed: 3.0ms preprocess, 18.0ms inference, 4.0ms postprocess per image at shape (1, 3, 288, 512)

0: 288x512 1 Hoop, 7 Players, 19.0ms
Speed: 2.0ms preprocess, 19.0ms inference, 4.0ms postprocess per image at shape (1, 3, 288, 512)

0: 288x512 1 Hoop, 7 Players, 26.0ms
Speed: 3.0ms preprocess, 26.0ms inference, 4.0ms postprocess per image at shape (1, 3, 288, 512)


Exception in thread Thread-35 (_play):
Traceback (most recent call last):
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\a3221\AppData\Local\Temp\ipykernel_37264\992342132.py", line 105, in _play
  File "C:\Users\a3221\AppData\Local\Temp\ipykernel_37264\992342132.py", line 196, in process_frame
  File "C:\Users\a3221\AppData\Local\Temp\ipykernel_37264\992342132.py", line 181, in process_image
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\tkinter\__init__.py", line 643, in get
    return self._tk.getboolean(self._tk.globalgetvar(self._name))
                               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: main thread is not in main loop
