In [1]:
import os
import sys
import tensorflow as tf
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, LSTM, Lambda
from keras.engine.topology import Input
from keras import optimizers
from keras.utils.np_utils import to_categorical
from keras.models import Sequential, load_model
from keras.layers import Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout
from keras.layers import Activation, BatchNormalization, MaxPooling2D
import time
import math
from keras.utils import plot_model
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from keras.engine.topology import Input
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from keras import backend as K
K.set_image_dim_ordering('tf')

import gazenetGenerator as gaze_gen

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# global param
dataset_path = 'gaze_dataset'
learning_rate = 0.0001
time_steps = 32
num_classes = 6
batch_size = 4
time_skip = 2
origin_image_size = 360    # size of the origin image before the cropWithGaze
img_size = 128    # size of the input image for network
num_channel = 3
steps_per_epoch=400
epochs=100
validation_step=20

In [3]:
class GazeNet():
    def __init__(self,learning_rate,time_steps,num_classes,batch_size):
        self.learning_rate = learning_rate
        self.time_steps = time_steps
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.kernel_size = 15
        self.kernel_num = 256
        self.gaussian_sigma = 1
        self.gaussian_weight = self.create_gaussian_weight()
        self.model = self.create_model()

    def create_gaussian_weight(self):
        kernel_size = self.kernel_size    #same with the shape of the layer before flatten
        kernel_num = self.kernel_num
        r = (kernel_size - 1) // 2
        sigma_2 = float(self.gaussian_sigma * self.gaussian_sigma)
        pi = 3.1415926
        ratio = 1 / (2*pi*sigma_2)

        kernel = np.zeros((kernel_size, kernel_size))
        for i in range(-r, r+1):
            for j in range(-r, r+1):
                tmp = math.exp(-(i*i+j*j)/(2*sigma_2))
                kernel[i+r][j+r] = round(tmp, 3)
        kernel *= ratio
        kernel = np.expand_dims(kernel, axis=2)
        kernel = np.tile(kernel, (1,1,kernel_num))
        # print(kernel.shape)
        return kernel

    def create_model(self):

        model = Sequential()

        def input_reshape(input):
            return tf.reshape(input, [self.batch_size*self.time_steps,128,128,3])
        
        model.add(Lambda(input_reshape, input_shape=(self.time_steps,128,128,3,)))
        #block 1
        model.add(Conv2D(96,(5,5),strides = (2,2),
                            padding = 'valid',
                            activation = 'relu'))
        model.add(BatchNormalization())
        model.add(MaxPooling2D(pool_size = (2,2)))

        #block 2
        model.add(Conv2D(256,(3,3),padding = 'same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))

        model.add(Conv2D(256,(3,3),padding = 'same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))

        model.add(Conv2D(256,(3,3),padding = 'same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D(2,2))

        def multiply_constant(input):
            for i in range(self.batch_size*self.time_steps):
                tmp = tf.multiply(tf.cast(input[i], tf.float32), tf.cast(self.gaussian_weight, tf.float32))
                tmp = tf.expand_dims(tmp, 0)
                if i == 0:
                    res = tmp
                else:
                    res = tf.concat([res, tmp], 0)
            res = tf.reshape(res,[self.batch_size,self.time_steps,
                                  self.kernel_size,self.kernel_size,self.kernel_num])
            res = tf.reshape(res,[self.batch_size,self.time_steps,
                                  self.kernel_size*self.kernel_size*self.kernel_num])
            return res

        model.add(Lambda(multiply_constant))

        def mean_value(input):
            return tf.reduce_mean(input,1)

        model.add(LSTM(128,return_sequences = True))
        model.add(LSTM(6,return_sequences = True))
        model.add(Lambda(mean_value))

        adam = optimizers.Adam(lr = self.learning_rate)
        model.compile(loss='categorical_crossentropy', optimizer='adam')
        print(model.summary())

        return model
    
#     def train(self):
#         # categorical_labels = to_categorical(int_labels, num_classes=None)
#         pass

#     def load_data(self):
#         print("Hello world")

#     def save_model_weights(self, folder_path, suffix):
#         # Helper function to save your model / weights.
#         self.model.save_weights(folder_path + 'weights-' +  str(suffix) + '.h5')
#         self.model.save(folder_path + 'model-' +  str(suffix) + '.h5')

#     def load_model(self, model_file):
#         # Helper function to load an existing model.
#         self.model = load_model(model_file)

#     def load_model_weights(self,weight_file):
#         # Helper funciton to load model weights.
#         self.model.load_weights(weight_file)

In [4]:
def main(args):
    # generate model
    gaze_net = GazeNet(learning_rate,time_steps,num_classes,batch_size)
    model = gaze_net.model
    plot_model(model, to_file='model.png')
    print("generate model!")
    
    # generatr generator
    trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2)
    train_data = trainGenerator.flow_from_directory(dataset_path, subset='training',time_steps=time_steps, 
                                                    batch_size=batch_size, crop=False,
                                                    gaussian_std=0.01, time_skip=time_skip, crop_with_gaze=True,
                                                   crop_with_gaze_size=128)
    val_data = trainGenerator.flow_from_directory(dataset_path, subset='validation', time_steps=time_steps, 
                                                  batch_size=batch_size, crop=False,
                                                    gaussian_std=0.01, time_skip=time_skip, crop_with_gaze=True,
                                                   crop_with_gaze_size=128)
    # [img_seq, gaze_seq], output = next(trainGeneratorDirectory)
    print("fetch data!")
    
    # start training
    checkpoint = ModelCheckpoint('weight.{epoch:02d}.hdf5', monitor='val_acc', mode='max', period=5)
    model.fit_generator(train_data, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[checkpoint], 
                    validation_data=val_data, validation_steps=validation_step, shuffle=False)
    print("finished training!")
    
    
#     [img_seq, gaze_seq], output = next(trainGeneratorDirectory)
#     img_seq = cropWithGaze(img_seq,gaze_seq,batch_size,time_steps,img_size,num_channel)


In [5]:
if __name__ == '__main__':
    main(sys.argv)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lambda_1 (Lambda)            (128, 128, 128, 3)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (128, 62, 62, 96)         7296      
_________________________________________________________________
batch_normalization_1 (Batch (128, 62, 62, 96)         384       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (128, 31, 31, 96)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (128, 31, 31, 256)        221440    
_________________________________________________________________
batch_normalization_2 (Batch (128, 31, 31, 256)        1024      
_________________________________________________________________
activation_1 (Activation)    (128, 31, 31, 256)        0         
__________

KeyboardInterrupt: 