# Train BlazePalm Model

In [1]:
import os
import glob

import tensorflow as tf
import numpy as np
import cv2
import pandas as pd
import matplotlib.pyplot as plt

from utils import anchors_generator, encoder, loss_function
from nets import blaze_palm 

## Prepare Data

In [None]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, image_dir, annotation_dir, batch_size=32, image_shape=(256, 256, 3),
                 num_coordinates=18, num_anchors=2944, 
                 shuffle=True):
        
        self.file_name_list = [file_name.replace('.jpg', '') for file_name in glob.glob(os.path.join(image_dir, '*.jpg'))]
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.batch_size = batch_size
        self.image_shape = image_shape
        self.num_coordinates = num_coordinates
        self.num_anchors = num_anchors
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.file_name_list) / self.batch_size))

    def __getitem__(self, index):
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Find list of IDs
        file_name_list_temp = [self.file_name_list[k] for k in indexes]
        # Generate data
        image_batch, annotation_batch = self.__data_generation(file_name_list_temp)

        return image_batch, annotation_batch

    def on_epoch_end(self):
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, file_name_list_temp):
        image_batch = np.empty((self.batch_size, *self.dim))
        annotation_batch = np.empty((self.batch_size),)
        # Generate data
        for i, file_name in enumerate(file_name_list_temp):
            input_image = cv2.imread()
            image_batch[i,] = input_image
            # Store class
            annotation_batch[i] = self.labels[ID]
            
            
        return image_batch, annotation_batch

## Create BlazePalm Model

In [2]:
model = blaze_palm.build_blaze_palm_model()

In [3]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d (Conv2D)                 (None, 128, 128, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 128, 128, 32) 0           conv2d[0][0]                     
__________________________________________________________________________________________________
depthwise_conv2d (DepthwiseConv (None, 128, 128, 32) 320         activation[0][0]                 
______________________________________________________________________________________________

In [5]:
adam = tf.keras.optimizers.Adam()
ssd_loss = loss_function.SSDLoss()
model.compile(optimizer=adam, loss=ssd_loss.compute_loss)

## Training