In [None]:
%matplotlib inline

import sys
import os
import numpy as np
import cv2 as cv

import tensorflow as tf
import matplotlib.pyplot as plt

from models import resnet as resnet

In [None]:
def _load_dictionary(dict_file):
    dictionary = dict()
    with open(dict_file, 'r') as lines:
        for line in lines:
            sp = line.rstrip('\n').split('\t')
            idx, name = sp[0], sp[1]
            dictionary[idx] = name
    return dictionary

def preprocess(img):
    rawH = float(img.shape[0])
    rawW = float(img.shape[1])
    newH = 256.0
    newW = 256.0
    test_crop = 224.0 

    if rawH <= rawW:
        newW = (rawW/rawH) * newH
    else:
        newH = (rawH/rawW) * newW
    img = cv.resize(img, (int(newW), int(newH)))
    img = img[int((newH-test_crop)/2):int((newH-test_crop)/2)+int(test_crop),int((newW-test_crop)/2):int((newW-test_crop)/2)+int(test_crop)]
    img = ((img/255.0) - 0.5) * 2.0
    return img[...,::-1]


def load_model(sess, checkpoint, input_shape=(224, 224, 3)):
    # build model
    images = tf.placeholder(dtype=tf.float32, shape=list((None, ) + input_shape))
    net = resnet.ResNet(images, is_training=False)
    net.build_model()

    # restore model
    saver = tf.train.Saver(tf.global_variables())
    saver.restore(sess, checkpoint)

    return net

In [None]:
sess = tf.InteractiveSession(config=tf.ConfigProto(device_count={'GPU': 0}))

net = load_model(sess, "./checkpoints/model.ckpt", input_shape=(224, 224, 3))
prob_topk, pred_topk = tf.nn.top_k(tf.nn.softmax(net.logit), k=20)

dictionary = _load_dictionary("data/ml2020_dictionary.txt")

In [None]:
types = "center"
prefix = 'images/'

for line in os.listdir(prefix):
    try:
        sp = os.path.join(prefix, line), ''
        if not os.path.isfile(sp[0]):
            continue
        raw_img = cv.imread(sp[0])
        img = preprocess(raw_img)
        logits, probs_topk, preds_topk = sess.run(
            [net.logit, prob_topk, pred_topk], {net.images: np.expand_dims(img, axis=0)}
        )
        probs_topk = np.squeeze(probs_topk)
        preds_topk = np.squeeze(preds_topk)
        names_topk = [dictionary[str(i)] for i in preds_topk]

        fig = plt.imshow(img)
        predictions = [
            "{} {}".format(pred, names_topk[i], probs_topk[i]) for i, pred in enumerate(preds_topk)
        ]
        plt.title("\n".join(predictions), loc="left")
        plt.show()
    except:
        print('Image issue {}'.format(line))
        continue