# RetinaNet

In [27]:
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from bs4 import BeautifulSoup
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from matplotlib import pyplot as plt

def generate_box(obj):
    
    xmin = float(obj.find("xmin").text)
    ymin = float(obj.find("ymin").text)
    xmax = float(obj.find("xmax").text)
    ymax = float(obj.find("ymax").text)

    return [xmin, ymin, xmax, ymax]

def generate_label(obj):
    
    if obj.find("name").text == "with_mask":
        return 1
    elif obj.find("name").text == "mask_weared_incorrect":
        return 2
    return 0

def generate_target(file): 
    with open(file) as f:
        data = f.read()
        soup = BeautifulSoup(data, "html.parser")
        objects = soup.find_all("object")

        num_objs = len(objects)

        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))

        boxes = torch.as_tensor(boxes, dtype=torch.float32) 
        labels = torch.as_tensor(labels, dtype=torch.int64) 
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        return target
    
def plot_image_from_output(img, annotation):
    
    img = img.permute(1,2,0)
    
    rects = []

    for idx in range(len(annotation["boxes"])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx]

        if annotation['labels'][idx] == 0 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
        
        elif annotation['labels'][idx] == 1 :
            
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
            
        else :
        
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')

        rects.append(rect)

    return img, rects

class MaskDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.imgs = list(sorted(os.listdir(self.path)))
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + "xml"
        img_path = os.path.join(self.path, file_image)

        if "test" in self.path:
            label_path = os.path.join("test_annotations/", file_label)
        else:
            label_path = os.path.join("annotations/", file_label)
        
        img = Image.open(img_path).convert("RGB")
        target = generate_target(label_path)

        to_tensor = torchvision.transforms.ToTensor()

        if self.transform:
            img, transform_target = self.transform(np.array(img), np.array(target["boxes"]))
            target["boxes"] = torch.as_tensor(transform_target)

        img = to_tensor(img)

        return img, target

def collate_fn(batch):
    return tuple(zip(*batch))

dataset = MaskDataset("images/")
test_dataset = MaskDataset("test_images/")

data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)


### 모델 불러오기

In [None]:
!pip install torch torchvision torchaudio -f https://download.pytorch.org/whl/torch_stable.html

In [7]:
import torchvision
import torch
torchvision.__version__

'0.17.1'

In [13]:
retina = torchvision.models.detection.retinanet_resnet50_fpn(
    num_classes=3, 
    pretrained=False, 
    pretrained_backnone=True
)

### 전이 학습

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

retina.to(device)    

num_epochs = 30

# parameters
params = [ p for p in retina.parameters() if p.requires_grad ]
optimizer = torch.optim.SGD(
    params,
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)
len_dataloader = len(data_loader)

for epoch in range(num_epochs):
    start = time.time()
    retina.train()

    i = 0
    epoch_loss = 0
    for images, targets in data_loader:
        
        images = list(image.to(device) for image in images)
        targets = [ {k: v for k, v in t.items() } for t in targets ]

        loss_dict = retina(images, targets)

        losses = sum(loss for loss in loss_dict.values)

        i += 1
        
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()

        epoch_loss += losses
    print(epoch_loss, f"time: {time.time() - start}")

모델 재사용을 위해 학습된 가중치를 저장

In [None]:
torch.save(retina.state_dict(), f"retina_{num_epochs}.pt")

In [None]:
retina.load_state_dict(torch.load(f"retina_{num_epochs}.pt"))

학습된 가중치를 불러올 때는 load_state_dict와 torch.load 함수를 이용

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

retina.to(device)    