In [1]:
base_path = r"C:\Users\nbxyz\Desktop\image-caption-generate"

# prepare image input

In [2]:
#import all required libraries
import tensorflow as tf
import numpy as np
from tqdm import tqdm
from tensorflow.keras.preprocessing.image import load_img, img_to_array
#load Inception model for transfer learning
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input

In [3]:
tf.config.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

In [4]:
from tensorflow.keras.models import Model
#get features using inception model
base_model = InceptionV3(weights='imagenet',
                         input_shape=(299, 299, 3), pooling=max)
model = Model(base_model.input, base_model.layers[-2].output)

In [5]:
#add to memory
def process_img(img_path):
    with tf.device('/GPU:0'):
        #load with size 299 x 299 as inception_v2 accept that
        img = load_img(img_path, target_size=(299, 299, 3))
        img_arr = img_to_array(img)
        #expand by a dimension and scale pixels from -1 to 1 range
        img_arr = np.expand_dims(img_arr, axis=0)
        img_arr = preprocess_input(img_arr)
        return img_arr

In [6]:
def encode_feature(img_arr):
    with tf.device('/GPU:0'):
        feature_vec = model.predict(img_arr)
        feature_vec = np.reshape(feature_vec, (feature_vec.shape[1]))
        return feature_vec

# gui for drag and drop image

In [7]:
import sys, os
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QVBoxLayout, QMessageBox
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QPixmap


class ImageLabel(QLabel):
    def __init__(self):
        super().__init__()

        self.setAlignment(Qt.AlignCenter)
        self.setText('\n\n Drop Image Here \n\n')
        self.setStyleSheet('''
            QLabel{
                border: 4px dashed #aaa
            }
        ''')

    def setPixmap(self, image):
        super().setPixmap(image)

class AppDemo(QWidget):
    
    def __init__(self):
        super().__init__()
        self.img_path = ''
        self.resize(400, 400)
        self.setAcceptDrops(True)

        mainLayout = QVBoxLayout()

        self.photoViewer = ImageLabel()
        mainLayout.addWidget(self.photoViewer)

        self.setLayout(mainLayout)

    def dragEnterEvent(self, event):
        if event.mimeData().hasImage:
            event.accept()
        else:
            event.ignore()

    def dragMoveEvent(self, event):
        if event.mimeData().hasImage:
            event.accept()
        else:
            event.ignore()

    def dropEvent(self, event):
        if event.mimeData().hasImage:
            event.setDropAction(Qt.CopyAction)
            file_path = event.mimeData().urls()[0].toLocalFile()
            self.set_path(file_path)
            self.set_image(file_path)

            event.accept()
        else:
            event.ignore()

    def set_image(self, file_path):
        self.photoViewer.setPixmap(QPixmap(file_path))
        
    def set_path(self, file_path):
        self.img_path = file_path
    
    def get_path(self):
        return self.img_path
        
    def quit(self):
        self.close()
        
    def keyPressEvent(self, event):
        if event.key() == Qt.Key_Return:
            self.quit()

In [8]:
def showdialog(image_desc):
    msg = QMessageBox()
    msg.setIcon(QMessageBox.Information)
    msg.setWindowTitle("Image description")
    msg.setText(image_desc)
    msg.exec_()

# Generating Caption

In [9]:
from tensorflow.keras.preprocessing.sequence import pad_sequences
import pickle

In [10]:
#load model
from tensorflow.keras.models import load_model
final_model = load_model(base_path+'\\models\model_18.h5')

In [11]:
wordtoix = pickle.load(open(base_path+'\\wordtoint.pickle', 'rb'))
ixtoword = pickle.load(open(base_path+'\\inttoword.pickle', 'rb'))

In [12]:
max_len = 31

In [13]:
#for one image
def input_img(img_path):
    img_arr = process_img(img_path)
    img_vec = encode_feature(img_arr)
    return img_vec

In [14]:
def greedy_search(img_feature, trained_model, max_len):
    start = 'startseq'
    img_feature = np.expand_dims(img_feature,axis=0)
    for i in range(max_len):
        seq = [wordtoix[word] for word in start.split() if word in wordtoix]
        seq = pad_sequences([seq], maxlen = max_len)
        yhat = trained_model.predict([img_feature, seq])
        yhat = np.argmax(yhat)
        word = ixtoword[yhat]
        start += ' ' + word
        if word == 'endseq':
            break
    final = start.split()
    final = final[1:-1]
    final = ' '.join(final)
    return final

In [15]:
def extract_caption():
    if not QApplication.instance():
        app = QApplication(sys.argv)
    else:
        app = QApplication.instance()
    window = AppDemo()
    window.show()
    app.exec_()
    image_path = window.get_path()
    with tf.device('/GPU:0'):
        img_feature = input_img(image_path)
        # generate description
        image_desc = greedy_search(img_feature, final_model, max_len)
    showdialog(image_desc)

In [41]:
extract_caption()