# Imports

In [None]:
import os
import pickle
import h5py
import numpy as np
import matplotlib.pyplot as plt

import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from models.topic_category_model import load_category_model

# Load Data

In [None]:
def load_data(filename, processed_data_dir, data_type):
    h5f = h5py.File(os.path.join(processed_data_dir, filename), 'r')
    data = h5f[data_type][:]
    h5f.close()
    return data

In [None]:
# Folder containing the datset
data_dir = '../dataset'
processed_data_dir = '../dataset/processed_topic_data'

In [None]:
# Total number of categories
with open(data_dir + '/coco_raw.pickle', 'rb') as file:
    coco_raw = pickle.load(file)
id_category = coco_raw['id_category']
num_classes = len(id_category)

In [None]:
test_images = load_data('test_images.h5', processed_data_dir, 'images')

In [None]:
test_categories = load_data('test_categories.h5', processed_data_dir, 'labels')

# Load Model

In [None]:
weights_path = '../weights/topic_category_model.keras'
model = load_category_model(num_classes, weights_path)

# Test Model

In [None]:
def get_predictions(model, image, label, id_category):
    image_batch = np.expand_dims(image, axis=0)
    predictions = model.predict(image_batch)
    
    prediction_labels = []
    for index, prediction_probability in enumerate(predictions[0]):
        if prediction_probability > 0.5:
            prediction_labels.append(id_category[index])
    
    true_labels = []
    for index, value in enumerate(label):
        if value == 1:
            true_labels.append(id_category[index])
    
    print('True labels:', true_labels)
    print('Predictions:', prediction_labels)
    
    print('Image:')
    plt.imshow(image)
    plt.show()

In [None]:
idx = 789

In [None]:
get_predictions(model, test_images[idx], test_categories[idx], id_category)