In [0]:
!rm -r ssd_object_detection
!git clone https://github.com/pai-plznw4me/ssd_object_detection
!pip install wget

In [0]:
import sys 
sys.path.append('./ssd_object_detection')
from model import simple_detection_netowrk
from tensorflow.python.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from prior import PriorBoxes
from dataset import DetectionDataset
from generator import DetectionGenerator
from loss import SSDLoss


# n_classes
n_classes = 11 # with background
n_anchors = 5
image_shape = (128, 128)

# Generate Detection Network
inputs, pred = simple_detection_netowrk((128, 128, 3), n_anchors, n_classes)

# Generate prior boxes
strides = [4, 8, 16]
scales = [10, 25, 40]
ratios = [(1, 1),
            (1.5, 0.5),
            (1.2, 0.8),
            (0.8, 1.2),
            (1.4, 1.4)]
prior = PriorBoxes(strides, scales, ratios)
prior_boxes = prior.generate(image_shape)

# Generate Dataset
trainset = DetectionDataset(data_type='train')
validset = DetectionDataset(data_type='validation')
traingen = DetectionGenerator(trainset.config,
                                prior.config,
                                batch_size=64)
validgen = DetectionGenerator(validset.config,
                                prior.config,
                                batch_size=64)
# Define Loss
ssd_loss = SSDLoss(1.0, 3.)

# Training
model = Model(inputs, pred)
model.compile(Adam(1e-3),
                loss=SSDLoss(1.0, 3.))

rlrop = ReduceLROnPlateau(factor=0.1,
                            min_lr=1e-6,
                            patience=5,
                            cooldown=3)
callbacks = []
callbacks.append(rlrop)
model.fit_generator(traingen,
                    epochs=50,
                    validation_data=validgen,
                    use_multiprocessing=True,
                    workers=6,
                    callbacks=callbacks)

In [0]:

def draw_rectangle(image, digits, color=(255,0,0), thickness=1):
    """ 주어진 좌표값 Dataframe에 따라, image에 사각형을 그리는 메소드
    """
    if isinstance(digits, np.ndarray):
        if digits.shape[1] == 4:
            digits = pd.DataFrame(digits, columns=['cx','cy','w','h'])
        elif digits.shape[2] == 5:
            digits = pd.DataFrame(digits, columns=['cx', 'cy', 'w', 'h','label'])

    elif isinstance(digits, pd.DataFrame):
        pass
    else:
        raise TypeError("digits은 numpy.ndarray 혹은 pandas.Dataframe으로 이루어져 있어야 합니다.")

    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)
    else:
        image = image.copy()
    for idx, row in digits.iterrows():
        xmin = row.cx - row.w / 2
        xmax = row.cx + row.w / 2
        ymin = row.cy - row.h / 2
        ymax = row.cy + row.h / 2

        start = tuple(np.array((xmin, ymin), dtype=np.int32))
        end = tuple(np.array((xmax, ymax), dtype=np.int32))
        image = cv2.rectangle(image, start, end, color, thickness)
        if "label" in row:
            cv2.putText(image, str(int(row.label)), start,
                        cv2.FONT_HERSHEY_DUPLEX, 0.3, color)
    return image


In [0]:
import numpy as np
import pandas as pd
import cv2 
import matplotlib.pyplot as plt
images, _ = validgen[0] # 이미지를 가져오기
predictions = model.predict(images)

idx=0
pred_loc = predictions[idx,:,-4:]
pred_clf = predictions[idx,:,:-4]
pr_boxes = prior.generate(images[0].shape)

# 원래의 bounding box로 복원하기
res_cx = (pred_loc[:,0] 
          * pr_boxes[:,2] 
          + pr_boxes[:,0])
res_cy = (pred_loc[:,1] 
          * pr_boxes[:,3] 
          + pr_boxes[:,1])
res_w = (np.exp(pred_loc[:,2])
         *pr_boxes[:,2])
res_h = (np.exp(pred_loc[:,3])
         *pr_boxes[:,3])

restore_boxes = np.stack([res_cx,res_cy,res_w,res_h],
                         axis=-1)

# Background를 제외
bg_index = pred_clf.shape[-1]-1
fg_indices = np.argwhere(pred_clf.argmax(axis=1)!=bg_index)
restore_boxes = restore_boxes[fg_indices].squeeze()

pred_clf = pred_clf[fg_indices]
pred_clf = pred_clf.max(axis=-1)
pred_clf = pred_clf.squeeze()

vis = draw_rectangle(images[idx],restore_boxes)
plt.imshow(vis)
plt.xticks([])
plt.yticks([])    
plt.show()        

In [0]:
f