<a href="https://colab.research.google.com/github/skywalker0803r/base_ball_detect_lab/blob/main/%E6%A3%92%E7%90%83%E8%BB%8C%E8%B7%A1%E5%81%B5%E6%B8%AC_%E7%90%83%E8%B7%AF%E8%BB%8C%E8%B7%A1%E9%85%8D%E6%A8%99%E7%B1%A4_%E6%89%B9%E9%87%8F%E8%99%95%E7%90%86.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [30]:
#我已經有一個函數extract_baseball_trajectory_from_video將影片pitch_0001.mp4丟進去
#trajectory = extract_baseball_trajectory_from_video(video_path='/content/pitch_0001.mp4', model=model, conf_threshold=0.5, baseball_class_name="baseball")
#他會輸出trajectory再丟到extract_longest_valid_segment(trajectory)會回傳longest_segment(代表真正的投球軌跡)
#我想做的事情是我有一個dir裡面放很多的mp4影片從pitch_0001.mp4到pitch_xxxx.mp4都有
#請你自動幫我把每一筆影片的longest_segment都抽出來當作這筆樣本的X
#再根據影片名例如pitch_0001.mp4 去我指定的csv檔案裏面
#找到Filename等於影片名的那一行對應的description當作y
#最後製作成batch_X(list of 每一個影片的投球軌跡),batch_y(list of 每一個影片的description)


In [31]:
#!pip install baseballcv ultralytics

In [32]:

import cv2
import matplotlib.pyplot as plt
import numpy as np

def extract_baseball_trajectory_from_video(video_path, model, conf_threshold=0.5, baseball_class_name="baseball"):
    """
    從影片逐幀偵測棒球位置，回傳軌跡座標列表，並畫出軌跡圖。

    參數：
        video_path: str，影片檔案路徑
        model: 已載入的 YOLOv8 模型物件
        conf_threshold: float，信心度門檻，低於此忽略
        baseball_class_name: str，棒球類別名稱（依你模型的類別而定）

    回傳：
        trajectory: list of (frame_index, x_center, y_center)
    """
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"無法開啟影片檔：{video_path}")

    trajectory = []
    frame_idx = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break  # 讀完了

        # YOLOv8 預測，model 物件直接丟 np.ndarray 也行
        results = model(frame,verbose=False)  # 取得 list[Results]

        # 因為只處理一張圖，所以取 results[0]
        res = results[0]

        # 取得類別名稱與boxes/conf
        names = res.names
        boxes = res.boxes.xyxy.cpu().numpy()
        scores = res.boxes.conf.cpu().numpy()
        class_ids = res.boxes.cls.cpu().numpy().astype(int)

        # 找出棒球位置（第一個符合信心度且是棒球的框）
        ball_found = False
        for box, score, cls_id in zip(boxes, scores, class_ids):
            if score < conf_threshold:
                continue
            label = names.get(cls_id, str(cls_id))
            if label == baseball_class_name:
                x1, y1, x2, y2 = box.astype(int)  # <-- 修正這行
                trajectory.append((frame_idx,x1,y1,x2,y2))
                ball_found = True
                break

        # 如果該幀沒找到棒球，也可以記錄 None 或忽略
        if not ball_found:
            trajectory.append((frame_idx, None, None))

        # 準備進下一個迴圈
        frame_idx += 1

    # 釋放
    cap.release()
    return trajectory

In [33]:
def extract_longest_valid_segment(trajectory):

    max_len = 0
    current = []
    best = []

    for item in trajectory:
        frame_idx, x, y = item
        if x is not None and y is not None:
            current.append(item)
        else:
            if len(current) > max_len:
                max_len = len(current)
                best = current
            current = []
    # 最後一段也可能是最長的
    if len(current) > max_len:
        best = current

    if best:
        start_idx = best[0][0]
        end_idx = best[-1][0]
        print(f"選擇的段落：從 frame {start_idx} 到 frame {end_idx}，共 {len(best)} 幀")
    else:
        print("找不到有效段落")
    return best

In [34]:
import cv2
from ultralytics import YOLO
from baseballcv.functions import LoadTools
from tqdm import tqdm
import cv2

# 載入模型
load_tools = LoadTools()
model_path = load_tools.load_model("ball_tracking")
model = YOLO(model_path)

2025-07-01 17:36:59,292 - LoadTools - INFO - Model found at models/od/YOLO/ball_tracking/model_weights/ball_tracking.pt


INFO:LoadTools:Model found at models/od/YOLO/ball_tracking/model_weights/ball_tracking.pt


In [35]:
import os
import pandas as pd
import pickle
from tqdm import tqdm
import re

def load_data_from_videos_and_csv(video_dir, csv_path, output_dir, model, conf_threshold=0.5, baseball_class_name="baseball"):
    # 讀 CSV
    df = pd.read_csv(csv_path)

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    video_files = [f for f in os.listdir(video_dir) if f.lower().endswith('.mp4')]
    video_files.sort()  # 如果要排序

    for vf in tqdm(video_files):

        # 如果檔名中有數字超過 200，則跳過
        numbers_in_name = [int(num) for num in re.findall(r'\d+', vf)]
        if any(num > 200 for num in numbers_in_name):
            continue

        # 若檔案已存在則跳過
        video_path = os.path.join(video_dir, vf)
        save_path = os.path.join(output_dir, f"{os.path.splitext(vf)[0]}_baseball_trajectory_with_label.pkl")
        if os.path.exists(save_path):
            print(f"已存在: {save_path}，跳過")
            continue

        # 抽取軌跡
        trajectory = extract_baseball_trajectory_from_video(video_path, model, conf_threshold, baseball_class_name)

        # 暫時使用完整軌跡
        longest_segment = trajectory

        if not longest_segment:
            print(f"警告: {vf} 找不到有效投球軌跡段，跳過")
            continue

        row = df[df['Filename'] == vf]
        if row.empty:
            print(f"警告: {vf} 在 CSV 中找不到對應標籤，跳過")
            continue
        try:
            pitch_type = row.iloc[0]['pitch_type']
            description = row.iloc[0]['description']
        except Exception as e:
            print(f"錯誤讀取標籤: {csv_path}, row: {row}, 錯誤: {e}")
            continue

        with open(save_path, 'wb') as f:
            pickle.dump({
                'trajectory': longest_segment,
                'pitch_type': pitch_type,
                'description': description,
            }, f)

    return f"{video_dir}:處理完成"

In [36]:
# 範例用法
player_name_list = [
    'Shohei_Ohtani_SL','Shohei_Ohtani_FS','Shohei_Ohtani_FF',
    'Gerrit_Cole_CH','Gerrit_Cole_FF','Gerrit_Cole_SL',
    'Yu_Darvish_FF','Yu_Darvish_FS','Yu_Darvish_SL',
]
player_name_list = ['Yu_Darvish_SL']
for player_name in tqdm(player_name_list):
  print(f"{player_name}:開始處理")
  video_dir = f'/content/drive/MyDrive/Baseball Movies/{player_name}_videos_4S'
  csv_path = f'/content/drive/MyDrive/Baseball Movies/data_csv/{player_name}.csv'
  output_dir = f'/content/drive/MyDrive/Baseball Movies/{player_name}_videos_4S/baseball_trajectory_bbox_with_pitch_type_description'
  load_data_from_videos_and_csv(video_dir, csv_path, output_dir, model)
  print(f"{player_name}:處理完成")

 14%|█▍        | 56/404 [00:50<05:15,  1.10it/s]
  0%|          | 0/1 [00:52<?, ?it/s]


KeyboardInterrupt: 