# Import

## Modules

In [None]:
import os
import pickle
import numpy as np
from tqdm.notebook import tqdm # ui for data processing
import tarfile
from google.colab import drive
import glob
import shutil
import re
import string
import zipfile
import requests
import random

import tensorflow as tf
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, add
from tensorflow.keras.layers import TextVectorization

In [None]:
# mount Google Drive
drive.mount('/content/gdrive')

# Setup Global Variables

In [None]:
# Google Drive paths
GDRIVE_WORKING_PATH = "/content/gdrive/MyDrive/image_cap"

# Rerun models or imports, if False pre-saved files and directories are used
RERUN = False

## Dataset

Download dataset and export to Google Drive only **Images** and **Captions** data.

Recursive function to flatten the directories --> https://gist.github.com/oatkiller/4429244

In [None]:
def currentDir(parent):
    parentDir = parent
    currentDir = parentDir

    dir(parentDir,currentDir)

def dir(parentDir,currentDir):
    files = os.listdir(currentDir)

    for file in files:
        joinedFile = os.path.join(currentDir,file)

        if os.path.isdir(joinedFile):
            dir(parentDir,joinedFile)
        else:
            if parentDir != currentDir:
                try:
                    shutil.move(joinedFile,parentDir)
                except shutil.Error:
                    os.rename(joinedFile,os.path.join(parentDir,file))

    if parentDir != currentDir:
        os.rmdir(currentDir)

In [None]:
if False:
  url = "http://www-i6.informatik.rwth-aachen.de/imageclef/resources/iaprtc12.tgz"
  target_path = 'iaprtc12.tgz'

  response = requests.get(url, stream=True)
  if response.status_code == 200:
      with open(target_path, 'wb') as f:
          f.write(response.raw.read())

  # extract files
  file = tarfile.open('/content/iaprtc12.tgz')
  file.extractall('')
  file.close()

  # flatten images directory
  currentDir(f"/content/iaprtc12/images")
  currentDir(f"/content/iaprtc12/annotations_complete_eng")

  # save flattened images and captions folder to drive
  shutil.copytree(f"/content/iaprtc12/images", f"{GDRIVE_WORKING_PATH}/iaprtc12/images")
  shutil.copytree(f"/content/iaprtc12/annotations_complete_eng", f"{GDRIVE_WORKING_PATH}/iaprtc12/captions")

  # clean content
  shutil.rmtree("/content/iaprtc12")
  os.remove("/content/iaprtc12.tgz")

In [None]:
IMAGES_DIR = os.path.join(GDRIVE_WORKING_PATH, "iaprtc12/images") # directory that contains all the images
CAPTION_DIR = os.path.join(GDRIVE_WORKING_PATH, "iaprtc12/captions") # directory that contains all the captions

# Dataset Cleaning

Keep only images that have a caption and viceversa, by checking the file names.

In [None]:
def rem_type(s):
  return s.split(".")[0]

In [None]:
list_ids_image = os.listdir(os.path.join(GDRIVE_WORKING_PATH, "iaprtc12/images"))
list_ids_image[:] = map(rem_type, list_ids_image)

list_ids_caption = os.listdir(os.path.join(GDRIVE_WORKING_PATH, "iaprtc12/captions"))
list_ids_caption[:] = map(rem_type, list_ids_caption)

In [None]:
len(list_ids_image)

20000

In [None]:
len(list_ids_caption)

20000

In [None]:
list_ids_image[:5] == list_ids_caption[:5]

False

In [None]:
list(set(list_ids_image).symmetric_difference(list_ids_caption))

['4092', '4072']

In [None]:
full_list = list(set(list_ids_caption) & set(list_ids_image))
len(full_list)

19999

Check if the ids are unique

In [None]:
len(set(full_list))

19999

Of these remaining tuple we should check if the description of the image is not an empty string or a short/malformed string.

In [None]:
list_captions = {}
for id in tqdm(full_list):
  # load the image from file
  caption_path = f"{CAPTION_DIR}/{id}.eng"
  # <DESCRIPTION> Caption </DESCRIPTION>
  with open(caption_path, "r", encoding='latin-1') as f:
    caption_str = f.read().replace("\n", "")
    result = re.search('<DESCRIPTION>(.*)</DESCRIPTION>', caption_str)
    caption = result.group(1)
    list_captions[f"{id}"] = caption

  0%|          | 0/19999 [00:00<?, ?it/s]

In [None]:
pickle.dump(list_captions, open(f"{GDRIVE_WORKING_PATH}/mapping_id_caption.pkl", "wb"))

Remove all the captions with less than 3 words

In [None]:
# [element for k, element in list_captions.items() if len(element.split(" ")) < 5]
to_remove_ids = []
for k, v in list_captions.items():
  if len(v.split(" ")) < 3:
    to_remove_ids.append(k)

In [None]:
list_ids = list(set(full_list) - set(to_remove_ids))
len(list_ids)

19983

Remove captions from the mapping

In [None]:
with open(f"{GDRIVE_WORKING_PATH}/mapping_id_caption.pkl", "rb") as f:
  mapping_caption_id = pickle.load(f)

for id in to_remove_ids:
  del mapping_caption_id[id]

pickle.dump(list_captions, open(f"{GDRIVE_WORKING_PATH}/mapping_id_caption.pkl", "wb"))

# Dataset Split

In [None]:
list_ids[:10]

['8390',
 '11582',
 '14421',
 '19338',
 '7937',
 '14260',
 '17803',
 '37429',
 '24007',
 '9182']

In [None]:
random.seed(1)
random.shuffle(list_ids)

In [None]:
list_ids[:10]

['32811',
 '30644',
 '24393',
 '1176',
 '13858',
 '38876',
 '40226',
 '10304',
 '39401',
 '895']

In [None]:
pickle.dump(list_ids, open(f"{GDRIVE_WORKING_PATH}/list_id_fixed_order.pkl", "wb"))

Check if the dump and load process impacts the order

In [None]:
with open(f"{GDRIVE_WORKING_PATH}/list_id_fixed_order.pkl", "rb") as f:
  loaded_list = pickle.load(f)

In [None]:
loaded_list[:10]

['32811',
 '30644',
 '24393',
 '1176',
 '13858',
 '38876',
 '40226',
 '10304',
 '39401',
 '895']

In [None]:
train_start = 0
train_end = int(np.floor(len(list_ids) * 0.7))
print(f"Train --> start index {train_start}, end index {train_end}")
test_start = train_end
test_end = int(np.floor(len(list_ids) * 0.9))
print(f"Test --> start index {test_start}, end index {test_end}")
val_start = test_end
val_end = len(list_ids)
print(f"Val --> start index {val_start}, end index {val_end}")

Train --> start index 0, end index 13988
Test --> start index 13988, end index 17984
Val --> start index 17984, end index 19983


In [None]:
train_ids = list_ids[train_start:train_end] # 13988 values (70%)
test_ids = list_ids[test_start:test_end] # 3996 values (20%)
val_ids = list_ids[val_start:val_end] # 1999 values (10%)
print(f"Train has length {len(train_ids)} --> {np.round(len(train_ids)/len(list_ids)*100, 4)}% --> almost 70%")
print(f"Test has length {len(test_ids)} --> {np.round(len(test_ids)/len(list_ids)*100, 4)}% --> almost 20%")
print(f"Val has length {len(val_ids)} --> {np.round(len(val_ids)/len(list_ids)*100, 4)}% --> almost 10%")
print(f"Total length is {len(list_ids)}")

Train has length 13988 --> 69.9995% --> almost 70%
Test has length 3996 --> 19.997% --> almost 20%
Val has length 1999 --> 10.0035% --> almost 10%
Total length is 19983


In [None]:
pickle.dump(train_ids, open(f"{GDRIVE_WORKING_PATH}/list_id_train.pkl", "wb"))
pickle.dump(test_ids, open(f"{GDRIVE_WORKING_PATH}/list_id_test.pkl", "wb"))
pickle.dump(val_ids, open(f"{GDRIVE_WORKING_PATH}/list_id_val.pkl", "wb"))