# Train BlazePalm Model

In [1]:
import os
import glob
import json

import tensorflow as tf
import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from utils import anchors_generator, encoder, loss_function
from nets import blaze_palm 
from utils.data_generator import DataGenerator

In [2]:
tf.__version__

'2.7.0'

## Create BlazePalm Model

In [3]:
model = blaze_palm.build_blaze_palm_model()

In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 32  896         ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 activation (Activation)        (None, 128, 128, 32  0           ['conv2d[0][0]']                 
                                )                                                             

In [None]:
adam = tf.keras.optimizers.Adam()
ssd_loss = loss_function.SSDLoss(alpha=1./256.)
model.compile(optimizer='adam', loss=ssd_loss.compute_loss)

## Prepare Data

In [None]:
train_data_generator = DataGenerator(image_dir=os.path.join('dataset', 'image'), batch_size=32,
                                     annotation_dir=os.path.join('dataset', 'annotation'))
val_data_generator = DataGenerator(image_dir=os.path.join('dataset', 'image'), batch_size=32,
                                     annotation_dir=os.path.join('dataset', 'annotation'))

In [None]:
batch = train_data_generator.__getitem__(0)
ssd_loss.compute_loss(batch[1], model.predict(batch[0]))

## Training

> `fit_generator` not work

In [None]:
history = model.fit(x=train_data_generator, epochs=1000)

In [None]:
plt.figure(figsize=(10,6))
plt.plot(history.history['loss'][::10], ls='--')
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train'])
plt.show()

## Test Model

In [None]:
from utils.anchor_config import AnchorsConfig
from utils.anchors_generator import AnchorsGenerator
from utils.encoder import center_to_corner

In [None]:
def sigmoid(values):
    return 1/(1+np.exp(-values))

In [None]:
def non_max_suppression_fast(boxes, probabilities=None, overlap_threshold=0.3):
    """
    Algorithm to filter bounding box proposals by removing the ones with a too low confidence score
    and with too much overlap.
    Source: https://www.pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/
    :param boxes: List of proposed bounding boxes
    :param overlap_threshold: the maximum overlap that is allowed
    :return: filtered boxes
    """
    # if there are no boxes, return an empty list
    if boxes.shape[1] == 0:
        return []
    # if the bounding boxes integers, convert them to floats --
    # this is important since we'll be doing a bunch of divisions
    if boxes.dtype.kind == "i":
        boxes = boxes.astype("float")
    # initialize the list of picked indexes
    pick = []
    # grab the coordinates of the bounding boxes
    x1 = boxes[:, 0] - (boxes[:, 2] / [2])  # center x - width/2
    y1 = boxes[:, 1] - (boxes[:, 3] / [2])  # center y - height/2
    x2 = boxes[:, 0] + (boxes[:, 2] / [2])  # center x + width/2
    y2 = boxes[:, 1] + (boxes[:, 3] / [2])  # center y + height/2

    # compute the area of the bounding boxes and grab the indexes to sort
    # (in the case that no probabilities are provided, simply sort on the
    # bottom-left y-coordinate)
    area = boxes[:, 2] * boxes[:, 3]  # width * height
    idxs = y2


    # if probabilities are provided, sort on them instead
    if probabilities is not None:
        idxs = probabilities

    # sort the indexes
    idxs = np.argsort(idxs)
    # keep looping while some indexes still remain in the indexes
    # list
    while len(idxs) > 0:
        # grab the last index in the indexes list and add the
        # index value to the list of picked indexes
        last = len(idxs) - 1
        i = idxs[last]
        pick.append(i)
        # find the largest (x, y) coordinates for the start of
        # the bounding box and the smallest (x, y) coordinates
        # for the end of the bounding box
        xx1 = np.maximum(x1[i], x1[idxs[:last]])
        yy1 = np.maximum(y1[i], y1[idxs[:last]])
        xx2 = np.minimum(x2[i], x2[idxs[:last]])
        yy2 = np.minimum(y2[i], y2[idxs[:last]])
        # compute the width and height of the bounding box
        w = np.maximum(0, xx2 - xx1 + 1)
        h = np.maximum(0, yy2 - yy1 + 1)
        # compute the ratio of overlap
        overlap = (w * h) / area[idxs[:last]]
        # delete all indexes from the index list that have
        idxs = np.delete(idxs, np.concatenate(([last],
                                               np.where(overlap > overlap_threshold)[0])))
    # return only the bounding boxes that were picked
    return pick

In [None]:
def preprocess(bgr_image, w, h, normailze=True):
    # convert to rgb
    rgb_image = bgr_image[:, :, ::-1]
    # pad to square and resize
    shape = np.r_[rgb_image.shape]
    padding = (shape.max() - shape[:2]).astype('uint32') // 2
    rgb_image = np.pad(rgb_image, ((padding[0], padding[0]), (padding[1], padding[1]), (0, 0)), mode='constant')
    rgb_image = cv2.resize(rgb_image, (w, h))
    rgb_image = np.ascontiguousarray(rgb_image)
    # normalize
    if normailze:
        rgb_image = np.ascontiguousarray(2 * ((rgb_image / 255) - 0.5).astype('float32'))
    # reshape as input shape
    rgb_image = rgb_image[tf.newaxis, ...]
    return rgb_image

In [None]:
anchor_config = AnchorsConfig(fixed_anchor_size=False)
anchor_generator = AnchorsGenerator(anchor_config)
anchors = anchor_generator.generate()
anchors_normalized = np.array([[anchor.x_center,anchor.y_center, anchor.w, anchor.h] for anchor in anchors])

In [None]:
image_path = np.random.choice(glob.glob(os.path.join('dataset','image','*.jpg')))
image = preprocess(cv2.imread(image_path), 256, 256)
pred = model.predict(image)

output_clf = pred[:,:,0][0]
output_reg = pred[:,:,1:][0]
scores = sigmoid(output_clf)
t = 0.95
output_reg = output_reg[scores>t]
output_clf = output_clf[scores>t]
candidate_anchors = anchors_normalized[scores>t]

moved_output_reg = output_reg.copy()
moved_output_reg[:, :2] = moved_output_reg[:, :2] + candidate_anchors[:, :2] * 256

box_ids = non_max_suppression_fast(moved_output_reg[:, :4], output_clf)

center_wo_offst = candidate_anchors[box_ids,:2] * 256
bboxes = moved_output_reg[box_ids, :4].astype('int')
keypoints_set = output_reg[box_ids,4:].reshape(-1, 7, 2)
for i in range(len(keypoints_set)):
    keypoints_set[i] = keypoints_set[i] + center_wo_offst[i]

color = ['limegreen', 'r']
width = [3, 3]
style = ['-', '--']
alpha = [1, 0.5]


plt.figure(figsize=(7, 7))
plt.title("Result")
image_show = preprocess(cv2.imread(image_path), 256, 256, False)
result_image = image_show.copy()[0]
    
for i in range(len(bboxes)):
    cx, cy, w, h = bboxes[i][:4]
    x1, y1, x2, y2 = (cx-w//2, cy-h//2, cx+w//2, cy+h//2)
    if i == 0:
        rect = patches.Rectangle((x1,y1), x2-x1, y2-y1, lw=width[1], ec=color[0], ls=style[0], alpha=alpha[1], facecolor="none", label='Prediction')
    else:
        rect = patches.Rectangle((x1,y1), x2-x1, y2-y1, lw=width[1], ec=color[0], ls=style[0], alpha=alpha[1], facecolor="none")
    plt.gca().add_patch(rect)
    plt.scatter((x2+x1)/2, (y2+y1)/2, color=color[0], s=54.87, alpha=alpha[1])
    for key_point in keypoints_set[i]:
        plt.scatter(key_point[0], key_point[1], color=color[1], s=9.487)

plt.legend()
plt.imshow(result_image)
plt.show()

## Test Model

In [None]:
image_path = np.random.choice(glob.glob(os.path.join('dataset','image','*.jpg')))
image = preprocess(cv2.imread(image_path), 256, 256)
image_show = preprocess(cv2.imread(image_path), 256, 256, False)
# plt.imshow(image[0])

pred = model.predict(image)

pred_id = np.argmax(pred[0,:,0])
pred_loc = pred[0,:,1:][pred_id]
pred_loc[:2] = anchors_normalized[pred_id,:2]*256 + pred_loc[:2]
pred_loc[:4] = center_to_corner(np.array([pred_loc[:4]]))
pred_loc[4:] = (pred_loc[4:].reshape(-1, 2) + anchors_normalized[pred_id,:2]*256).reshape(-1)

color = ['limegreen', 'r']
width = [3, 3]
style = ['-', '--']
alpha = [1, 0.5]


plt.figure(figsize=(7, 7))
plt.title("Result")
result_image = image_show.copy()[0]    

x1, y1, x2, y2 = pred_loc[:4]
rect = patches.Rectangle((x1,y1), x2-x1, y2-y1, 
                         lw=width[1], ec=color[0],
                         ls=style[0], alpha=alpha[1], facecolor="none", label='Prediction')
plt.gca().add_patch(rect)
for key_point in pred_loc[4:].reshape(-1, 2):
    plt.scatter(key_point[0], key_point[1], color=color[1], s=9.487)

# plt.scatter((x2+x1)/2, (y2+y1)/2, color=color[0], alpha=alpha[1], s=87)

plt.legend()
plt.imshow(result_image)
plt.show()

## Save Model

In [None]:
model.save('balzepalm_first_version.h5')

In [None]:
palm_model = tf.keras.models.load_model('model/palm.h5', custom_objects={'compute_loss': loss_function.SSDLoss(alpha=1./256.).compute_loss})

In [None]:
image_path = np.random.choice(glob.glob(os.path.join('dataset','image','*.jpg')))
image = preprocess(cv2.imread(image_path), 256, 256)
image_show = preprocess(cv2.imread(image_path), 256, 256, False)
# plt.imshow(image[0])

pred = palm_model.predict(image)

pred_id = np.argmax(pred[0,:,0])
pred_loc = pred[0,:,1:][pred_id]
pred_loc[:2] = anchors_normalized[pred_id,:2]*256 + pred_loc[:2]
pred_loc[:4] = center_to_corner(np.array([pred_loc[:4]]))
pred_loc[4:] = (pred_loc[4:].reshape(-1, 2) + anchors_normalized[pred_id,:2]*256).reshape(-1)

color = ['limegreen', 'r']
width = [3, 3]
style = ['-', '--']
alpha = [1, 0.5]


plt.figure(figsize=(7, 7))
plt.title("Result")
result_image = image_show.copy()[0]    

x1, y1, x2, y2 = pred_loc[:4]
rect = patches.Rectangle((x1,y1), x2-x1, y2-y1, 
                         lw=width[1], ec=color[0],
                         ls=style[0], alpha=alpha[1], facecolor="none", label='Prediction')
plt.gca().add_patch(rect)
for key_point in pred_loc[4:].reshape(-1, 2):
    plt.scatter(key_point[0], key_point[1], color=color[1], s=9.487)

# plt.scatter((x2+x1)/2, (y2+y1)/2, color=color[0], alpha=alpha[1], s=87)

plt.legend()
plt.imshow(result_image)
plt.show()