<span style="font-size:150%">라이브러리</span>

In [1]:
import os,sys
sys.path.append(os.path.abspath('./../'))

print(os.path.abspath('./../'))
import torch
import torchvision
import torchvision.models.detection as detection
from torch.utils.data import DataLoader
from torchvision.models import ResNet50_Weights

# import torchvision.transforms as T
from day1.coco import transforms as T
from day1.coco.engine import train_one_epoch
from torchvision.utils import * 
from day1.datasets import PennFudanDataset

import matplotlib.pyplot as plt
import cv2
%matplotlib inline

/AILAB-summer-school-2025


<span style="font-size:150%">경로 및 파라미터 설정</span>

In [None]:
data_path = './data/PennFudanPed'
save_path = './parameters'
os.makedirs(save_path, exist_ok=True)
num_epoch = 10

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

<span style="font-size:150%">어그멘테이션 설정</span>

In [None]:
def get_transform(train):
    transforms = []
    transforms.append(T.PILToTensor())
    transforms.append(T.ConvertImageDtype(torch.float))
    if train:
        transforms.append(T.RandomHorizontalFlip(0.5))
    return T.Compose(transforms)

<span style="font-size:150%">데이터셋</span>

In [None]:
def collate_fn(batch):
    """
    Object Detection용 collate function
    batch: [(image1, target1), (image2, target2), ...]
    반환: (images, targets) where images=[image1, image2, ...], targets=[target1, target2, ...]
    """
    images, targets = tuple(zip(*batch))
    images = list(images)
    targets = list(targets)
    return images, targets

In [None]:
trainset = PennFudanDataset(data_path, get_transform(train=True))
testset = PennFudanDataset(data_path, get_transform(train=False))

indices = [i for i in range(len(trainset))]
dataset = torch.utils.data.Subset(trainset, indices[:-1])
dataset_test = torch.utils.data.Subset(testset, indices[-1:])

trainLoader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)
testLoader = DataLoader(dataset_test, batch_size=1, shuffle=False, collate_fn = collate_fn)

<span style="font-size:150%">모델 선언</span>

In [None]:
backbone = detection.backbone_utils.resnet_fpn_backbone(backbone_name='resnet50', weights=ResNet50_Weights.IMAGENET1K_V2)

anchor_generator = detection.rpn.AnchorGenerator(sizes=((32,), (64,), (128,), (256,), (512,),),
                                                 aspect_ratios=((0.5, 1.0, 2.0),)*5)

roi_pooler = torchvision.ops.MultiScaleRoIAlign(featmap_names=['0', '1', '2', '3'],
                                                output_size=7,
                                                sampling_ratio=2)

model = detection.FasterRCNN(backbone,
                   num_classes=2,
                   rpn_anchor_generator=anchor_generator,
                   box_roi_pool=roi_pooler)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005)

<span style="font-size:150%">모델 학습</span>

In [None]:
for epoch in range(num_epoch):
    train_one_epoch(model, optimizer, trainLoader, device, epoch, print_freq=10)
torch.save(model.state_dict(), os.path.join(save_path, 'detector.pth'))

<span style="font-size:150%">모델 출력 시각화</span>

In [None]:
imgs, targets= next(iter(testLoader))
img = imgs[0]
sample = img.permute(1,2,0).cpu().numpy()
target = targets[0]
boxes = target['boxes'].cpu().numpy().astype(int)
print(boxes)

In [None]:
model.eval()
device = torch.device('cpu')
model = model.to(device)
outputs = model(img.unsqueeze(0))
outputs = [{k: v.to(device) for k, v in t.items()} for t in outputs]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(16, 8))

mean_score = torch.mean(outputs[0]['scores'])

for box, score in zip(outputs[0]['boxes'].int(), outputs[0]['scores']):
    print(box, score)
    if score > 0.5:
        cv2.rectangle(sample,(box[0].item(), box[1].item()),(box[2].item(), box[3].item()),(225, 0, 0), 3)
        
for box in zip(targets[0]['boxes'].int()):
    box = box[0]
    cv2.rectangle(sample,(box[0].item(), box[1].item()),(box[2].item(), box[3].item()),(0, 0, 255), 3)
    
ax.set_axis_off()
ax.imshow(sample)