In [2]:
import sys
sys.path.insert(0, 'glasses')

from typing import Tuple
from image_folder import ImageFolderWithPaths, SingleImageFolder
from train import NORMALIZITAION_FOR_PRETRAINED

import shutil
import torch
import numpy as np
import os

from ipywidgets import IntProgress
from IPython.display import display
from facenet_pytorch import MTCNN, InceptionResnetV1
from PIL import Image
from torchvision import transforms, models, utils

In [3]:
def check_glasses(images_path):
    model_kind = 'squeeze'
    model_params = 'glasses/dist/squeezenet_params'
    threshold = 4
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if model_kind == 'resnet':
        model = models.resnet18(pretrained=False, num_classes=1)
    else:
        model = models.squeezenet1_1(pretrained=False, num_classes=1)

    if model_params is not None: 
        model.load_state_dict(torch.load(model_params, map_location=device))
        model.to(device)
        model.eval()
        
    transform = transforms.Compose([
        transforms.CenterCrop(400),
        transforms.Resize(224),
        transforms.ToTensor(),
        NORMALIZITAION_FOR_PRETRAINED
    ])
    images = SingleImageFolder(images_path, transform=transform)

    for img_tensor, path in images:
        if img_tensor == None:
            continue
        has_glasses = model(img_tensor.to(device).unsqueeze(0))
        if (has_glasses > threshold):
            return True
    return False

    
def get_face_coordinates(path, mtcnn_model):
    img = Image.open(path)
    boxes, _ = mtcnn_model.detect(img)
    if type(boxes) == np.ndarray:
        if len(boxes) == 1:
            return boxes[0]
        else:
            return []
    else:
        return []
        
        
def check_shape(boxes, threshold):
    
    width = abs(max(0, boxes[0]) - max(0, boxes[2]))
    height = abs(max(0, boxes[1]) - max(0, boxes[3]))
    if width > threshold and height > threshold:
        return True
    return False



def process_photo(filename, input_path, mtcnn_model, output_path):
    
    if filename == '.DS_Store': ## CHECKED
        return
        
    path = input_path + filename ## image path
    
    coordinates = get_face_coordinates(path, mtcnn_model) ## get face' boxes
    if coordinates != []: ## only one face on photo
        is_ok = check_shape(coordinates, 255)
    else:
        is_ok = 0

    if is_ok:
        img = Image.open(path)
#         img_test = img.crop(tuple(boxes)) TO DO: expand size
        img.save('{0}temp.jpeg'.format(temp_path))

        images_path = 'check_glasses' ## temp folder for model
        
        if check_glasses(images_path):
            shutil.move(path, output_path + filename)
    

In [4]:
mtcnn_model = MTCNN()

input_path = 'Input_Dataset/'
output_path = 'Output_Dataset/'
temp_path = 'Check_glasses/'

progress = IntProgress(min=0, max=len(os.listdir(input_path)), value=0)
display(progress)

images = os.listdir(input_path)
images.sort()

for i, filename in enumerate(images):
    process_photo(filename, input_path, mtcnn_model, output_path)
    
    if (i % 10 == 0):
        progress.value = i
            
    

IntProgress(value=0, max=1)