In [6]:
# project path
prpath = r'C:\Users\nbxyz\Desktop\image-caption-generate'

# prepare image input

In [1]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from keras.preprocessing.sequence import pad_sequences

def extract_features(filename, model):
        try:
            image = Image.open(filename)
        except:
            print("ERROR: Couldn't open image! Make sure the image path and extension is correct")
        image = image.resize((299,299))
        image = np.array(image)
        # for images that has 4 channels, we convert them into 3 channels
        if image.shape[2] == 4: 
            image = image[..., :3]
        image = np.expand_dims(image, axis=0)
        image = image/127.5
        image = image - 1.0
        feature = model.predict(image)
        return feature
def word_for_id(integer, tokenizer):
    for word, index in tokenizer.word_index.items():
        if index == integer:
            return word
    return None
def generate_desc(model, tokenizer, photo, max_length):
    in_text = 'start'
    for i in range(max_length):
        sequence = tokenizer.texts_to_sequences([in_text])[0]
        sequence = pad_sequences([sequence], maxlen=max_length)
        pred = model.predict([photo,sequence], verbose=0)
        pred = np.argmax(pred)
        word = word_for_id(pred, tokenizer)
        if word is None:
            break
        in_text += ' ' + word
        if word == 'end':
            break
    return in_text

# gui for drag and drop image

In [2]:
import sys, os
from PyQt5.QtWidgets import QApplication, QWidget, QLabel, QVBoxLayout, QMessageBox
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QPixmap


class ImageLabel(QLabel):
    def __init__(self):
        super().__init__()

        self.setAlignment(Qt.AlignCenter)
        self.setText('\n\n Drop Image Here \n\n')
        self.setStyleSheet('''
            QLabel{
                border: 4px dashed #aaa
            }
        ''')

    def setPixmap(self, image):
        super().setPixmap(image)

class AppDemo(QWidget):
    
    def __init__(self):
        super().__init__()
        self.img_path = ''
        self.resize(400, 400)
        self.setAcceptDrops(True)

        mainLayout = QVBoxLayout()

        self.photoViewer = ImageLabel()
        mainLayout.addWidget(self.photoViewer)

        self.setLayout(mainLayout)

    def dragEnterEvent(self, event):
        if event.mimeData().hasImage:
            event.accept()
        else:
            event.ignore()

    def dragMoveEvent(self, event):
        if event.mimeData().hasImage:
            event.accept()
        else:
            event.ignore()

    def dropEvent(self, event):
        if event.mimeData().hasImage:
            event.setDropAction(Qt.CopyAction)
            file_path = event.mimeData().urls()[0].toLocalFile()
            self.set_path(file_path)
            self.set_image(file_path)

            event.accept()
        else:
            event.ignore()

    def set_image(self, file_path):
        self.photoViewer.setPixmap(QPixmap(file_path))
        
    def set_path(self, file_path):
        self.img_path = file_path
    
    def get_path(self):
        return self.img_path
        
    def quit(self):
        self.close()
        
    def keyPressEvent(self, event):
        if event.key() == Qt.Key_Return:
            self.quit()

In [3]:
def showdialog(image_desc):
    msg = QMessageBox()
    msg.setIcon(QMessageBox.Information)
    msg.setWindowTitle("Image description")
    msg.setText(image_desc)
    msg.exec_()

# run model

In [4]:
import tensorflow as tf
from keras.applications.xception import Xception
from pickle import load
from keras.models import load_model

def extract_caption():
    if not QApplication.instance():
        app = QApplication(sys.argv)
    else:
        app = QApplication.instance()
    window = AppDemo()
    window.show()
    app.exec_()
    image_path = window.get_path()
    with tf.device('/GPU:0'):
        max_length = 32
        tokenizer = load(open(prpath + "\\tokenizer.pkl","rb"))
        model = load_model(prpath + '\model\\model_9.h5')
        # load and prepare the photograph
        xception_model = Xception(include_top=False, pooling="avg")
        photo = extract_features(image_path, xception_model)
        # generate description
        image_desc = generate_desc(model, tokenizer, photo, max_length)
    showdialog(image_desc)

In [None]:
extract_caption()