<a href="https://colab.research.google.com/github/skywalker0803r/x3d/blob/main/x3d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pytorchvideo.models.hub import x3d_xs
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset
import pandas as pd
from torchvision.io import read_video
from torchvision import transforms as T
import cv2
import numpy as np

def read_video_cv2(path, max_frames=16):
    cap = cv2.VideoCapture(path)
    frames = []
    while len(frames) < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()

    if len(frames) == 0:
        raise RuntimeError(f"Cannot read video {path}")

    # 如果幀數不足 max_frames，用最後一幀重複補齊
    while len(frames) < max_frames:
        frames.append(frames[-1].copy())

    video_np = np.stack(frames, axis=0)  # (T, H, W, C)
    video_t = torch.from_numpy(video_np).permute(3, 0, 1, 2)  # (C, T, H, W)
    return video_t

# 資料模型
class Normalize(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1, 1)

    def forward(self, x):
        return (x - self.mean) / self.std

class VideoTransform:
    def __init__(self):
        self.resize = T.Resize((182, 182))
        self.center_crop = T.CenterCrop(160)
        self.normalize = Normalize(mean=[0.43216, 0.394666, 0.37645],
                                   std=[0.22803, 0.22145, 0.216989])

    def __call__(self, video):
        # video shape: (C, T, H, W)
        C, T, H, W = video.shape
        frames = []
        for t in range(T):
            img = video[:, t, :, :]  # (C,H,W), tensor

            # 用 Resize 和 CenterCrop 物件轉換
            img = self.resize(img)
            img = self.center_crop(img)

            frames.append(img)

        video = torch.stack(frames, dim=1)  # (C, T, H, W)
        video = video.float() / 255.0
        video = self.normalize(video)
        return video

class VideoDataset(Dataset):
    def __init__(self, csv_path, video_dir, frames=16, transform=VideoTransform()):
        self.video_dir = video_dir
        self.frames = frames
        self.transform = transform

        self.data = pd.read_csv(csv_path)
        # 篩選出存在的影片路徑
        def file_exists(filename):
            return os.path.isfile(os.path.join(video_dir, filename))

        mask = self.data['filename'].apply(file_exists)
        filtered_data = self.data[mask].reset_index(drop=True)
        num_removed = len(self.data) - len(filtered_data)
        if num_removed > 0:
            print(f"Warning: removed {num_removed} entries because video files not found")
        self.data = filtered_data

        self.data['label'] = self.data['description'].str.contains('strike', case=False).astype(int)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        video_path = os.path.join(self.video_dir, row['filename'])
        label = row['label']

        video = read_video_cv2(video_path, self.frames)

        if self.transform:
            video = self.transform(video)

        return video, label

# ------- 1. 訓練函數 --------
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for videos, labels in tqdm(loader, desc="Training", leave=False):
        videos, labels = videos.to(device), labels.to(device).long()

        outputs = model(videos)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), correct / total

# ------- 2. 驗證函數 --------
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        for videos, labels in tqdm(loader, desc="Validating", leave=False):
            videos, labels = videos.to(device), labels.to(device).long()
            outputs = model(videos)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return total_loss / len(loader), correct / total

# ------- 3. 主訓練流程 --------
def main():
    # 設定
    csv_path = "/content/drive/MyDrive/Baseball Movies/CH_videos_4s/CH.csv"
    video_dir = "/content/drive/MyDrive/Baseball Movies/CH_videos_4s"
    batch_size = 4
    num_epochs = 20
    frames = 240
    val_split = 0.2
    lr = 1e-4
    model_save_path = "best_x3d_model.pth"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"use device:{device}")

    # 資料載入
    full_dataset = VideoDataset(csv_path, video_dir, frames=frames)
    val_size = int(len(full_dataset) * val_split)
    train_size = len(full_dataset) - val_size
    train_set, val_set = random_split(full_dataset, [train_size, val_size])

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)

    # 模型載入 & 修改輸出層
    model = x3d_xs(pretrained=True)
    model.blocks[-1].proj = nn.Linear(model.blocks[-1].proj.in_features, 2)
    model = model.to(device)

    # 優化器 & 損失
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    best_val_acc = 0

    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")

        # 儲存最佳模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), model_save_path)
            print("✅ Saved best model!")

if __name__ == "__main__":
    main()

use device:cuda


Downloading: "https://dl.fbaipublicfiles.com/pytorchvideo/model_zoo/kinetics/X3D_XS.pyth" to /root/.cache/torch/hub/checkpoints/X3D_XS.pyth
100%|██████████| 29.4M/29.4M [00:00<00:00, 75.2MB/s]



Epoch 1/20


Training:  52%|█████▎    | 63/120 [08:20<08:16,  8.70s/it]