<a href="https://colab.research.google.com/github/radonys/Image-Captioning/blob/main/Image_Captioning_Flickr8K.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image Captioning using Flickr8K Dataset

### BUDT758 - Big Data and Artificial Intelligence

##### Team Members: Aditya Kismatrao, Manas Mishra, Pratik Pandey & Yash Srivastava

Google Colaboratory Link: https://colab.research.google.com/drive/1ZOA195yLGo2OtVDMALI_pc5mq2lSeMVg?usp=sharing

### Import Modules

In [None]:
import os
import requests
import zipfile
import io
import numpy as np
import pandas as pd
import datetime
import matplotlib.pyplot as plt
from pickle import dump, load

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

import re
from bs4 import BeautifulSoup


import tensorflow as tf
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.models import Model, load_model
from keras.applications.vgg16 import preprocess_input
from keras.preprocessing.text import Tokenizer
from keras.utils import plot_model
from keras.layers import Input
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Embedding
from keras.layers import Dropout
from keras.layers.merge import add
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical
from keras.callbacks import ModelCheckpoint

In [None]:
%load_ext tensorboard
%matplotlib inline

### Flickr8K Dataset

In [None]:
images = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip"
text = "https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip"

if not os.path.exists("/content/FlickrImages/"):

  r = requests.get(images)
  z = zipfile.ZipFile(io.BytesIO(r.content))
  z.extractall("/content/FlickrImages/")

if not os.path.exists("/content/FlickrText/"):

  r = requests.get(text)
  z = zipfile.ZipFile(io.BytesIO(r.content))
  z.extractall("/content/FlickrText/")

### Data Cleaning and Processing

#### Image-Caption Dictionary

In [None]:
captions_path = "/content/FlickrText/Flickr8k.token.txt"

image_captions = dict()

with open(captions_path) as file:

  captions = file.read()
  
  for caption in captions.split('\n'):

    if caption != '':
      
      splits = caption.split(" ")

      image_id = splits[0].split(".")[0]
      string = ' '.join(splits[1:])

      if image_id in image_captions:
        image_captions[image_id].append(string)

      else:
        image_captions[image_id] = [string]

#### Training and Validation Data

In [None]:
def read_data(path):

  data = list()

  with open(path) as file:

    images = file.read()

    for image in images.split("\n"):
      
      if image != '':
        data.append(image.split(".")[0])

  return data

In [None]:
train_path = "/content/FlickrText/Flickr_8k.trainImages.txt"
val_path = "/content/FlickrText/Flickr_8k.devImages.txt"

train_images = read_data(train_path)
val_images = read_data(val_path)

print("Number of Training Samples:", len(train_images))
print("Number of Validation Samples:", len(val_images))

#### Clean Text

- Clean Text by removing bad symbols and stopwords
- Reference: https://github.com/radonys/Reddit-Flair-Detector/blob/master/Jupyter%20Notebooks/Reddit_Flair_Detector.ipynb

In [None]:
REPLACE_BY_SPACE_RE = re.compile('[/(){}\[\]\|@,;]')
BAD_SYMBOLS_RE = re.compile('[^0-9a-z #+_]')
STOPWORDS = set(stopwords.words('english'))

def clean_text(text):
   
    text = BeautifulSoup(text, "lxml").text
    text = text.lower()
    text = REPLACE_BY_SPACE_RE.sub(' ', text)
    text = BAD_SYMBOLS_RE.sub('', text)
    text = ' '.join(word for word in text.split() if word.isalpha())

    return text

for key in image_captions:
  image_captions[key] = [clean_text(caption) for caption in image_captions[key]]

#### Text Tokenizer

In [None]:
def tokenize(captions):

  caption_list = list()

  for key in image_captions:
    for caption in image_captions[key]:
      caption_list.append(caption)

  tokenizer = Tokenizer()
  tokenizer.fit_on_texts(caption_list)

  vocab_size = len(tokenizer.word_index) + 1
  max_length = max(len(caption.split()) for caption in caption_list)

  return tokenizer, vocab_size, max_length

#### Start and End Identifiers for Captions

Start Identifier: "startcap"

End Identifier: "endcap"

In [None]:
def startend(image_captions, keys):

  marked_captions = dict()

  for key in image_captions:

    if key in keys:

      for caption in image_captions[key]:

        caption = 'startcap ' + caption + ' endcap'

        if key in marked_captions:
          marked_captions[key].append(caption)

        else:
          marked_captions[key] = [caption]

  return marked_captions

### Convolutional Neural Network: VGG16

In [None]:
model_cnn = VGG16()
model_cnn = Model(inputs=model_cnn.inputs, outputs=model_cnn.layers[-2].output)

#### Save Train and Validation CNN Features

In [None]:
def preprocess(image_path):

    img = image.load_img(image_path, target_size=(224, 224))
    
    img = image.img_to_array(img)

    x = img.reshape((1, img.shape[0], img.shape[1], img.shape[2]))

    x = preprocess_input(x)

    return x

In [None]:
def encode(image):

    image = preprocess(image)

    features = model_cnn.predict(image)

    return features

In [None]:
images_path = "/content/FlickrImages/Flicker8k_Dataset/"

train_images_encoded = dict()

for img in train_images:
  train_images_encoded[img] = encode(images_path + img + '.jpg')

with open("/content/train_image_features.pkl", "wb") as file:
    dump(train_images_encoded, file)

In [None]:
val_images_encoded = dict()

for img in val_images:
  val_images_encoded[img] = encode(images_path + img + '.jpg')

with open("/content/val_image_features.pkl", "wb") as file:
    dump(val_images_encoded, file)

### Variable Declaration

In [None]:
epochs = 10
batch_size = 32

### Model Definition

In [None]:
def image_captioning(vocab_size, max_length):

  #Feature Extractor (FE)
	input1 = Input(shape=(4096,))
	fe_drop = Dropout(0.5)(input1)
	fe_fc = Dense(256, activation='relu')(fe_drop)
 
  #Sequential Model (Captions)
	input2 = Input(shape=(max_length,))
	se_embed = Embedding(vocab_size, 256, mask_zero=True)(input2)
	se_drop = Dropout(0.5)(se_embed)
	se_lstm = LSTM(256)(se_drop)
 
  #Combine Features
	combine = add([fe_fc, se_lstm])
	combine_fc = Dense(256, activation='relu')(combine)
	output = Dense(vocab_size, activation='softmax')(combine_fc)
 
	model = Model(inputs=[input1, input2], outputs=output)
	
	return model

### Data Generator

In [None]:
def data_generator(captions, images, tokenizer, vocab_size, max_length, batch_size):
	
  X1, X2, y = list(), list(), list()
  counter = 0
  
  while 1:

    for key in captions:
      
      counter += 1

      for caption in captions[key]:

        caption_sequence = tokenizer.texts_to_sequences([caption])[0]
        
        for i in range(1, len(caption_sequence)):
          
          input_seq, output_seq = caption_sequence[:i], caption_sequence[i]
        
          input_seq = pad_sequences([input_seq], maxlen=max_length)[0]
          output_seq = to_categorical([output_seq], num_classes=vocab_size)[0]

          X1.append(images[key][0])
          X2.append(input_seq)
          y.append(output_seq)

      if counter == batch_size:

        yield ([np.array(X1), np.array(X2)], np.array(y))

        X1, X2, y = list(), list(), list()
        counter = 0

### Model Compilation and Train/Validation Data

In [None]:
#train_image_features = load(open("/content/train_image_features.pkl", "rb"))
train_captions = startend(image_captions, train_images)

tokenizer, vocab_size, max_length = tokenize(train_captions)
 
#val_image_features = load(open("/content/val_image_features.pkl", "rb"))
val_captions = startend(image_captions, val_images)

In [None]:
model = image_captioning(vocab_size, max_length)

model.compile(loss='categorical_crossentropy', optimizer='adam')

### Model Training and Validation

In [None]:
train_steps = len(train_captions)//batch_size
val_steps = len(val_captions)//batch_size

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

for epoch in range(0, epochs):

  print("Epoch:", epoch)

  train_generator = data_generator(train_captions, train_images_encoded, tokenizer, vocab_size, max_length, batch_size)
  val_generator = data_generator(val_captions, val_images_encoded, tokenizer, vocab_size, max_length, batch_size)

  model.fit(train_generator, epochs=1, verbose=1, steps_per_epoch=train_steps, validation_data=val_generator, validation_steps=val_steps, callbacks=[tensorboard_callback])

In [None]:
model.save("model_vgg.h5")

In [None]:
%tensorboard --logdir logs

### Model Output

In [None]:
def int_to_word(integer, tokenizer):
  
  for word, index in tokenizer.word_index.items():
   
   if index == integer:
     return word
	
  return None

In [None]:
def model_output(image_path, model, tokenizer, max_length):
  
  caption = 'startcap'
  image_feature = encode(image_path)
  
  for i in range(0, max_length):
   
    sequence = tokenizer.texts_to_sequences([caption])[0]

    sequence = pad_sequences([sequence], maxlen=max_length)

    y = model.predict([image_feature, sequence], verbose=0)
    y = np.argmax(y)

    word = int_to_word(y, tokenizer)

    if word is None:
      break

    caption += ' ' + word

    if word == 'endcap':
      break

  return caption

In [None]:
model = load_model('/content/model_vgg.h5')

In [None]:
image_path = "/content/FlickrImages/Flicker8k_Dataset/1000268201_693b08cb0e.jpg"
caption = model_output(image_path, model, tokenizer, max_length)

caption = caption.split(" ")[1:-1]
caption = " ".join(caption)

x = plt.imread(image_path)
plt.imshow(x)
print("Predicted Caption:", caption)
print("Actual Caption(s):", image_captions[image_path.split("/")[-1].split(".")[0]])