In [3]:
# if running on cpu
!pip install numpy pillow tqdm sentence_transformers faiss-cpu pandas scikit-image
# -- OR -- 
# if running on gpu
#!pip install numpy pillow tqdm sentence_transformers faiss-gpu



In [5]:
import os, glob, numpy as np, pandas as pd
from PIL import Image
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
import faiss
from skimage.metrics import structural_similarity as ssim

'''
This notebook scans images and uses a pretrained CLIP-ViT-B-32 model to create image embeddings, compares them to the reference
embeddings index created in the 01-create-reference-embeddings.ipynb notebook, then loads the images into a clean GUI to enable
Human in the Loop (HITL) side-by-side image review.  This brings the Augmented AI project full circle.

###

This particular cell performs the AI analysis component of this Augmented AI project.  It scans images and uses a 
pretrained CLIP-ViT-B-32 model to create image embeddings, then it compares these scanned embeddings to an FAISS 
(Facebook AI Similarity Search) HNSW (Hierarchical Navigable Small World) index built using embeddings created by 
running a collection of reference images through the same pretrained CLIP model via the 
01-create-reference-embeddings.ipynb notebook.  

For similarity calculations, this script uses a combination of cosine similarity and SSIM (structural similarity index) 
thumbnail comparison to identify images which are probable matches.  Through fine-tuning the configuration values, this
comparison method can eliminate false-negatives completely and dramatically reduce false-positives.  This comparison 
approach also computes much more quicly than other methods (e.g. ORB).  

The output file from this script only contains information for the input images which are determined to be reasonably
similar to one of the reference images based on the decision thresholds.  
'''

# Retrieval / decision settings
K = 12                              # top-k candidates retrieved from FAISS per scanned image
EF_SEARCH = 128                     # HNSW search beam: Higher = more recall
COS_SIM_THRESHOLD = 0.92            # cosine similarity acceptance threshold
COS_SIM_THRESH_FOR_FALLBACK = 0.8   # minimum cosine similarity to be considered for SSIM override
SSIM_THRESHOLD = 0.75               # SSIM acceptance threshold for SSIM override

# input images filepath
FOLDER_FOR_IMAGES_TO_ANALYZE = "./input_channel/input_images"
# output filepath
OUTPUT_FILEPATH = "./ai_image_analysis/image_mapping.csv"

# reference
INDEX_PATH = "./reference_embeddings/faiss_hnsw.index"
REF_IMAGE_PATHS = "./reference_embeddings/ref_image_paths.npy"

# Load the model
model = SentenceTransformer('clip-ViT-B-32')
model.eval()

# Image types
IMG_EXTS = (".jpg",".jpeg",".png",".webp",".bmp",".tif",".tiff")


def load_image_paths(root):

    return [p for p in glob.glob(os.path.join(root, "**/*"), recursive=True)
            if p.lower().endswith(IMG_EXTS)]


def pil_open_rgb(p):

    try:
        return Image.open(p).convert("RGB")
    except:
        print(p)
        return None


# SSIM thumbnail comparison
def fast_ssim(pil_a, pil_b, size=(64,64)):

    A = pil_a.resize(size, Image.BICUBIC).convert("L")
    B = pil_b.resize(size, Image.BICUBIC).convert("L")

    # Convert to numpy uint8
    a = np.asarray(A, dtype=np.uint8)
    b = np.asarray(B, dtype=np.uint8)

    # Use gaussian_weights for a little robustness
    score, _ = ssim(a, b, full=True, gaussian_weights=True, use_sample_covariance=False)

    return float(score)

    
# Load FAISS HNSW index & reference image paths
index = faiss.read_index(INDEX_PATH)
index.hnsw.efSearch = EF_SEARCH
ref_image_paths = np.load(REF_IMAGE_PATHS, allow_pickle=True)

# Load the images to be analyzed
scan_imgs = load_image_paths(FOLDER_FOR_IMAGES_TO_ANALYZE)
results = []

for full_fpath in tqdm(scan_imgs, desc="Scanning images"):
    split_fname = full_fpath.split("\\")
    num_elements_split_fname = len(split_fname)
    fname = split_fname[num_elements_split_fname - 1]
    split_fpath = full_fpath.split("/")
    num_elements_split_fpath = len(split_fpath)
    partial_fpath = split_fpath[num_elements_split_fpath - 1]
    input_img = pil_open_rgb(full_fpath)
    if input_img is None:
        continue

    # batch_imgs can be a list of PIL Images, NumPy arrays, or image file paths
    input_img_vec = model.encode(
        input_img,
        convert_to_numpy=True,
        normalize_embeddings=True,  # L2 normalize
        show_progress_bar=False,
        batch_size=32,  # adjust to your GPU/CPU memory
    ).astype("float32").reshape(1, -1)

    # Retrieve top-k candidates from the image catalog
    D, I = index.search(input_img_vec, K)  # inner product on normalized vectors == cosine similarity
    sims = D[0]
    idxs = I[0]

    # Pick the best candidate that satisfies the similarity threshold
    best = None
    best_sim = -1.0

    for sim, cand_idx in zip(sims, idxs):
        if cand_idx < 0:
            continue
        cand_path = str(ref_image_paths[cand_idx])

        # Calculate thumbnail SSIM score
        cand_img = pil_open_rgb(cand_path)
        ssim_score = fast_ssim(input_img, cand_img)

        split_cand_fname = cand_path.split("\\")
        num_elements_split_cand_fname = len(split_cand_fname)
        cand_fname = split_cand_fname[num_elements_split_cand_fname - 1]

        if sim >= COS_SIM_THRESHOLD and sim >= best_sim:
            # Accept on similarity alone
            best = (cand_fname, float(sim), ssim_score)
            best_sim = float(sim)

        # SSIM fallback for near-threshold cases
        if best is None and sim >= COS_SIM_THRESH_FOR_FALLBACK:
            if cand_img is None:
                continue
            if ssim_score >= SSIM_THRESHOLD:
                best = (cand_fname, float(sim), ssim_score)
                best_sim = float(sim)

    if best is not None:
        results.append({
            "input_filename": fname,
            "match_filename": best[0],
            "cosine_similarity": best[1],
            "thumbnail_ssim": best[2],
            "input_filepath": partial_fpath
        })

df = pd.DataFrame(results)
df.to_csv(OUTPUT_FILEPATH, index=False)
print(f"Wrote {OUTPUT_FILEPATH} with {len(df)} rows")


Loading weights:   0%|          | 0/398 [00:00<?, ?it/s]

CLIPModel LOAD REPORT from: C:\Users\tev\.cache\huggingface\hub\models--sentence-transformers--clip-ViT-B-32\snapshots\327ab6726d33c0e22f920c83f2ff9e4bd38ca37f\0_CLIPModel
Key                                  | Status     |  | 
-------------------------------------+------------+--+-
text_model.embeddings.position_ids   | UNEXPECTED |  | 
vision_model.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
Scanning images: 100%|█████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00,  7.26it/s]

Wrote ./output_image_analysis/image_mapping.csv with 16 rows





In [7]:
import os
import pandas as pd
from tqdm import tqdm
import shutil

'''
This cell acts as the bridge between the AI side of this project and the Human in the Loop (HITL). It copies the 
images which the previous cell (via the CLIP ViT model) determined to be possible matches into a folder that
will be used by the next cell to present images to a human user for final visual review.   
'''

input_channel_dir = "./input_channel"
possible_matches_input_file = "./ai_image_analysis/image_mapping.csv"
input_image_dir = "input_images"
destination_image_dir = "./possible_matching_images"
input_image_folder_name = input_image_dir

# Create destination folder
if not os.path.exists(destination_image_dir):
    os.makedirs(destination_image_dir)
    print(f"Directory '{destination_image_dir}' created.")

# reading input file
data = pd.read_csv(possible_matches_input_file, encoding='utf-8')
data = data.squeeze()

base_image_counts = {}
for class_folder in os.listdir(input_channel_dir):
    print(class_folder)
    class_path = os.path.join(input_channel_dir, class_folder)
    if os.path.isdir(class_path):
        num_images = len(os.listdir(class_path))
        base_image_counts[class_folder] = num_images

image_counter = 0
images_copied_counter = 0
is_end_of_files = False

input_image_dir_list = os.listdir(input_channel_dir + "/" + input_image_dir)
input_image_dir_list.sort()  # otherwise lowercase and uppercase would get sorted differently from ImageFolder

for image in tqdm(input_image_dir_list):

    if image_counter >= (len(input_image_dir_list)):
        is_end_of_files = True
        print("end of files")

    if is_end_of_files:
        break

    fname = image

    # if there is only a single filename in the input file, data will be a string instead of a pandas series
    if type(data) == str:

        if data == fname:

            input_fpath = input_channel_dir + "/" + input_image_dir + "/" + fname
            output_fpath = destination_image_dir + "/" + fname

            try:
                shutil.copy2(input_fpath, output_fpath)
                images_copied_counter += 1

            except OSError:
                print("Problem copying: {}".format(input_fpath))

    # otherwise data will be a pandas series and the following code will manage multiple images
    else:

        for name in data["input_filename"]:

            if name == fname:

                input_fpath = input_channel_dir + "/" + input_image_dir + "/" + fname
                output_fpath = destination_image_dir + "/" + fname

                try:
                    shutil.copy2(input_fpath, output_fpath)
                    images_copied_counter += 1

                except OSError:
                    print("Problem copying: {}".format(input_fpath))

    image_counter += 1

print("\nSuccess!\n")


input_images


100%|████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 1066.41it/s]


Success!






In [13]:
from datetime import datetime
import tkinter as tk
from tkinter.ttk import Progressbar
from PIL import ImageTk, Image
from threading import Thread
import pandas as pd

'''
This cell closes the loop on the this HITL AI Image Retrieval project. It loads the images which were identified earlier 
in this notebook as possible matches alongside the best matching reference image an easy to use GUI that 
allows a human user to determine whether the input images are true matches.  The feedback from the user is then saved
as an output file which allows downstream workflows to locate the matching files.  
'''

start_image_counter = 1

input_channel_dir = "./input_channel"
input_image_dir = "input_images"  # NEED THE ACTUAL FOLDER -- NOT A NESTING INPUT FOLDER
reference_image_dir = "./reference_images" # NEED THE ACTUAL FOLDER -- NOT A NESTING INPUT FOLDER
input_image_fpath = input_channel_dir + "/" + input_image_dir
possible_matches_input_file = "./ai_image_analysis/image_mapping.csv"
output_dir = "./hitl_output"

image_counter = start_image_counter - 1
visual_review_image_filenames = []
visual_review_feedback = []
visual_review_image_filepaths = []
visual_review_matching_filenames = []

input_data = pd.read_csv(possible_matches_input_file)
input_image_filenames = input_data['input_filename'].tolist()
match_image_filenames = input_data['match_filename'].tolist()
input_image_filepaths = input_data['input_filepath'].tolist()

finished_pruning_filenames = False

while not finished_pruning_filenames:

    highest_index_so_far = 0

    for i in range(len(match_image_filenames)):

        highest_index_so_far = i

        if type(match_image_filenames[i]) == float:

            input_image_filenames.pop(i)
            match_image_filenames.pop(i)
            input_image_filepaths.pop(i)
            break

    if highest_index_so_far == len(match_image_filenames) - 1:
        finished_pruning_filenames = True

# # IF THE FILENAMES DON'T HAVE "small_variant" in front...
# for i in range(len(match_image_filenames)):
#     match_image_filenames[i] = "small_variant_" + str(match_image_filenames[i])

num_input_files = len(input_image_filenames)

"""
Configuring GUI
"""

window = tk.Tk()
window.title("Augmented AI Review")

window.rowconfigure(0, minsize=50, weight=0)
window.rowconfigure(1, minsize=50, weight=0)
window.columnconfigure(0, minsize=600, weight=1)

frm_greeting = tk.Frame(window, bd=0)
frm_progress = tk.Frame(window, bd=0)
frm_finished = tk.Frame(window, bd=0)
frm_image_banner = tk.Frame(window, bd=0)
frm_image_display = tk.Frame(window, bd=0)
frm_user_input_button = tk.Frame(window, bd=0)
frm_user_input_processing = tk.Frame(window, bd=0)
frm_counter = tk.Frame(window, bd=0)

frm_greeting.rowconfigure(0, minsize=50, weight=0)
frm_greeting.columnconfigure(0, minsize=400, weight=1)

frm_progress.rowconfigure(0, minsize=50, weight=0)
frm_progress.columnconfigure(0, minsize=400, weight=1)

frm_finished.rowconfigure(0, minsize=50, weight=0)
frm_finished.columnconfigure(0, minsize=400, weight=1)

frm_image_banner.rowconfigure(0, minsize=50, weight=0)
frm_image_banner.columnconfigure(0, minsize=50, weight=1)
frm_image_banner.columnconfigure(1, minsize=250, weight=0)
frm_image_banner.columnconfigure(2, minsize=250, weight=0)
frm_image_banner.columnconfigure(3, minsize=50, weight=1)

frm_image_display.rowconfigure(0, minsize=250, weight=0)
frm_image_display.columnconfigure(0, minsize=50, weight=10)
frm_image_display.columnconfigure(1, minsize=250, weight=0)
frm_image_display.columnconfigure(2, minsize=250, weight=0)
frm_image_display.columnconfigure(3, minsize=50, weight=10)

frm_user_input_button.rowconfigure(0, minsize=50, weight=0)
frm_user_input_button.rowconfigure(1, minsize=15, weight=0)
frm_user_input_button.columnconfigure(0, minsize=100, weight=1)
frm_user_input_button.columnconfigure(1, minsize=90, weight=0)
frm_user_input_button.columnconfigure(2, minsize=40, weight=0)
frm_user_input_button.columnconfigure(3, minsize=90, weight=0)
frm_user_input_button.columnconfigure(4, minsize=90, weight=1)
frm_user_input_button.columnconfigure(5, minsize=90, weight=0)
frm_user_input_button.columnconfigure(6, minsize=100, weight=1)

frm_counter.rowconfigure(0, minsize=50, weight=0)
frm_counter.columnconfigure(0, minsize=150, weight=1)
frm_counter.columnconfigure(1, minsize=40, weight=0)
frm_counter.columnconfigure(2, minsize=20, weight=0)
frm_counter.columnconfigure(3, minsize=40, weight=0)
frm_counter.columnconfigure(4, minsize=150, weight=1)

lbl_greeting = tk.Label(frm_greeting,
                        text="Review images below and click the appropriate button.",
                        font=("Arial", 14))
lbl_greeting.grid(row=0, column=0, sticky="nsew", padx=25, pady=25)

progressbar = Progressbar(frm_progress, mode="indeterminate")

lbl_finished = tk.Label(frm_finished, text="Processing completed!  Output file saved in Output_Files folder.",
                        foreground="green", font=("Arial", 12))

lbl_input_image = tk.Label(frm_image_banner, text="Input Image", font=("Arial", 12))
lbl_input_image.grid(row=0, column=1, sticky="nsew", padx=15, pady=15)
lbl_match_image = tk.Label(frm_image_banner, text="Possible Match", font=("Arial", 12))
lbl_match_image.grid(row=0, column=2, sticky="nsew", padx=15, pady=15)

input_image = ImageTk.PhotoImage(Image.open(input_image_fpath + "/" + input_image_filenames[image_counter]).resize((250, 250)))
input_image_panel = tk.Label(frm_image_display, image=input_image)
input_image_panel.grid(row=0, column=1, sticky="nsew", padx=15, pady=15)

try:
    match_image = ImageTk.PhotoImage(Image.open(reference_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize((250, 250)))
except:
    match_image = ImageTk.PhotoImage(Image.open(reference_image_dir + "/" + match_image_filenames[image_counter]).resize((250, 250)))
match_image_panel = tk.Label(frm_image_display, image=match_image)
match_image_panel.grid(row=0, column=2, sticky="nsew", padx=15, pady=15)

counter_current = tk.Label(frm_counter, text="1", font=("Arial", 12))
counter_current.grid(row=0, column=1, sticky="new", padx=5, pady=5)
counter_of = tk.Label(frm_counter, text="of", font=("Arial", 12))
counter_of.grid(row=0, column=2, sticky="new", padx=5, pady=5)
counter_total = tk.Label(frm_counter, text=num_input_files, font=("Arial", 12))
counter_total.grid(row=0, column=3, sticky="new", padx=5, pady=5)

btn_back = tk.Button(frm_user_input_button, text="Back", border=4)
btn_back.grid(row=0, column=1, sticky="nsew", padx=10, pady=5)

btn_match = tk.Button(frm_user_input_button, text="Match", border=4)
btn_match.grid(row=0, column=3, sticky="nsew", padx=10, pady=5)

btn_not_match = tk.Button(frm_user_input_button, text="Not a Match", border=4)
btn_not_match.grid(row=0, column=4, sticky="nsew", padx=10, pady=5)

btn_save = tk.Button(frm_user_input_button, text="Save File", border=4)
btn_save.grid(row=0, column=5, sticky="nsew", padx=10, pady=5)


def go_back():
    """
    Predicts the PIM Group for product data entered through GUI from user.

    :return:
    """

    global window

    window.update_idletasks()


def process_go_back():
    """
    Prepares data entered by user through the GUI and calls the predict_from_input
    function in a separate thread.

    :return:
    """

    global window
    global btn_back
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter

    window.update_idletasks()

    frm_finished.grid_remove()

    if image_counter > 0:

        btn_back["state"] = "disabled"
        btn_match["state"] = "disabled"
        btn_not_match["state"] = "disabled"
        btn_save["state"] = "disabled"

        image_counter -= 1

        t = Thread(target=go_back)
        t.start()

        schedule_check_go_back(t)

    return


def schedule_check_go_back(thread):
    """
    Schedules the execution of the check_if_done_input function each
    second.
    """

    global window

    window.after(200, check_if_done_go_back, thread)


def check_if_done_go_back(thread):
    """
    Checks to see if the thread on which the predict_from_input function is
    running and handles actions to be executed upon completion.

    :param thread: The thread on which the predict_from_input function is running
    :return:
    """

    global window
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter
    global input_image_filenames
    global input_image_filepaths
    global input_image
    global match_image
    global input_image_panel
    global match_image_panel
    global counter_current

    if not thread.is_alive():

        progressbar.stop()
        frm_progress.grid_remove()
        frm_finished.grid(row=2, column=0, sticky="nsew")

        if image_counter > 0:
            # pop the user feedback for the previous image pair from the filenames and feedback lists
            visual_review_image_filenames.pop(image_counter + 1 - start_image_counter)  # handles starting with different image
            visual_review_feedback.pop(image_counter - start_image_counter)  # handles starting with different image
            visual_review_image_filepaths.pop(image_counter + 1 - start_image_counter)  # handles starting with different image
            visual_review_matching_filenames.pop(image_counter + 1 - start_image_counter)  # handles starting with different image

        if image_counter == num_input_files:
            process_save_progress()

        if not image_counter >= num_input_files:
            btn_back["state"] = "normal"
            btn_match["state"] = "normal"
            btn_not_match["state"] = "normal"
            btn_save["state"] = "normal"

            input_image = ImageTk.PhotoImage(
                Image.open(input_image_fpath + "/" + input_image_filenames[image_counter]).resize((250, 250)))
            input_image_panel = tk.Label(frm_image_display, image=input_image)
            input_image_panel.grid(row=0, column=1, sticky="nsew", padx=15, pady=15)
            try:
                match_image = ImageTk.PhotoImage(
                    Image.open(reference_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize(
                        (250, 250)))
            except:
                match_image = ImageTk.PhotoImage(
                    Image.open(reference_image_dir + "/" + match_image_filenames[image_counter]).resize((250, 250)))
            # match_image = ImageTk.PhotoImage(
            #     Image.open(match_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize((250, 250)))
            match_image_panel = tk.Label(frm_image_display, image=match_image)
            match_image_panel.grid(row=0, column=2, sticky="nsew", padx=15, pady=15)
            counter_current = tk.Label(frm_counter, text=str(image_counter + 1), font=("Arial", 12))
            counter_current.grid(row=0, column=1, sticky="new", padx=5, pady=5)

    else:
        # Otherwise check again after .2 seconds.
        schedule_check_go_back(thread)


def flag_match():
    """
    Predicts the PIM Group for product data entered through GUI from user.

    :return:
    """

    global window

    window.update_idletasks()


def process_flag_match():
    """
    Prepares data entered by user through the GUI and calls the predict_from_input
    function in a separate thread.

    :return:
    """

    global window
    global btn_back
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter

    window.update_idletasks()

    frm_finished.grid_remove()

    btn_back["state"] = "disabled"
    btn_match["state"] = "disabled"
    btn_not_match["state"] = "disabled"
    btn_save["state"] = "disabled"

    image_counter += 1

    t = Thread(target=flag_match)
    t.start()

    schedule_check_flag_match(t)

    return


def schedule_check_flag_match(thread):
    """
    Schedules the execution of the check_if_done_input function each
    second.
    """

    global window

    window.after(200, check_if_done_flag_match, thread)


def check_if_done_flag_match(thread):
    """
    Checks to see if the thread on which the predict_from_input function is
    running and handles actions to be executed upon completion.

    :param thread: The thread on which the predict_from_input function is running
    :return:
    """

    global window
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter
    global input_image_filenames
    global input_image_filepaths
    global input_image
    global input_image_panel
    global match_image
    global match_image_panel

    if not thread.is_alive():

        progressbar.stop()
        frm_progress.grid_remove()
        frm_finished.grid(row=2, column=0, sticky="nsew")

        visual_review_image_filenames.append(input_image_filenames[image_counter - 1])
        visual_review_feedback.append("match")
        visual_review_image_filepaths.append(input_image_filepaths[image_counter - 1])
        visual_review_matching_filenames.append(match_image_filenames[image_counter - 1])

        if image_counter == num_input_files:
            process_save_progress()

        if not image_counter >= num_input_files:

            if type(match_image_filenames[image_counter - 1]) != float:

                btn_back["state"] = "normal"
                btn_match["state"] = "normal"
                btn_not_match["state"] = "normal"
                btn_save["state"] = "normal"

                input_image = ImageTk.PhotoImage(
                    Image.open(input_image_fpath + "/" + input_image_filenames[image_counter]).resize((250, 250)))
                input_image_panel = tk.Label(frm_image_display, image=input_image)
                input_image_panel.grid(row=0, column=1, sticky="nsew", padx=15, pady=15)
                try:
                    match_image = ImageTk.PhotoImage(
                        Image.open(reference_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize(
                            (250, 250)))
                except:
                    match_image = ImageTk.PhotoImage(
                        Image.open(reference_image_dir + "/" + match_image_filenames[image_counter]).resize((250, 250)))
                # match_image = ImageTk.PhotoImage(
                #     Image.open(match_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize((250, 250)))
                match_image_panel = tk.Label(frm_image_display, image=match_image)
                match_image_panel.grid(row=0, column=2, sticky="nsew", padx=15, pady=15)
                counter_current = tk.Label(frm_counter, text=str(image_counter + 1), font=("Arial", 12))
                counter_current.grid(row=0, column=1, sticky="new", padx=5, pady=5)
            else:

                btn_back["state"] = "normal"
                btn_match["state"] = "normal"
                btn_not_match["state"] = "normal"
                btn_save["state"] = "normal"

                image_counter += 1

    else:
        # Otherwise check again after .2 seconds.
        schedule_check_flag_match(thread)


def flag_not_match():
    """
    Predicts the PIM Group for product data entered through GUI from user.

    :return:
    """

    global window

    window.update_idletasks()


def process_flag_not_match():
    """
    Prepares data entered by user through the GUI and calls the predict_from_input
    function in a separate thread.

    :return:
    """

    global window
    global btn_back
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter

    window.update_idletasks()

    frm_finished.grid_remove()

    btn_back["state"] = "disabled"
    btn_match["state"] = "disabled"
    btn_not_match["state"] = "disabled"
    btn_save["state"] = "disabled"

    image_counter += 1

    t = Thread(target=flag_not_match)
    t.start()

    schedule_check_flag_not_match(t)

    return


def schedule_check_flag_not_match(thread):
    """
    Schedules the execution of the check_if_done_input function each
    second.
    """

    global window

    window.after(200, check_if_done_flag_not_match, thread)


def check_if_done_flag_not_match(thread):
    """
    Checks to see if the thread on which the predict_from_input function is
    running and handles actions to be executed upon completion.

    :param thread: The thread on which the predict_from_input function is running
    :return:
    """

    global window
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter
    global input_image_filenames
    global input_image
    global match_image
    global input_image_panel
    global match_image_panel
    global counter_current

    if not thread.is_alive():

        progressbar.stop()
        frm_progress.grid_remove()
        frm_finished.grid(row=2, column=0, sticky="nsew")

        visual_review_image_filenames.append(input_image_filenames[image_counter - 1])
        visual_review_feedback.append("not a match")
        visual_review_image_filepaths.append(input_image_filepaths[image_counter - 1])
        visual_review_matching_filenames.append(match_image_filenames[image_counter - 1])

        if image_counter == num_input_files:
            process_save_progress()

        if not image_counter >= num_input_files:
            btn_back["state"] = "normal"
            btn_match["state"] = "normal"
            btn_not_match["state"] = "normal"
            btn_save["state"] = "normal"

            input_image = ImageTk.PhotoImage(
                Image.open(input_image_fpath + "/" + input_image_filenames[image_counter]).resize((250, 250)))
            input_image_panel = tk.Label(frm_image_display, image=input_image)
            input_image_panel.grid(row=0, column=1, sticky="nsew", padx=15, pady=15)
            try:
                match_image = ImageTk.PhotoImage(
                    Image.open(reference_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize(
                        (250, 250)))
            except:
                match_image = ImageTk.PhotoImage(
                    Image.open(reference_image_dir + "/" + match_image_filenames[image_counter]).resize((250, 250)))
            # match_image = ImageTk.PhotoImage(
            #     Image.open(match_image_dir + "/" + match_image_filenames[image_counter] + ".jpg").resize((250, 250)))
            match_image_panel = tk.Label(frm_image_display, image=match_image)
            match_image_panel.grid(row=0, column=2, sticky="nsew", padx=15, pady=15)
            counter_current = tk.Label(frm_counter, text=str(image_counter + 1), font=("Arial", 12))
            counter_current.grid(row=0, column=1, sticky="new", padx=5, pady=5)

    else:
        # Otherwise check again after .2 seconds.
        schedule_check_flag_not_match(thread)


def save_progress():
    """
    Predicts the PIM Group for product data entered through GUI from user.

    :return:
    """

    global window

    window.update_idletasks()


def process_save_progress():
    """
    Prepares data entered by user through the GUI and calls the predict_from_input
    function in a separate thread.

    :return:
    """

    global window
    global btn_back
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter
    global lbl_greeting

    window.update_idletasks()

    frm_image_banner.grid_remove()
    frm_image_display.grid_remove()

    frm_progress.grid(row=1, column=0, sticky="nsew")
    progressbar.start()

    if image_counter == num_input_files - 1:

        lbl_greeting = tk.Label(frm_greeting,
                                text="End of files, saving user feedback...",
                                font=("Arial", 14))
        lbl_greeting.grid(row=0, column=0, sticky="nsew", padx=25, pady=25)
    else:
        lbl_greeting = tk.Label(frm_greeting,
                                text="Saving user feedback...",
                                font=("Arial", 14))
        lbl_greeting.grid(row=0, column=0, sticky="nsew", padx=25, pady=25)

    btn_back["state"] = "disabled"
    btn_match["state"] = "disabled"
    btn_not_match["state"] = "disabled"
    btn_save["state"] = "disabled"

    image_counter += 1

    t = Thread(target=save_progress)
    t.start()

    schedule_check_save_progress(t)

    return


def schedule_check_save_progress(thread):
    """
    Schedules the execution of the check_if_done_input function each
    second.
    """

    global window

    window.after(1500, check_if_done_save_progress, thread)


def check_if_done_save_progress(thread):
    """
    Checks to see if the thread on which the predict_from_input function is
    running and handles actions to be executed upon completion.

    :param thread: The thread on which the predict_from_input function is running
    :return:
    """

    global window
    global btn_back
    global btn_match
    global btn_not_match
    global btn_save
    global image_counter
    global input_image_filenames
    global lbl_greeting

    if not thread.is_alive():

        progressbar.stop()
        frm_progress.grid_remove()

        lbl_greeting = tk.Label(frm_greeting,
                                text="Save completed",
                                font=("Arial", 14))
        lbl_greeting.grid(row=0, column=0, sticky="nsew", padx=25, pady=25)

        visual_review_image_filenames_series = pd.Series(visual_review_image_filenames)
        visual_review_feedback_series = pd.Series(visual_review_feedback)
        visual_review_image_filepaths_series = pd.Series(visual_review_image_filepaths)
        visual_review_matching_image_filepaths_series = pd.Series(visual_review_matching_filenames)
        visual_results_df = pd.DataFrame(columns=['input_filename', 'input_filepath', 'user_feedback',
                                                  'potential_match_reference_filename'])
        visual_results_df['input_filename'] = visual_review_image_filenames_series
        visual_results_df['input_filepath'] = visual_review_image_filepaths_series
        visual_results_df['user_feedback'] = visual_review_feedback_series
        visual_results_df['potential_match_reference_filename'] = visual_review_matching_image_filepaths_series

        visual_results_df.to_csv(path_or_buf=output_dir + "/hitl_output_" +
                                             str(datetime.now().strftime('%Y-%m-%d_%H.%M.%S')) +
                                             '.csv', sep=',', encoding='utf-8', index=False)

        btn_back["state"] = "disabled"
        btn_match["state"] = "disabled"
        btn_not_match["state"] = "disabled"
        btn_save["state"] = "disabled"

    else:
        # Otherwise check again after one second.
        schedule_check_save_progress(thread)

btn_back = tk.Button(frm_user_input_button, text="Back", border=4, command=process_go_back)
btn_back.grid(row=0, column=1, sticky="nsew", padx=10, pady=5)

btn_match = tk.Button(frm_user_input_button, text="Match", border=4, command=process_flag_match)
btn_match.grid(row=0, column=3, sticky="nsew", padx=10, pady=5)

btn_not_match = tk.Button(frm_user_input_button, text="Not a Match", border=4, command=process_flag_not_match)
btn_not_match.grid(row=0, column=4, sticky="nsew", padx=10, pady=5)

btn_save = tk.Button(frm_user_input_button, text="Save File", border=4, command=process_save_progress)
btn_save.grid(row=0, column=5, sticky="nsew", padx=10, pady=5)

frm_greeting.grid(row=0, column=0, sticky="nsew")
frm_image_banner.grid(row=1, column=0, sticky="nsew")
frm_image_display.grid(row=2, column=0, sticky="nsew")
frm_user_input_button.grid(row=3, column=0, sticky="nsew")
frm_counter.grid(row=4, column=0, sticky="nsew", pady=10)

window.mainloop()
