In [None]:
from glob import glob
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from keras.models import Model
from keras.models import Sequential

from keras.layers import Dense, RepeatVector
from keras.layers import Embedding, LSTM, TimeDistributed
from keras.layers import Concatenate, Activation

from keras.utils import to_categorical

from keras.applications import MobileNetV2
from keras.applications.mobilenet_v2 import preprocess_input

from keras.preprocessing.image import load_img, img_to_array
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences

In [None]:
images_path = '../input/flickr8kimagescaptions/flickr8k/images/'
images = glob(images_path + '*.jpg')
print(len(images))

In [None]:
mobilenet = MobileNetV2(include_top=True, weights='imagenet')

model_output = mobilenet.layers[-2].output

model = Model(inputs=[mobilenet.input], outputs=[model_output])

In [None]:
image_to_features = {}
count = 0
for image_path in images:
    img = load_img(image_path, target_size=(224, 224))
    img = img_to_array(img)
    img = preprocess_input(img)
    img = img.reshape(1, 224, 224, 3)
    
    image_name = image_path.split('/')[-1]
    features = model.predict(img).reshape(-1, )
    
    image_to_features[image_name] = features
    count += 1
    
    if count == 1500:
        break

In [None]:
captions = pd.read_csv('../input/flickr8kimagescaptions/flickr8k/captions.txt')
image_to_captions = {}
for _, item in captions.iterrows():
    image = item['image']
    caption = 'sos ' + item['caption'].lower() + ' eos'
    if image in image_to_features:
        if image in image_to_captions:
            image_to_captions[image].append(caption)
        else:
            image_to_captions[image] = [caption]

corpus = []
for _, captions in image_to_captions.items():
    corpus.extend(captions)
    
tokenizer = Tokenizer()
tokenizer.fit_on_texts(corpus)

for image, captions in image_to_captions.items():
    tokenized_captions = tokenizer.texts_to_sequences(captions)
    image_to_captions[image] = tokenized_captions

In [None]:
VOCAB_LEN = len(tokenizer.word_index)
MAX_LEN = 34
EMBED_SIZE = 128

def get_dataset(image_to_features, image_to_captions):
    
    X_1 = []
    X_2 = []
    Y = []
    
    for image, captions in image_to_captions.items():
        for caption in captions:
            
            for i in range(1, len(caption)):
                sequence = caption[:i]
                target = caption[i]
                
                X_1.append(image_to_features[image])
                X_2.append(pad_sequences([sequence], maxlen=MAX_LEN, padding='post', truncating='post')[0])
                Y.append(to_categorical([target], num_classes=VOCAB_LEN+1)[0])
                
    return X_1, X_2, Y

In [None]:
X_1, X_2, Y = get_dataset(image_to_features, image_to_captions)

X_1 = np.array(X_1)
X_2 = np.array(X_2, dtype='float64')
Y = np.array(Y, dtype='float64')

In [None]:
image_model = Sequential([
    Dense(EMBED_SIZE, input_shape=(1280, ), activation='relu'),
    RepeatVector(MAX_LEN),
])

caption_model = Sequential([
    Embedding(input_dim=VOCAB_LEN+1, output_dim=EMBED_SIZE, input_length=MAX_LEN),
    LSTM(256, return_sequences=True),
    TimeDistributed(Dense(EMBED_SIZE)),
])

concatenated = Concatenate()([image_model.output, caption_model.output])
X = LSTM(128, return_sequences=True)(concatenated)
X = LSTM(512)(X)
X = Dense(VOCAB_LEN+1)(X)
output = Activation('softmax')(X)

captioning_model = Model(inputs=[image_model.input, caption_model.input], outputs=output)
captioning_model.compile(loss='categorical_crossentropy', optimizer='RMSprop', metrics=['accuracy'])

In [None]:
captioning_model.fit([X_1, X_2], Y, batch_size=512, epochs=50)

In [None]:
index_to_word = tokenizer.index_word

In [None]:
def get_image(image_number):
    image_path = images[image_number]
    img = load_img(image_path, target_size=(224, 224))
    img = img_to_array(img)
    img = preprocess_input(img)
    img = img.reshape(1, 224, 224, 3)
    return img

In [None]:
img_number = int(input())

img = load_img(images[img_number])
plt.imshow(img)

test_X1 = model.predict(get_image(img_number)).reshape(1, -1)
test_X2 = 'sos'
count = 0
caption = ''
while count < 30:
    count += 1
    
    encoded = tokenizer.texts_to_sequences([test_X2])
    encoded = np.array(pad_sequences(encoded, padding='post', truncating='post', maxlen=MAX_LEN), dtype='float64')
    prediction = np.argmax(captioning_model.predict([test_X1, encoded]))
    
    sampled = index_to_word[prediction]
    
    if sampled == 'eos':
        break
        
    caption += ' ' + sampled
        
    test_X2 += ' ' + sampled
print(caption)

In [None]:
model.save_weights("model_weights.h5")

In [None]:
captioning_model.save_weights("captioning_model_weights.h5")

In [None]:
with open('tokenizer.pickle', 'wb') as handle:
    pickle.dump(tokenizer, handle, protocol=pickle.HIGHEST_PROTOCOL)