In [None]:
import tensorflow as tf
from sklearn.model_selection import train_test_split
import config
import utils

def load_image(image_path):
    img = tf.io.read_file(config.dataset_images_path + image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (config.img_height, config.img_width))
    img = tf.keras.applications.inception_v3.preprocess_input(img)
    return img, image_path

def prepare_data():
    images_captions_dict = utils.read_captions(config.dataset_path)
    images_dict = utils.extract_image_features(images_captions_dict)

    image_filenames = list(images_captions_dict.keys())
    image_filenames_train, image_filenames_test = train_test_split(image_filenames, test_size=config.validation_split, random_state=1)

    X_train, y_train_raw = utils.get_images_labels(image_filenames_train, images_dict, images_captions_dict)
    X_test, y_test_raw = utils.get_images_labels(image_filenames_test, images_dict, images_captions_dict)

    tokenizer = utils.create_tokenizer(y_train_raw)
    y_train = utils.tokenize_captions(tokenizer, y_train_raw)

    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
    train_dataset = train_dataset.shuffle(config.BUFFER_SIZE).batch(config.BATCH_SIZE).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

    return train_dataset, (X_test, y_test_raw, image_filenames_test)