In [39]:
import tensorflow as tf

In [40]:
VOCAB_SIZE=10
MAX_QUESTION_LEN=4

In [41]:
class VQANet:
    def __init__(self, combine_type, question_embed_dim, lstm_dim, n_answers):
        self.combine_type = combine_type
        self.question_embed_dim = question_embed_dim
        self.lstm_dim = lstm_dim
        self.n_answers = n_answers
        self.build()
        
    def build(self):
        if self.combine_type == 'show-and-tell':
            dummy = tf.keras.layers.Input(shape=(10,), 
                                          dtype='float32')
            
            image_embedding = tf.keras.layers.Dense(units=self.question_embed_dim, 
                                              activation='elu', 
                                              name='image_embedding')(inputs=dummy)

            image_embedding = tf.keras.layers.Reshape((1, self.question_embed_dim))(image_embedding)
            
            question_input = tf.keras.layers.Input(shape=(MAX_QUESTION_LEN,), 
                                                   dtype='int32',
                                                   name='question_input')
            
            question_embedding = tf.keras.layers.Embedding(input_dim=VOCAB_SIZE, 
                                                           output_dim=self.question_embed_dim, 
                                                           input_length=MAX_QUESTION_LEN,
                                                           name='question_embedding')(inputs=question_input)
            
            image_question_embedding = tf.keras.layers.Concatenate(axis=1, 
                                                                   name='image_question_embedding')(inputs=[image_embedding, question_embedding])
            
            question_features, last_h, _ = tf.keras.layers.LSTM(units=self.lstm_dim, 
                                                                return_sequences=True, 
                                                                return_state=True, 
                                                                name='question_generator')(inputs=image_question_embedding)

            question_pred = tf.keras.layers.TimeDistributed(layer=tf.keras.layers.Dense(units=VOCAB_SIZE, 
                                                                  activation='softmax', 
                                                                  name='word_classifier'))(inputs=question_features)
            
            answer_pred = tf.keras.layers.Dense(units=self.n_answers,
                                                activation='softmax',
                                                name='answer_classifier')(inputs=last_h)
            
            
            model = tf.keras.Model(inputs=[dummy, question_input], outputs=[question_pred, answer_pred])  
            print(model.summary())

In [42]:
model = VQANet(combine_type='show-and-tell', question_embed_dim=2, lstm_dim=6, n_answers=3)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_16 (InputLayer)           (None, 10)           0                                            
__________________________________________________________________________________________________
image_embedding (Dense)         (None, 2)            22          input_16[0][0]                   
__________________________________________________________________________________________________
question_input (InputLayer)     (None, 4)            0                                            
__________________________________________________________________________________________________
reshape_12 (Reshape)            (None, 1, 2)         0           image_embedding[0][0]            
__________________________________________________________________________________________________
question_e