<a href="https://colab.research.google.com/github/skywalker0803r/x3d/blob/main/x3d.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pytorchvideo.models.hub import x3d_xs
from tqdm import tqdm
import torch.nn.functional as F
from torch.utils.data import Dataset
import pandas as pd
from torchvision.io import read_video
from torchvision import transforms as T
import cv2
import numpy as np
from torchvision import transforms
import random
from PIL import Image
from torch.utils.data import ConcatDataset
from torch.utils.data import Subset
import copy
from PIL import Image, ImageFilter

def read_video_cv2(path, max_frames=240, sample_frames=120):
    cap = cv2.VideoCapture(path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(frame)
    cap.release()

    total_frames = len(frames)
    if total_frames == 0:
        raise RuntimeError(f"Cannot read video {path}")

    # 如果影片幀數不夠，補最後一幀
    while len(frames) < max_frames:
        frames.append(frames[-1].copy())

    frames = frames[:max_frames]  # 確保長度不超過max_frames

    # 等距抽樣成 sample_frames 幀
    indices = np.linspace(0, max_frames - 1, sample_frames).astype(int)
    sampled_frames = [frames[i] for i in indices]

    video_np = np.stack(sampled_frames, axis=0)  # (T, H, W, C)
    video_t = torch.from_numpy(video_np).permute(3, 0, 1, 2)  # (C, T, H, W)
    return video_t

# 資料模型
class Normalize(torch.nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.mean = torch.tensor(mean).view(-1, 1, 1, 1)
        self.std = torch.tensor(std).view(-1, 1, 1, 1)

    def forward(self, x):
        return (x - self.mean) / self.std

# 🔧 安全資料增強
class SafeVideoAugmentation:
    def __init__(self, resize=(224, 224), apply_blur_prob=0.3, apply_brightness_prob=0.3):
        self.resize = resize
        self.to_tensor = transforms.ToTensor()
        self.apply_blur_prob = apply_blur_prob
        self.apply_brightness_prob = apply_brightness_prob
        self.normalize = Normalize(mean=[0.45, 0.45, 0.45], std=[0.225, 0.225, 0.225])

    def __call__(self, frames):
        augmented = []
        apply_blur = random.random() < self.apply_blur_prob
        apply_brightness = random.random() < self.apply_brightness_prob
        brightness_factor = random.uniform(0.8, 1.2)

        for frame in frames:
            frame = cv2.resize(frame, self.resize)
            pil_frame = Image.fromarray(frame)

            if apply_blur:
                pil_frame = pil_frame.filter(ImageFilter.GaussianBlur(radius=1.5))  # radius 可以調整模糊程度
            if apply_brightness:
                pil_frame = transforms.functional.adjust_brightness(pil_frame, brightness_factor)

            tensor_frame = self.to_tensor(pil_frame)
            augmented.append(tensor_frame)
        augmented_tensor = torch.stack(augmented) # (T, C, H, W)
        augmented_tensor = augmented_tensor.permute(1, 0, 2, 3)  # (C, T, H, W)
        augmented_tensor = self.normalize(augmented_tensor)
        return augmented_tensor
class VideoDataset(Dataset):
    def __init__(self, csv_path, video_dir, original_frames=240, sample_frames=120, transform=None):
        self.video_dir = video_dir
        self.original_frames = original_frames  # 影片原始長度(最大幀數)
        self.sample_frames = sample_frames      # 要等距抽樣成多少幀
        self.transform = transform or SafeVideoAugmentation()
        self.data = pd.read_csv(csv_path)
        # 篩選出存在的影片路徑
        def file_exists(filename):
            return os.path.isfile(os.path.join(video_dir, filename))

        mask = self.data['filename'].apply(file_exists)
        filtered_data = self.data[mask].reset_index(drop=True)
        num_removed = len(self.data) - len(filtered_data)
        if num_removed > 0:
            print(f"Warning: removed {num_removed} entries because video files not found")
        self.data = filtered_data

        self.data['label'] = self.data['description'].str.contains('strike', case=False).astype(int)

    def __len__(self):
      return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        video_path = os.path.join(self.video_dir, row['filename'])
        label = row['label']

        video = read_video_cv2(video_path, self.original_frames, self.sample_frames)
        video = video.permute(1, 2, 3, 0).numpy()  # (T,H,W,C)
        if self.transform:
            video = self.transform(video)
        return video, label

# ------- 1. 訓練函數 --------
def train_one_epoch(model, loader, criterion, optimizer, device, scaler):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for videos, labels in tqdm(loader, desc="Training", leave=False):
        videos, labels = videos.to(device), labels.to(device).long()

        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(videos)
            loss = criterion(outputs, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return total_loss / len(loader), correct / total

# ------- 2. 驗證函數 --------
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0

    with torch.no_grad():
        with torch.cuda.amp.autocast():
            for videos, labels in tqdm(loader, desc="Validating", leave=False):
                videos, labels = videos.to(device), labels.to(device).long()
                outputs = model(videos)
                loss = criterion(outputs, labels)

                total_loss += loss.item()
                preds = outputs.argmax(dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

    return total_loss / len(loader), correct / total

# ------- 3. 主訓練流程 --------
def main():
    batch_size = 10      # 抽樣成120幀，可以試著加大batch_size
    num_epochs = 20
    original_frames = 240
    sample_frames = 120
    val_split = 0.2
    lr = 1e-4
    dataset_paths = [
        {
            "csv_path": "/content/drive/MyDrive/Baseball Movies/CH_videos_4s/CH.csv",
            "video_dir": "/content/drive/MyDrive/Baseball Movies/CH_videos_4s"
        },
        {
            "csv_path": "/content/drive/MyDrive/Baseball Movies/FF_videos_4s/FF.csv",
            "video_dir": "/content/drive/MyDrive/Baseball Movies/FF_videos_4s"
        },
        {
            "csv_path": "/content/drive/MyDrive/Baseball Movies/SL_videos_4s/SL.csv",
            "video_dir": "/content/drive/MyDrive/Baseball Movies/SL_videos_4s"
        },
    ]
    datasets = []
    for item in dataset_paths:
        dataset = VideoDataset(
            csv_path=item["csv_path"],
            video_dir=item["video_dir"],
            original_frames=original_frames,
            sample_frames=sample_frames,
            transform=None  # 先不加 transform，稍後設置
        )
        datasets.append(dataset)

    full_dataset = ConcatDataset(datasets)

    # 1. 分割 index
    val_split = 0.2
    dataset_len = len(full_dataset)
    indices = list(range(dataset_len))
    split = int(val_split * dataset_len)
    random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # 2. 建立 train/val dataset，並複製原始 dataset
    train_dataset = copy.deepcopy(full_dataset)
    val_dataset = copy.deepcopy(full_dataset)

    # 3. 設定 transform 分別給 train/val dataset
    for d in train_dataset.datasets:
        d.transform = SafeVideoAugmentation()

    for d in val_dataset.datasets:
        d.transform = SafeVideoAugmentation(apply_blur_prob=0.0, apply_brightness_prob=0.0)

    # 4. 建立 Subset
    train_set = Subset(train_dataset, train_indices)
    val_set = Subset(val_dataset, val_indices)

    # 5. 建立 dataloader
    batch_size = 10
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    model = x3d_xs(pretrained=True)
    model.blocks[-1].proj = nn.Linear(model.blocks[-1].proj.in_features, 2)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # 加入 AMP 的 GradScaler
    scaler = torch.amp.GradScaler('cuda')

    best_val_acc = 0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, scaler)
        val_loss, val_acc = evaluate(model, val_loader, criterion, device)

        print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Val   Acc: {val_acc:.4f}")

        # 保存到雲端 命名為 epoch和val_acc
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            # 動態命名模型儲存檔案到google drive MyDrive
            save_dir = "/content/drive/MyDrive"
            os.makedirs(save_dir, exist_ok=True)
            filename = f"epoch_{epoch+1}_valacc_{val_acc:.4f}.pt"
            model_save_path = os.path.join(save_dir, filename)
            torch.save(model.state_dict(), model_save_path)
            print(f"✅ Saved best model to Google Drive: {model_save_path}")

if __name__ == "__main__":
    main()


Epoch 1/20


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


Train Loss: 0.6182 | Train Acc: 0.6949
Val   Loss: 0.5595 | Val   Acc: 0.7549
✅ Saved best model to Google Drive: /content/drive/MyDrive/epoch_1_valacc_0.7549.pt

Epoch 2/20


Training:  26%|██▌       | 37/144 [02:10<07:47,  4.37s/it]