In [11]:
#!pip install baseballcv ultralytics
import cv2
from ultralytics import YOLO
from baseballcv.functions import LoadTools
from tqdm import tqdm
import cv2

# 載入模型
load_tools = LoadTools()
model_path = load_tools.load_model("ball_tracking")
model = YOLO(model_path)
def predict_pitch_boxes_from_video_batch(video_path, batch_size=16,model=model):
    cap = cv2.VideoCapture(video_path)
    frame_idx = 0
    box_results = []

    batch_frames = []
    frame_indices = []

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        batch_frames.append(frame)
        frame_indices.append(frame_idx)
        frame_idx += 1

        # 每 batch_size 張推一次
        if len(batch_frames) == batch_size:
            results = model.predict(source=batch_frames, imgsz=640, device='cuda:0', verbose=False)

            for idx, result in enumerate(results):
                boxes = result.boxes
                if boxes is not None and len(boxes) > 0:
                    best_box = boxes[0]
                    x1, y1, x2, y2 = best_box.xyxy[0].tolist()
                    box_results.append((frame_indices[idx], (x1, y1, x2, y2)))
                else:
                    box_results.append((frame_indices[idx], None))

            batch_frames = []
            frame_indices = []

    # 處理最後不足 batch_size 的幀
    if batch_frames:
        results = model.predict(source=batch_frames, imgsz=640, device='cuda:0', verbose=False)
        for idx, result in enumerate(results):
            boxes = result.boxes
            if boxes is not None and len(boxes) > 0:
                best_box = boxes[0]
                x1, y1, x2, y2 = best_box.xyxy[0].tolist()
                box_results.append((frame_indices[idx], (x1, y1, x2, y2)))
            else:
                box_results.append((frame_indices[idx], None))

    cap.release()
    return box_results

def draw_boxes_on_video_batch(input_path, output_path, box_results, batch_size=16):
    cap = cv2.VideoCapture(input_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)

    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_idx = 0

    while frame_idx < total_frames:
        frames_batch = []
        indices_batch = []

        # 批次讀取 batch_size 幀
        for _ in range(batch_size):
            ret, frame = cap.read()
            if not ret:
                break
            frames_batch.append(frame)
            indices_batch.append(frame_idx)
            frame_idx += 1

        # 批次畫框
        for i, frame in enumerate(frames_batch):
            if indices_batch[i] < len(box_results):
                _, box = box_results[indices_batch[i]]
                if box is not None:
                    x1, y1, x2, y2 = map(int, box)
                    cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            out.write(frame)

    cap.release()
    out.release()

2025-06-22 01:19:26,744 - LoadTools - INFO - Model found at models/od/YOLO/ball_tracking/model_weights/ball_tracking.pt


INFO:LoadTools:Model found at models/od/YOLO/ball_tracking/model_weights/ball_tracking.pt


In [12]:
import time
s = time.time()
boxes = predict_pitch_boxes_from_video_batch("/content/drive/MyDrive/Baseball Movies/CH_videos_4s/pitch_0001.mp4")
e = time.time()
print(e-s)

10.063778162002563


In [13]:
boxes

[(0,
  (565.8788452148438, 580.109130859375, 694.6965942382812, 591.2823486328125)),
 (1,
  (566.6822509765625,
   579.8925170898438,
   694.6683349609375,
   591.2116088867188)),
 (2, (566.386962890625, 579.826171875, 694.701171875, 591.282958984375)),
 (3,
  (566.6749267578125,
   579.7977294921875,
   694.5841064453125,
   591.3221435546875)),
 (4,
  (566.90283203125, 579.7598876953125, 694.602783203125, 591.3731689453125)),
 (5, (566.178466796875, 579.8905029296875, 694.836181640625, 591.65771484375)),
 (6,
  (566.8365478515625,
   579.8829345703125,
   694.6395263671875,
   591.5504150390625)),
 (7, (566.21630859375, 579.8629150390625, 695.171875, 591.4456787109375)),
 (8,
  (565.7723388671875,
   579.7650756835938,
   694.3687744140625,
   591.2881469726562)),
 (9, (564.2841796875, 579.68017578125, 694.39306640625, 591.230224609375)),
 (10,
  (564.0563354492188,
   579.6151123046875,
   694.3644409179688,
   591.2008056640625)),
 (11,
  (563.7749633789062, 579.513427734375, 695.6

In [9]:
draw_boxes_on_video_batch("/content/drive/MyDrive/Baseball Movies/CH_videos_4s/pitch_0001.mp4", "output_with_boxes.mp4", boxes)