In [None]:
import re
import tensorflow as tf
import numpy as np
from tqdm import tqdm
import config

def get_preprocessed_caption(caption):
    caption = re.sub(r'\s+', ' ', caption)
    caption = caption.strip()
    caption = "<start> " + caption + " <end>"
    return caption

def read_captions(dataset_path):
    images_captions_dict = {}
    with open(dataset_path + "/captions.txt", "r") as dataset_info:
        next(dataset_info)
        for info_raw in list(dataset_info)[:4000]:
            info = info_raw.split(",")
            image_filename = info[0]
            caption = get_preprocessed_caption(info[1])
            if image_filename not in images_captions_dict:
                images_captions_dict[image_filename] = [caption]
            else:
                images_captions_dict[image_filename].append(caption)
    return images_captions_dict

def extract_image_features(images_captions_dict):
    image_dataset = tf.data.Dataset.from_tensor_slices(list(images_captions_dict.keys()))
    image_dataset = image_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(64)

    images_dict = {}
    encoder = get_encoder()
    for img_tensor, path_tensor in tqdm(image_dataset):
        batch_features_tensor = encoder(img_tensor)
        for batch_features, path in zip(batch_features_tensor, path_tensor):
            decoded_path = path.numpy().decode("utf-8")
            images_dict[decoded_path] = batch_features.numpy()
    return images_dict

def get_images_labels(image_filenames, images_dict, images_captions_dict):
    images, labels = [], []
    for image_filename in image_filenames:
        image = images_dict[image_filename]
        captions = images_captions_dict[image_filename]
        for caption in captions:
            images.append(image)
            labels.append(caption)
    return images, labels

def create_tokenizer(captions):
    tokenizer = tf.keras.preprocessing.text.Tokenizer(num_words=config.top_k, oov_token="<unk>", filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ ')
    tokenizer.fit_on_texts(captions)
    tokenizer.word_index['<pad>'] = 0
    tokenizer.index_word[0] = '<pad>'
    return tokenizer

def tokenize_captions(tokenizer, captions):
    sequences = tokenizer.texts_to_sequences(captions)
    return tf.keras.preprocessing.sequence.pad_sequences(sequences, padding='post')

def loss_function(real, pred, loss_object):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_mean(loss_)

def clean_caption(caption):
    return [item for item in caption if item not in ['<start>', '<end>', '<pad>']]

def get_caption(img, encoder, decoder):
    features = encoder(tf.expand_dims(img, 0))
    caption = []
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)
    state = None

    for i in range(config.max_caption_length):
        predictions, memory_state, carry_state = decoder(dec_input, features, omit_features=i > 0, initial_state=state)
        word_index = np.argmax(predictions.numpy().flatten())
        if tokenizer.index_word[word_index] == '<end>':
            break
        caption.append(tokenizer.index_word[word_index])
        dec_input = tf.expand_dims([word_index], 0)
        state = [memory_state, carry_state]

    return clean_caption(caption)