In [19]:
import os
import torch
import xml.etree.ElementTree as ET
from PIL import Image
from torchvision.transforms import functional as F

class VOCDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, label_dir):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.images = []
        self.xmls = []

        # 掃描資料夾
        for subfolder in os.listdir(image_dir):
            for file in os.listdir(os.path.join(image_dir, subfolder)):
                if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                    img_path = os.path.join(image_dir, subfolder, file)
                    # 對應的 XML 檔案名稱
                    base_name = os.path.splitext(file)[0]
                    xml_path = os.path.join(label_dir, subfolder+"_label", base_name + '.xml')
                else:
                    continue
                # 檢查對應的 XML 是否存在
                if os.path.exists(xml_path):
                    self.images.append(img_path)
                    self.xmls.append(xml_path)
                # else:
                #     print(f"警告: 找不到 {file} 對應的 XML 標記")

    def __getitem__(self, idx):
        xml_path = self.xmls[idx]
        img_path = self.images[idx]

        # load image
        img = Image.open(img_path).convert("RGB")
        
        # parse XML
        tree = ET.parse(xml_path)
        root = tree.getroot()
        
        boxes = []
        labels = []

        for obj in root.findall("object"):
            name = obj.find("name").text
            # 假設每個子資料夾名稱就是類別
            labels.append(1)  # 假設只有單一類別，可以替換 label mapping

            bbox = obj.find("bndbox")
            xmin = float(bbox.find("xmin").text)
            ymin = float(bbox.find("ymin").text)
            xmax = float(bbox.find("xmax").text)
            ymax = float(bbox.find("ymax").text)
            boxes.append([xmin, ymin, xmax, ymax])

        boxes = torch.tensor(boxes, dtype=torch.float32)
        labels = torch.tensor(labels)

        target = {
            "boxes": boxes,
            "labels": labels,
        }

        return F.to_tensor(img), target

    def __len__(self):
        return len(self.images)


In [18]:
import torchvision
import torch

def get_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="DEFAULT")

    # 替換分類器
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )
    return model


In [20]:
from torch.utils.data import DataLoader
import torch

dataset = VOCDataset(image_dir="dataset", label_dir="label")
data_loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=lambda x: tuple(zip(*x)))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

num_classes = 2  # 1類別 + 背景
model = get_model(num_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

for epoch in range(10):
    model.train()
    total_loss = 0

    for imgs, targets in data_loader:
        imgs = [img.to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = model(imgs, targets)
        loss = sum(loss_dict.values())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}")


KeyboardInterrupt: 