In [17]:
import torch.nn as nn
import cv2
from torchvision import transforms
import torch
import torchvision
import numpy as np
import ds_ear
import glob
from PIL import Image
from matplotlib import image
from DLBio import pt_training
from os.path import join


CATEGORIES = ["Falco", "Jesse", "Konrad", "Nils"]
AUTHORIZED = ["Falco","Konrad"]
RESIZE_Y = 150
RESIZE_X = 100
DATA_TEST_FOLDER = "../test/*jpg"


def get_data(folder):
    img_array = []
    img_array_resized = []
    files = glob.glob (folder)
    for idx, f in zip(range(len(files)),files):
        image = cv2.imread(f)
        img_array.append (image)
        img_array_resized.append(cv2.resize(img_array[idx],(RESIZE_Y,RESIZE_X)))
    return np.asarray(img_array_resized)


model = torch.load('./class_sample/model.pt')



In [18]:
image_array = []
files = glob.glob (DATA_TEST_FOLDER)
files.sort()
for idx, f in zip(range(len(files)),files):
    image = Image.open(f)
    transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize((RESIZE_Y, RESIZE_X)),
        torchvision.transforms.Lambda(lambda x: x.convert('RGB')),
        torchvision.transforms.ToTensor(),

        torchvision.transforms.Normalize(
            [0.485, 0.456, 0.406],
            [0.229, 0.224, 0.225]
        )
    ])
    image_transformed = transform(image)
    image_transformed = image_transformed.reshape(-1, RESIZE_Y, RESIZE_X, 1)
    image_transformed = image_transformed.permute(3, 0, 1, 2)
    image_array.append(image_transformed.type('torch.cuda.FloatTensor'))


In [19]:
all_classes = []
for i in image_array:
	with torch.no_grad():
		pred = model(i)
		pred = torch.softmax(pred, 1)
		pred = pred.cpu().numpy()

	classes = np.argmax(pred, 1)
	all_classes.append(classes[0])

	pred = np.append(pred, classes)
	pred = np.append(pred, CATEGORIES[classes[0]])	
	print(pred, "\n")
print(all_classes)


['0.9667887091636658' '0.0005871779285371304' '0.02024783194065094'
 '0.012376310303807259' '0.0' 'Falco'] 

['0.9290536046028137' '0.003952649887651205' '0.02345476858317852'
 '0.04353899508714676' '0.0' 'Falco'] 

['0.9492008686065674' '0.02982499450445175' '0.011885423213243484'
 '0.009088710881769657' '0.0' 'Falco'] 

['0.9383549690246582' '0.0022658249363303185' '0.03458288311958313'
 '0.024796441197395325' '0.0' 'Falco'] 

['0.8705261945724487' '0.09643664956092834' '0.029231686145067215'
 '0.0038054778706282377' '0.0' 'Falco'] 

['0.27524513006210327' '0.20828931033611298' '0.48518282175064087'
 '0.03128273785114288' '2.0' 'Konrad'] 

['0.7961697578430176' '0.17696717381477356' '0.01918684132397175'
 '0.007676254026591778' '0.0' 'Falco'] 

['0.8900232315063477' '0.042689815163612366' '0.049403268843889236'
 '0.017883557826280594' '0.0' 'Falco'] 

['0.0021648008842021227' '0.017974229529500008' '0.9748411774635315'
 '0.005019825417548418' '2.0' 'Konrad'] 

['0.0019261729903519154

In [20]:
NUMBER_AUTHORIZED = int(.8*len(image_array))
authentification_dict = {CATEGORIES[i]:all_classes.count(i) for i in all_classes}
print(authentification_dict) 

for a in authentification_dict:
    if a in AUTHORIZED and authentification_dict[a] >= NUMBER_AUTHORIZED:
        print("Acces granted! Welcome "  + a + "!")
    else:
        print("Acces denied")
        break


{'Falco': 7, 'Konrad': 3}
Acces denied
