In [2]:
import os,magic
from dnn_api import ModelWrapper

In [20]:
class Predictor():
    def __init__(self,
                 model_path,
                 class_confidence_threshold = 0.9999):
        self.model = ModelWrapper.load(model_path)
        self.class_confidence_threshold = class_confidence_threshold
    
    def predict_file(self,filepath):
        if not os.path.isfile(filepath):
            return 'not valid'
        file_mimetype = magic.Magic(mime=True).from_file(filepath)
        if not file_mimetype in ('image/jpeg','image/png'):
            return 'not valid'
        
        pred_probs = self.model.predict(filepath)
        if all([p<self.class_confidence_threshold for p in pred_probs]):
            return 'unknown'
        elif pred_probs[0] > pred_probs[1]:
            return 'cat'
        else:
            return 'dog'
        
    def predict_dir(self,dirpath):
        filenames = os.listdir(dirpath)
        return {filename: self.predict_file(os.path.join(dirpath,filename)) \
               for filename in filenames}            

In [9]:
def prettyprint(dct):
    max_key_len = max([len(str(key)) for key in dct.keys()])
    for key in sorted(dct.keys()):
        print(key,' '*(max_key_len-len(str(key)))+':',dct[key])

In [21]:
p = Predictor('model/model.hd5')

In [22]:
prediction = p.predict_dir('sample/')

In [15]:
prettyprint(prediction)

cat.11737.jpg  : cat
cat.2266.jpg   : cat
cat.2921.jpg   : cat
cat.3570.jpg   : cat
cat.394.jpg    : cat
cat.4600.jpg   : cat
cat.4865.jpg   : cat
cat.9021.jpg   : cat
dog.1402.jpg   : dog
dog.1614.jpg   : dog
dog.2423.jpg   : cat
dog.6391.jpg   : dog
dog.6768.jpg   : dog
dog.8091.jpg   : dog
dog.8643.jpg   : dog
dog.9077.jpg   : dog
notimage.1.txt : not valid
notimage.2     : not valid
random.1.jpg   : unknown
random.2.jpg   : dog
random.3.jpg   : dog
random.4.jpg   : cat
random.5.jpg   : dog
random.6.png   : dog
