In [1]:
import os
import torch
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, average_precision_score
import random
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Configuration parameters
class Config:
    # Path configuration
    DATA_ROOT = "D:/Code_pytorch/xianyu/hand detection"
    TRAIN_IMAGES = os.path.join(DATA_ROOT, "training_dataset/training_dataset/training_data/images")
    TRAIN_LABELS = os.path.join(DATA_ROOT, "labels_fast_rcnn")
    TEST_IMAGES = os.path.join(DATA_ROOT, "training_dataset/training_dataset/training_data/images")
    TEST_LABELS = os.path.join(DATA_ROOT, "labels_fast_rcnn")
    OUTPUT_DIR = os.path.join(DATA_ROOT, "output")
    MODEL_SAVE_PATH = os.path.join(OUTPUT_DIR, "faster_rcnn_model.pth")
    
    # Training parameters
    NUM_CLASSES = 2  # Background + target class count
    BATCH_SIZE = 10
    NUM_EPOCHS = 20
    LEARNING_RATE = 0.005
    MOMENTUM = 0.9
    WEIGHT_DECAY = 0.0005
    
    # Device configuration
    DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
# Custom dataset class
class HandDataset(Dataset):
    def __init__(self, image_dir, label_dir, train=False, transforms=None):
        self.image_dir = image_dir
        self.label_dir = label_dir
        self.transforms = transforms
        self.train = train
        self.image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        # Load image
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)
        image = Image.open(img_path).convert("RGB")
        original_width, original_height = image.size
        
        # Resize image to 112x112
        new_size = (112, 112)
        image = F.resize(image, new_size)
        
        # Compute scale ratios
        width_scale = new_size[0] / original_width
        height_scale = new_size[1] / original_height
        
        # Load labels
        label_path = os.path.join(self.label_dir, os.path.splitext(img_name)[0] + ".txt")
        boxes = []
        labels = []
        
        with open(label_path, 'r') as f:
            for line in f:
                parts = line.strip().split()
                if len(parts) != 5:
                    continue
                
                class_id = int(parts[0])
                x_min = float(parts[1])
                y_min = float(parts[2])
                x_max = float(parts[3])
                y_max = float(parts[4])
                
                # Scale coordinates
                x_min = x_min * width_scale
                y_min = y_min * height_scale
                x_max = x_max * width_scale
                y_max = y_max * height_scale
                
                boxes.append([x_min, y_min, x_max, y_max])
                labels.append(class_id + 1)  # 0 reserved for background
        
        # Convert to tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        
        # Data augmentation: Random horizontal flip
        if self.train and random.random() < 0.5:
            image = F.hflip(image)
            boxes[:, [0, 2]] = 112 - boxes[:, [2, 0]]
        
        image = F.to_tensor(image)
        
        target = {
            "boxes": boxes,
            "labels": labels,
            "image_id": torch.tensor([idx]),
            "area": (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]),
            "iscrowd": torch.zeros((len(boxes),), dtype=torch.int64)
        }
        
        if self.transforms:
            image = self.transforms(image)
        
        return image, target