In [5]:
import tkinter as tk
from tkinter import filedialog, messagebox
from PIL import Image, ImageTk
import cv2
import threading
import numpy as np
import torch
from torchvision import models, transforms
from sklearn.cluster import DBSCAN  # 修改：从 sklearn 导入 DBSCAN
from ultralytics import YOLO

class VideoPlayer:
    def __init__(self, root):
        self.root = root
        self.root.title("Video Player")

        # Create frames for layout
        self.left_frame = tk.Frame(root)
        self.left_frame.grid(row=0, column=0, padx=10, pady=10)

        self.right_frame = tk.Frame(root)
        self.right_frame.grid(row=0, column=1, padx=10, pady=10)

        # 顯示影片
        self.canvas = tk.Canvas(self.left_frame, width=640, height=480, bg="grey")
        self.canvas.grid(row=0, column=0, columnspan=2, padx=5, pady=5)

        self.load_button = tk.Button(self.left_frame, text="Select a Video", command=self.load_video)
        self.load_button.grid(row=1, column=0, columnspan=2, pady=10)

        # 控制顯示A/B隊
        self.effects_label = tk.Label(self.right_frame, text="Effects on")
        self.effects_label.grid(row=0, column=0, pady=10)

        self.team_a_var = tk.BooleanVar(value=True)
        self.team_b_var = tk.BooleanVar(value=True)
        self.all_var = tk.BooleanVar(value=True)

        self.team_a_check = tk.Checkbutton(self.right_frame, text="Team A", variable=self.team_a_var)
        self.team_a_check.grid(row=1, column=0, sticky="w", padx=5, pady=5)

        self.team_b_check = tk.Checkbutton(self.right_frame, text="Team B", variable=self.team_b_var)
        self.team_b_check.grid(row=2, column=0, sticky="w", padx=5, pady=5)

        self.all_check = tk.Checkbutton(self.right_frame, text="All", variable=self.all_var, command=self.toggle_all)
        self.all_check.grid(row=3, column=0, sticky="w", padx=5, pady=5)

        self.video_path = None
        self.cap = None
        self.playing = False
        self.paused = False
        self.stop_flag = False

        # YOLOv8模型路径
        self.model_path = 'C:/Users/a3221/deltable_data/yolov8_teamDiff/yolov8_trained2_0719.pt'
        self.model1 = YOLO(self.model_path)
        self.model1.model.names = {0: 'Ball', 1: 'Hoop', 2: 'Player'}
        
        # 加载 ResNet 模型
        self.resnet = models.resnet152(pretrained=True)
        self.resnet = torch.nn.Sequential(*list(self.resnet.children())[:-1])  # 移除最后一层全连接层
        self.resnet.eval()
        
        # 定义图像预处理步骤
        self.preprocess = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])

    def toggle_all(self):
        state = self.all_var.get()
        self.team_a_var.set(state)
        self.team_b_var.set(state)

    def load_video(self):
        self.video_path = filedialog.askopenfilename()
        if self.video_path:
            self.cap = cv2.VideoCapture(self.video_path)
            self.play_video()
        else:
            messagebox.showerror("Error", "Failed to Load Video")

    def play_video(self):
        if self.video_path and not self.playing:
            self.playing = True
            self.paused = False
            self.stop_flag = False
            threading.Thread(target=self._play).start()
        elif self.paused:
            self.paused = False

    def _play(self):
        while self.cap.isOpened() and self.playing and not self.stop_flag:
            if not self.paused:
                ret, frame = self.cap.read()
                if not ret:
                    break

                # Process frame for displaying team bounding boxes
                frame = self.process_frame(frame)

                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = cv2.resize(frame, (640, 480))
                img = Image.fromarray(frame)
                imgtk = ImageTk.PhotoImage(image=img)
                self.canvas.create_image(0, 0, anchor=tk.NW, image=imgtk)
                self.canvas.image = imgtk
                self.root.update()

                # 控制播放速度
                if cv2.waitKey(10) & 0xFF == ord('q'):
                    break

        self.playing = False
        self.cap.release()

    def extract_features(self, image):
        image_tensor = self.preprocess(image).unsqueeze(0)
        with torch.no_grad():
            features = self.resnet(image_tensor)
        return features.squeeze().numpy()

    def process_image(self, img):
        results = self.model1.predict(img)
        
        features_list = []
        for result in results[0].boxes:
            class_id = int(result.cls[0])
            if class_id == 2:  # 仅处理球员
                x1, y1, x2, y2 = map(int, result.xyxy[0])
                cropped_img = img[y1:y2, x1:x2]
                cropped_pil_img = Image.fromarray(cv2.cvtColor(cropped_img, cv2.COLOR_BGR2RGB))
                features = self.extract_features(cropped_pil_img)
                features_list.append(features)
        
        n = len(features_list)
        if n < 2:
            print("特征向量数量不足，无法进行聚类。")
        else:
            features_2d = np.array(features_list)
            
            # 进行DBSCAN聚类
            dbscan = DBSCAN(eps=0.9, min_samples=1, metric='euclidean').fit(features_2d)
            labels = dbscan.labels_
            print(labels)
        
            colors = [(255, 0, 0), (0, 0, 255)]  # Red for B队, Blue for A队
            color_mapping = {}
            
            for i, label in enumerate(labels):
                if label not in color_mapping:
                    color_mapping[label] = colors[len(color_mapping) % len(colors)]
            
            print("Color Mapping:", color_mapping)
        
            j = 0
            for i, result in enumerate(results[0].boxes):
                if j == n:
                    break
                class_id = int(result.cls[0])
                if class_id == 2:
                    label = labels[j]
                    if label == -1:
                        continue  # 忽略噪点
                    
                    x1, y1, x2, y2 = map(int, result.xyxy[0])
                    score = result.conf[0]
                    team_label = '(A)' if color_mapping[label] == colors[1] else '(B)'
                    color = color_mapping[label]
        
                    cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
                    cv2.putText(img, f'Player {team_label}', (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
                    cv2.putText(img, f'{score:.2f}', (x1, y2 + 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 1)
                    j += 1

        return img

    def process_frame(self, frame):
        # 处理当前帧
        frame = self.process_image(frame)
        return frame

    def pause_video(self):
        if self.playing:
            self.paused = True

    def stop_video(self):
        if self.playing:
            self.stop_flag = True
            self.playing = False
            self.cap.release()
            self.canvas.delete("all")

if __name__ == "__main__":
    root = tk.Tk()
    app = VideoPlayer(root)
    root.mainloop()





0: 288x512 5 Players, 81.0ms
Speed: 3.0ms preprocess, 81.0ms inference, 3.0ms postprocess per image at shape (1, 3, 288, 512)
[0 1 2 3 4]
Color Mapping: {0: (255, 0, 0), 1: (0, 0, 255), 2: (255, 0, 0), 3: (0, 0, 255), 4: (255, 0, 0)}

0: 288x512 1 Hoop, 7 Players, 19.0ms
Speed: 2.0ms preprocess, 19.0ms inference, 3.0ms postprocess per image at shape (1, 3, 288, 512)
[0 1 2 3 4 5 6]
Color Mapping: {0: (255, 0, 0), 1: (0, 0, 255), 2: (255, 0, 0), 3: (0, 0, 255), 4: (255, 0, 0), 5: (0, 0, 255), 6: (255, 0, 0)}

0: 288x512 1 Hoop, 7 Players, 14.0ms
Speed: 2.0ms preprocess, 14.0ms inference, 2.0ms postprocess per image at shape (1, 3, 288, 512)
[0 1 2 3 4 5 6]
Color Mapping: {0: (255, 0, 0), 1: (0, 0, 255), 2: (255, 0, 0), 3: (0, 0, 255), 4: (255, 0, 0), 5: (0, 0, 255), 6: (255, 0, 0)}

0: 288x512 1 Hoop, 7 Players, 17.0ms
Speed: 1.0ms preprocess, 17.0ms inference, 3.0ms postprocess per image at shape (1, 3, 288, 512)
[0 1 2 3 4 5 6]
Color Mapping: {0: (255, 0, 0), 1: (0, 0, 255), 2: (255,

Exception in thread Thread-9 (_play):
Traceback (most recent call last):
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\threading.py", line 1045, in _bootstrap_inner
    self.run()
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\threading.py", line 982, in run
    self._target(*self._args, **self._kwargs)
  File "C:\Users\a3221\AppData\Local\Temp\ipykernel_57084\3941482025.py", line 106, in _play
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\site-packages\PIL\ImageTk.py", line 126, in __init__
    self.__photo = tkinter.PhotoImage(**kw)
                   ^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\tkinter\__init__.py", line 4150, in __init__
    Image.__init__(self, 'photo', name, cnf, master, **kw)
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\tkinter\__init__.py", line 4087, in __init__
    master = _get_default_root('create image')
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "C:\Users\a3221\anaconda3\envs\Py311\Lib\tkinter\__init__.py", l

[0 1 2 3 4 5 6 7]
Color Mapping: {0: (255, 0, 0), 1: (0, 0, 255), 2: (255, 0, 0), 3: (0, 0, 255), 4: (255, 0, 0), 5: (0, 0, 255), 6: (255, 0, 0), 7: (0, 0, 255)}
