In [1]:
from os import listdir
from pickle import dump
from pickle import load
import numpy as np
import string
import seaborn as sns
import os

In [2]:
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
from tqdm import tqdm
import matplotlib.pyplot as plt

In [6]:
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (224, 224))
    img = tf.reshape(img,(1, img.shape[0], img.shape[1], img.shape[2]))
    img = tf.keras.applications.vgg16.preprocess_input(img)
    return img

In [7]:
# extract features from each photo in the directory
def extract_features(directory):
    model = keras.applications.VGG16()
    # re-structure the model
    model.layers.pop()
    model = keras.Model(inputs=model.inputs, outputs=model.layers[-1].output)
    # summarize
    #print(model.summary())
    features = dict()
    for name in tqdm(listdir(directory)):
        # load an image from file
        filename = directory + '\\' + name
        image = load_image(filename)
        feature = model.predict(image, verbose=0)
        image_id = name.split('.')[0]
        # store feature
        features[image_id] = feature
    return features

In [8]:
# extract features from all images
directory = 'Flicker8k_Dataset'
features = extract_features(directory)
print('Extracted Features: %d' % len(features))
# save to file
dump(features, open('features.pkl', 'wb'))


100%|██████████████████████████████████████████████████████████████████████████████| 8091/8091 [32:07<00:00,  4.20it/s]


Extracted Features: 8091


In [9]:
# load doc into memory
def load_doc(filename):
    # open the file as read only
    file = open(filename, 'r')
    # read all text
    text = file.read()
    # close the file
    file.close()
    return text

In [10]:
# extract descriptions for images
def load_descriptions(doc):
    mapping = dict()
    # process lines
    for line in doc.split('\n'):
        # split line by white space
        tokens = line.split()
        if len(line) < 2:
            continue
        # take the first token as the image id, the rest as the description
        image_id, image_desc = tokens[0], tokens[1:]
        # remove filename from image id
        image_id = image_id.split('.')[0]
        # convert description tokens back to string
        image_desc = ' '.join(image_desc)
        # create the list if needed
        if image_id not in mapping:
            mapping[image_id] = list()
        # store description
        mapping[image_id].append(image_desc)
    return mapping

In [11]:
def clean_descriptions(descriptions):
    # prepare translation table for removing punctuation
    table = str.maketrans('', '', string.punctuation)
    for key, desc_list in descriptions.items():
        for i in range(len(desc_list)):
            desc = desc_list[i]
            # tokenize
            desc = desc.split()
            # convert to lower case
            desc = [word.lower() for word in desc]
            # remove punctuation from each token
            desc = [w.translate(table) for w in desc]
            # remove hanging 's' and 'a'
            desc = [word for word in desc if len(word)>1]
            # remove tokens with numbers in them
            desc = [word for word in desc if word.isalpha()]
            # store as string
            desc_list[i] =  ' '.join(desc)

In [12]:
# save descriptions to file, one per line
def save_descriptions(descriptions, filename):
    lines = list()
    for key, desc_list in descriptions.items():
        for desc in desc_list:
            lines.append(key + ' ' + desc)
    data = '\n'.join(lines)
    file = open(filename, 'w')
    file.write(data)
    file.close()

In [13]:
filename = 'Flickr8k.token.txt'
# load descriptions
doc = load_doc(filename)
# parse descriptions
descriptions = load_descriptions(doc)
print('Loaded: %d ' % len(descriptions))
# clean descriptions
clean_descriptions(descriptions)
# save to file
save_descriptions(descriptions, 'descriptions.txt')

Loaded: 8092 
