In [4]:
import torch
from torchvision import transforms
from PIL import Image

class TrafficDataset(torch.utils.data.Dataset):
    def __init__(self, image_paths, labels):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor()
        ])

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        img = self.transform(img)
        target = self.labels[idx]  # 應該包含 bbox, category 等資訊
        return img, target

    def __len__(self):
        return len(self.image_paths)


In [5]:
import torchvision
from torchvision.models.detection import FasterRCNN_ResNet50_FPN_Weights
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

# 使用新的 weights 參數載入模型
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=weights)

# 修改輸出層，適應車輛分類
num_classes = 2  # 1 (車輛) + 1 (背景)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# 設定訓練參數
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

# 設定為訓練模式
model.train()

# 基本的訓練循環結構
def train_one_epoch(model, optimizer, data_loader, device):
    for images, targets in data_loader:
        images = [image.to(device) for image in images]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
    return losses


In [6]:
import cv2

def count_vehicles(tracked_objects, counting_line):
    count = 0
    for obj in tracked_objects:
        x, y, w, h = obj["bbox"]
        center = (int(x + w / 2), int(y + h / 2))
        if center[1] > counting_line:
            count += 1
    return count
