In [1]:
import os
import sys
import tensorflow as tf
import numpy as np
from keras.models import Model
from keras.layers import Input, Dense, LSTM, Lambda
from keras.engine.topology import Input
from keras import optimizers
from keras.utils.np_utils import to_categorical
from keras.models import Sequential, load_model
from keras.layers import Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dropout
from keras.layers import Activation, BatchNormalization, MaxPooling2D
import time
import math
from keras.utils import plot_model
from keras.callbacks import ModelCheckpoint
from keras.preprocessing.image import ImageDataGenerator
from keras.engine.topology import Input
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model
from keras import backend as K
K.set_image_dim_ordering('tf')

import gazenetGenerator as gaze_gen

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
# global param
dataset_path = 'gaze_dataset'
learning_rate = 0.0001
time_steps = 32
num_classes = 6
batch_size = 4
time_skip = 2
origin_image_size = 360    # size of the origin image before the cropWithGaze
img_size = 128    # size of the input image for network
num_channel = 3
steps_per_epoch=400
epochs=100
validation_step=20

In [3]:
class GazeNet():
    def __init__(self,learning_rate,time_steps,num_classes,batch_size):
        self.learning_rate = learning_rate
        self.time_steps = time_steps
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.kernel_size = 15
        self.kernel_num = 256
        self.gaussian_sigma = 1
        self.gaussian_weight = self.create_gaussian_weight()
        self.model = self.create_model()

    def create_gaussian_weight(self):
        kernel_size = self.kernel_size    #same with the shape of the layer before flatten
        kernel_num = self.kernel_num
        r = (kernel_size - 1) // 2
        sigma_2 = float(self.gaussian_sigma * self.gaussian_sigma)
        pi = 3.1415926
        ratio = 1 / (2*pi*sigma_2)

        kernel = np.zeros((kernel_size, kernel_size))
        for i in range(-r, r+1):
            for j in range(-r, r+1):
                tmp = math.exp(-(i*i+j*j)/(2*sigma_2))
                kernel[i+r][j+r] = round(tmp, 3)
        kernel *= ratio
        kernel = np.expand_dims(kernel, axis=2)
        kernel = np.tile(kernel, (1,1,kernel_num))
        # print(kernel.shape)
        return kernel

    def create_model(self):

        model = Sequential()

        def input_reshape(input):
            return tf.reshape(input, [self.batch_size*self.time_steps,128,128,3])
        
        model.add(Lambda(input_reshape, input_shape=(self.time_steps,128,128,3,)))
        #block 1
        model.add(Conv2D(96,(5,5),strides = (2,2),
                            padding = 'valid',
                            activation = 'relu'))
        model.add(BatchNormalization())
        model.add(MaxPooling2D(pool_size = (2,2)))

        #block 2
        model.add(Conv2D(256,(3,3),padding = 'same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))

        model.add(Conv2D(256,(3,3),padding = 'same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))

        model.add(Conv2D(256,(3,3),padding = 'same'))
        model.add(BatchNormalization())
        model.add(Activation('relu'))
        model.add(MaxPooling2D(2,2))

        def multiply_constant(input):
            for i in range(self.batch_size*self.time_steps):
                tmp = tf.multiply(tf.cast(input[i], tf.float32), tf.cast(self.gaussian_weight, tf.float32))
                tmp = tf.expand_dims(tmp, 0)
                if i == 0:
                    res = tmp
                else:
                    res = tf.concat([res, tmp], 0)
            res = tf.reshape(res,[self.batch_size,self.time_steps,
                                  self.kernel_size,self.kernel_size,self.kernel_num])
            res = tf.reshape(res,[self.batch_size,self.time_steps,
                                  self.kernel_size*self.kernel_size*self.kernel_num])
            return res

        model.add(Lambda(multiply_constant))

        def mean_value(input):
            return tf.reduce_mean(input,1)

        model.add(LSTM(128,return_sequences = True))
        model.add(LSTM(6,return_sequences = True))
        model.add(Lambda(mean_value))

        adam = optimizers.Adam(lr = self.learning_rate)
        model.compile(loss='categorical_crossentropy', optimizer='adam')
        print(model.summary())

        return model
    
#     def train(self):
#         # categorical_labels = to_categorical(int_labels, num_classes=None)
#         pass

#     def load_data(self):
#         print("Hello world")

#     def save_model_weights(self, folder_path, suffix):
#         # Helper function to save your model / weights.
#         self.model.save_weights(folder_path + 'weights-' +  str(suffix) + '.h5')
#         self.model.save(folder_path + 'model-' +  str(suffix) + '.h5')

#     def load_model(self, model_file):
#         # Helper function to load an existing model.
#         self.model = load_model(model_file)

#     def load_model_weights(self,weight_file):
#         # Helper funciton to load model weights.
#         self.model.load_weights(weight_file)

In [4]:
def main(args):
    # generate model
    gaze_net = GazeNet(learning_rate,time_steps,num_classes,batch_size)
    model = gaze_net.model
    plot_model(model, to_file='model.png')
    print("generate model!")
    
    # generatr generator
    trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2)
    train_data = trainGenerator.flow_from_directory(dataset_path, subset='training',time_steps=time_steps, 
                                                    batch_size=batch_size, crop=True, target_size=(360,360),
                                                    gaussian_std=0.01, time_skip=time_skip, crop_with_gaze=True,
                                                   crop_with_gaze_size=128)
    val_data = trainGenerator.flow_from_directory(dataset_path, subset='validation', time_steps=time_steps, 
                                                  batch_size=batch_size, crop=True, target_size=(360,360),
                                                    gaussian_std=0.01, time_skip=time_skip, crop_with_gaze=True,
                                                   crop_with_gaze_size=128)
    # [img_seq, gaze_seq], output = next(trainGeneratorDirectory)
    print("fetch data!")
    
    # start training
    checkpoint = ModelCheckpoint('weight.{epoch:02d}.hdf5', monitor='val_acc', mode='max', period=5)
    model.fit_generator(train_data, steps_per_epoch=steps_per_epoch, epochs=epochs, callbacks=[checkpoint], 
                    validation_data=val_data, validation_steps=validation_step, shuffle=False)
    print("finished training!")
    
    
#     [img_seq, gaze_seq], output = next(trainGeneratorDirectory)
#     img_seq = cropWithGaze(img_seq,gaze_seq,batch_size,time_steps,img_size,num_channel)


In [None]:
if __name__ == '__main__':
    main(sys.argv)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lambda_1 (Lambda)            (128, 128, 128, 3)        0         
_________________________________________________________________
conv2d_1 (Conv2D)            (128, 62, 62, 96)         7296      
_________________________________________________________________
batch_normalization_1 (Batch (128, 62, 62, 96)         384       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (128, 31, 31, 96)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (128, 31, 31, 256)        221440    
_________________________________________________________________
batch_normalization_2 (Batch (128, 31, 31, 256)        1024      
_________________________________________________________________
activation_1 (Activation)    (128, 31, 31, 256)        0         
__________

26.00577
170.99599
26.007435
170.99904
25.984152
171.00081
26.000223
171.02403
24.998505
167.00705
25.002728
163.9956
25.011904
163.00766
25.001019
161.01645
24.999443
160.99648
24.022621
159.98929
23.998718
159.98822
24.000237
159.99744
24.001413
159.99069
23.99179
159.00246
24.011503
159.98201
24.00651
160.00119
24.014818
158.98854
24.022968
159.00522
23.999144
158.9956
24.00388
159.98325
24.002262
160.00139
63.994377
153.00002
63.987225
166.01027
63.996967
152.00603
63.984444
165.01025
64.00083
152.01442
63.996326
152.00435
63.990215
165.00053
64.00657
164.99057
63.997837
150.98189
64.016785
165.01431
64.00134
151.98778
65.00766
164.98157
64.9874
152.00667
66.9918
166.00055
66.99172
166.00179
67.020996
152.99513
67.98956
167.99606
67.98938
155.99373
68.000496
156.0137
68.01185
169.01976
68.000435
157.00687
67.99251
167.99733
68.00663
167.99345
68.00811
156.99582
68.00989
168.00104
68.00484
157.00829
68.00432
167.016
68.00481
168.00179
68.00992
157.00581
67.99144
156.9986
68.00099
15

-0.98371494
35.99872
-1.9970846
41.985783
-1.0186459
36.008976
-2.008122
41.988945
-1.9987328
42.012577
-0.9974482
37.00296
-1.9979758
42.00288
-1.0239396
36.99225
-0.9927292
37.010105
-0.99589044
42.004505
-0.9971772
36.992966
-1.9977584
42.999958
-2.0071094
43.00138
-1.9925932
43.000546
-0.9963807
37.0041
-2.0004764
42.982174
-0.9937798
37.031258
-1.0031395
42.990936
-1.0107328
42.988472
-0.9895294
37.001286
-0.99661744
37.016468
-0.99856246
37.005547
-0.9799359
42.006035
-0.9934301
36.99536
24.00982
12.993196
23.990858
13.005707
23.003504
6.994958
24.00692
12.00642
23.00721
7.000104
22.988585
5.9984465
23.992044
10.989472
23.009928
5.0036407
23.00115
5.0086164
24.003824
9.998652
24.001684
10.989759
23.006134
2.997732
24.001657
11.985
22.993082
3.0009217
25.000114
13.003725
24.995865
12.992575
24.009583
4.0096145
23.992441
14.019877
24.005112
2.9918458
25.006556
14.013912
24.007893
2.000573
25.01012
14.993268
23.981035
2.9766932
24.004549
3.0062065
24.987371
15.009996
24.015396
2.000

-30.995974
219.01012
-31.00526
219.00313
-37.990376
182.99692
-32.000458
218.99248
-31.99331
218.99654
-31.995052
218.99706
-32.007694
219.002
-31.994354
218.9985
-31.000032
221.00594
-30.988329
220.00995
-31.009718
220.00519
-31.005817
220.00804
-28.987936
221.00717
-27.992247
219.98831
-31.991507
222.9962
-31.989193
222.98492
-31.99527
222.99713
-31.990105
222.98863
-32.003124
222.99387
-31.997477
225.99292
-31.993692
226.00026
-32.003326
225.99422
-31.994648
225.98627
-31.98577
226.004
-32.006546
223.99919
-32.00158
224.00806
-32.00436
223.98434
-32.000587
224.02373
-32.00441
224.01315
-32.009087
224.00381
-31.991276
224.00557
-32.017986
223.98221
35841.977
36000.004
52.00069
393.02347
52.02557
394.98953
52.00222
395.00674
51.997375
395.0033
51.995186
396.00543
53.00917
396.98456
53.00805
397.00586
52.99245
396.9972
53.020107
396.99417
52.991833
396.99106
52.99016
398.00916
52.975468
397.00433
51.00277
392.01944
50.007977
375.9947
49.98599
376.00256
52.006306
360.00244
54.000607
341

22.018978
199.9896
22.008524
199.99942
23.003437
198.00293
23.013529
198.00293
23.992355
197.99872
25.010427
197.00067
25.004942
196.99333
26.00481
197.0041
26.002237
197.01736
26.997547
194.00269
27.002792
193.99991
26.997107
193.97894
28.02077
193.00914
29.001389
193.00745
29.995506
191.00111
32.00697
190.98677
31.990234
191.00749
33.01423
189.03123
33.99136
188.98882
34.989773
185.98166
34.989883
185.99551
36.01406
185.98979
36.999767
185.98685
38.00943
185.0087
37.992306
184.00575
38.9973
184.01027
38.990543
183.0088
38.99109
182.99406
38.998524
182.00557
39.996456
182.00827
39.994225
181.9985
41.02233
182.01027
34.990444
168.01366
35.022137
168.01115
34.98962
168.02016
34.9999
167.98697
34.99646
167.98413
36.00604
168.99275
35.99765
168.00037
36.006847
169.00252
35.988846
167.97865
36.998848
168.99182
36.989513
167.98378
36.974663
169.01111
36.994495
168.98372
37.003323
169.00357
37.007355
169.00037
38.004517
168.00548
38.00928
167.98833
37.99272
168.01086
38.01074
168.00633
38.01