# Faster R-CNN 모델을 위한 코드 (FasterRCNN_Object_Detection.ipynb)

In [None]:

# 1. 라이브러리 임포트
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import os
import cv2
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


In [None]:
# 2. 학습시 CPU, GPU 사용 설정
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using device: {device}')

In [None]:
# 3. 데이터 불러오기
class PidrayDataset(Dataset):
    def __init__(self, root_dir, transforms=None):
        self.root_dir = root_dir
        self.transforms = transforms
        self.imgs = list(sorted(os.listdir(os.path.join(root_dir, "images"))))
        self.labels = list(sorted(os.listdir(os.path.join(root_dir, "labels"))))

    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, "images", self.imgs[idx])
        label_path = os.path.join(self.root_dir, "labels", self.labels[idx])
        img = Image.open(img_path).convert("RGB")
        img = np.array(img)

        boxes = []
        labels = []
        with open(label_path) as f:
            for line in f:
                parts = line.strip().split()
                labels.append(int(parts[0]))
                x_min = float(parts[1]) * img.shape[1]
                y_min = float(parts[2]) * img.shape[0]
                w = float(parts[3]) * img.shape[1]
                h = float(parts[4]) * img.shape[0]
                boxes.append([x_min, y_min, x_min + w, y_min + h])

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)
        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)

        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img = self.transforms(img)

        return img, target

    def __len__(self):
        return len(self.imgs)

def get_transform():
    return transforms.Compose([
        transforms.ToTensor(),
    ])

In [None]:
# 데이터셋 로드
train_dataset = PidrayDataset('pidray/train', get_transform())
valid_dataset = PidrayDataset('pidray/valid', get_transform())
test_dataset = PidrayDataset('pidray/test', get_transform())

train_dataloader = DataLoader(train_dataset, batch_size=12, shuffle=True, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
valid_dataloader = DataLoader(valid_dataset, batch_size=12, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))
test_dataloader = DataLoader(test_dataset, batch_size=12, shuffle=False, num_workers=4, collate_fn=lambda x: tuple(zip(*x)))


In [None]:
# 4. 네트워크 구현
model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes = 12 + 1  # 12 classes + background
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.to(device)

In [None]:
# 5. Loss, optimizer 설정
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

In [None]:
# 6. 모델 학습 및 validation loss, accuracy 계산
num_epochs = 10
train_losses = []
valid_losses = []

freq = 10

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    i = 0

    for images, targets in train_dataloader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        
        loss_dict = model(images, targets)
        losses = sum(loss for loss in loss_dict.values())
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        running_loss += losses.item()

        if i % freq == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {losses.item()}")
        i += 1
    
    epoch_loss = running_loss / len(train_dataloader)
    train_losses.append(epoch_loss)
    
    lr_scheduler.step()
    
    # Validation step
    model.eval()
    valid_loss = 0
    with torch.no_grad():
        for images, targets in valid_dataloader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            valid_loss += losses.item()
    
    valid_losses.append(valid_loss / len(valid_dataloader))
    print(f"Epoch: {epoch}, Validation Loss: {valid_loss / len(valid_dataloader)}")

    # Loss 시각화
    plt.figure(figsize=(10, 5))
    plt.title("Training and Validation Loss")
    plt.plot(train_losses, label="train")
    plt.plot(valid_losses, label="valid")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

In [None]:
# 7. 모델 정확도 평가
model.eval()
detection_results = []
accuracy = []

with torch.no_grad():
    for images, targets in test_dataloader:
        images = list(image.to(device) for image in images)
        outputs = model(images)
        
        for target, output in zip(targets, outputs):
            pred_boxes = output['boxes'].cpu().numpy()
            true_boxes = target['boxes'].cpu().numpy()
            pred_labels = output['labels'].cpu().numpy()
            true_labels = target['labels'].cpu().numpy()
            
            correct = (pred_labels == true_labels).sum()
            total = len(true_labels)
            accuracy.append(correct / total)
            
            detection_results.append(output)

# 정확도 시각화
plt.figure(figsize=(10, 5))
plt.title("Model Accuracy on Test Data")
plt.plot(accuracy, label="accuracy")
plt.xlabel("Samples")
plt.ylabel("Accuracy")
plt.legend()
plt.show()

In [None]:
# 8. 모델 가중치 저장
torch.save(model.state_dict(), 'fasterrcnn_pidray.pth')