In [25]:
import keras
from keras.models import Model
from keras.layers import Input, Dense, LSTM, BatchNormalization, Concatenate, Flatten, Conv2D

In [28]:
class GazeNet():
    def __init__(self):
        self.model = self.create_model()
    
    def convolution(self, kernel_size = 3):
        def f(input):
            filters = 128
            conv1 = Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', activation=None)(input)
            conv1 = Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', activation=None)(conv1)
            conv1 = Conv2D(filters, kernel_size, strides=(1, 1), padding='valid', activation=None)(conv1)
            return conv1
        return f
    
    def lstm(self):
        def f(input):
            lstm = LSTM(128,return_sequences=True)(input)
            lstm = LSTM(128,return_sequences=True)(lstm)
            lstm = LSTM(128)(lstm)
            return lstm
        return f
    
    def dense(self):
        def f(input):
            dense = Dense(128)(input)
            dense = Dense(128)(dense)
            dense = Dense(128)(dense)
            return dense
        return f
        
    def create_model(self):
        image = Input(shape=(200,200,3,))
        gaze = Input(shape=(100,2,))
        image_embedding = self.convolution()(image)
        gaze_embedding = self.lstm()(gaze)
        flatten = Flatten()(image_embedding)
        merged = Concatenate()([flatten, gaze_embedding])
        output = self.dense()(merged)
        
        model = Model(input=[image, gaze], output=output)
        return model

In [31]:
from IPython.display import SVG
from keras.utils.vis_utils import model_to_dot
from keras.utils import plot_model


In [32]:
gaze_net = GazeNet()
SVG(model_to_dot(gaze_net.model).create(prog='dot', format='svg'))
plot_model(gaze_net.model, to_file='model.png')

