In [None]:
import tensorflow as tf

class CNN_Encoder(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        self.flat = tf.keras.layers.Flatten()
        self.fc = tf.keras.layers.Dense(embedding_dim)

    def call(self, x):
        x = self.flat(x)
        x = self.fc(x)
        return x

class RNN_Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RNN_Decoder, self).__init__()
        self.units = units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.lstm = tf.keras.layers.LSTM(self.units, return_sequences=True, return_state=True)
        self.fc = tf.keras.layers.Dense(vocab_size)

    def call(self, captions, features, omit_features=False, initial_state=None):
        embed = self.embedding(captions)
        features = tf.expand_dims(features, 1)
        lstm_input = tf.concat([features, embed], axis=-2) if not omit_features else embed
        output, memory_state, carry_state = self.lstm(lstm_input, initial_state=initial_state)
        output = self.fc(output)
        return output, memory_state, carry_state

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))