### Understanding the data

In [None]:
import json
from IPython.display import JSON
import pprint
from collections import defaultdict
import numpy as np
from sklearn.model_selection import train_test_split
from keras.preprocessing import image
from keras.applications.inception_v3 import preprocess_input
import os
from time import time

In [None]:
with open('flickr8k/dataset.json', 'r') as f:
    json_data = json.load(f)

In [None]:
json_data.keys()

In [None]:
json_data['dataset']

In [None]:
json_data['images'][0].keys()

In [None]:
pprint.pprint(json_data['images'][0])

### Ground truth image descriptions

In [None]:
def get_gt_image_descriptions(json_data):
    descriptions = defaultdict(list)
    for jd in json_data['images']:
        fn = jd['filename'].split('.')[0]
        for s in jd['sentences']:
            descriptions[fn].append(s['raw'])
    return descriptions

In [None]:
descriptions = get_gt_image_descriptions(json_data)
sample_image_id = np.random.choice(list(descriptions.keys()))
print('\n'.join(descriptions[sample_image_id]))

### Data cleaning

In [None]:
import string

In [None]:
string.punctuation

In [None]:
def clean_descriptions(descriptions):
    table = str.maketrans('', '', string.punctuation)
    for key, desc_list in descriptions.items():
        for i in range(len(desc_list)):
            desc = desc_list[i]
            # tokenize
            desc = desc.split()
            # convert to lower case
            desc = [w.lower() for w in desc]
            # remove punctuation
            desc = [w.translate(table) for w in desc]
            # remove 'a' and 's'
            desc = [w for w in desc if len(w) > 1]
            # remove tokens with numbers in them
            desc = [w for w in desc if w.isalpha()]
            # store as string
            desc_list[i] = ' '.join(desc)
    return descriptions

In [None]:
clean_descriptions = clean_descriptions(descriptions)

In [None]:
clean_descriptions[sample_image_id]

In [None]:
all_descriptions = [len(clean_descriptions[key]) for key in clean_descriptions.keys()]
print(sum(all_descriptions))

### Create Vocabulary

In [None]:
def create_vocabulary(clean_descriptions):
    #  remove duplicate words (set of unique words)
    vocabulary = set()
    for key in clean_descriptions.keys():
        [vocabulary.update(d.split()) for d in clean_descriptions[key]]
    return vocabulary

In [None]:
vocabulary = create_vocabulary(clean_descriptions)

In [None]:
print('vocabulary size:', len(vocabulary))

In [None]:
def save_descriptions(clean_descriptions, filename):
    lines = []
    for key, desc_list in clean_descriptions.items():
        for desc in desc_list:
            lines.append(key + ' ' + desc)
    with open(filename, 'w') as f:
        for line in lines:
            f.writelines(line)
            f.writelines('\n')

In [None]:
save_descriptions(clean_descriptions, 'descriptions.txt')

### Train and Test Split

In [None]:
train_clean_desc_keys, test_clean_desc_keys = train_test_split(list(clean_descriptions.keys()))
print('train size:', len(train_clean_desc_keys))
print('test size:', len(test_clean_desc_keys))

In [None]:
def load_train_clean_descriptions(train_clean_desc_keys, filename):
    train_clean_descriptions = {}
    with open(filename, 'r') as f:
        for line in f:
            line = line.rstrip()
            tokens = line.split()
            image_id, image_desc = tokens[0], tokens[1:]
            # skip images not in the train set
            if image_id in train_clean_desc_keys:
                if not image_id in train_clean_descriptions:
                    train_clean_descriptions[image_id] = []
                # add start and end token
                desc = 'startseq ' + ' '.join(image_desc) + ' endseq'
                train_clean_descriptions[image_id].append(desc)
    return train_clean_descriptions

In [None]:
train_descriptions = load_train_clean_descriptions(train_clean_desc_keys, 'descriptions.txt')

In [None]:
print('Descriptions: train={}'.format(len(train_descriptions)))

In [None]:
print(train_descriptions[sample_image_id])

#### Feature Vector Extraction

In [None]:
from keras.applications.inception_v3 import InceptionV3
from keras.models import Model

In [None]:
model = InceptionV3(weights='imagenet')

In [None]:
print(model.input)
print(model.layers[-2:])
print(model.layers[-2].output)

In [None]:
model_new = Model(model.input, model.layers[-2].output)

In [None]:
def preprocess(image_path):
    # convert all images to the size 299x299 as expected by the Inception v3
    img = image.load_img(image_path, target_size=(299, 299))
    # convert PIL image to numpy array
    x = image.img_to_array(img)
    # Add one more dimension
    x = np.expand_dims(x, axis=0)
    # preprocess image using preprocess_input from inception_v3 module
    x = preprocess_input(x)
    return x

def encode(img):
    img = preprocess(img)
    feat_vec = model_new.predict(img)
    feat_vec = np.reshape(feat_vec, feat_vec.shape[1])
    return feat_vec

In [None]:
image_dir = 'flickr8k/images/'

In [None]:
start = time()
encoding_train = {}
for base_img_fn in train_clean_desc_keys:
    img_fn = base_img_fn + '.jpg'
    image_file_path = os.path.abspath(os.path.join(image_dir, img_fn))
    if not os.path.exists(image_file_path):
        print('Not found image:', image_file_path)
        continue
    encoding_train[base_img_fn] = encode(os.path.join(image_dir, img_fn))
print('encoding time for train:', time() - start)    

In [None]:
with open("encoded_train_images.pkl", "wb") as f:
  pickle.dump(encoding_train, f)

In [None]:
start = time()
encoding_test = {}
for i, base_img_fn in enumerate(test_clean_desc_keys):
    img_fn = base_img_fn + '.jpg'
    image_file_path = os.path.abspath(os.path.join(image_dir, img_fn))
    if not os.path.exists(image_file_path):
        print('Not found image:', image_file_path)
        continue
    else:
        print('{}: {}'.format(i, img_fn))
    encoding_test[base_img_fn] = encode(os.path.join(image_dir, img_fn))
print('encoding time for test:', time() - start)  

In [None]:
with open("encoded_test_images.pkl", "wb") as f:
  pickle.dump(encoding_test, f)