In [41]:
import torch
from facenet_pytorch import MTCNN, InceptionResnetV1
from torchvision import datasets
from torch.utils.data import DataLoader
import cv2

def collate_fn(x):
    return x[0]

In [42]:
mtcnn = MTCNN(image_size=240)
resnet = InceptionResnetV1(pretrained='vggface2').eval()

dataset = datasets.ImageFolder('images/face_recog')
idx_to_class = {i:c for c,i in dataset.class_to_idx.items()}

loader = DataLoader(dataset, collate_fn=collate_fn)

In [46]:
names = [] 
face_embeddings = [] 

for img, idx in loader:
    face, prob = mtcnn(img, return_prob=True) 
    if face is not None and prob>0.90:
        embedded_data = resnet(face.unsqueeze(0)) 
        face_embeddings.append(embedded_data.detach())
        names.append(idx_to_class[idx])

# save data.pt file
data = [face_embeddings, names]
torch.save(data, 'saved_models/face_recog_model.pt') 

In [49]:
def recognize_face(img_path, model_path): 
    img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
    face, prob = mtcnn(img, return_prob=True)
    emb = resnet(face.unsqueeze(0)).detach()
    
    saved_model = torch.load('saved_models/face_recog_model.pt') 
    face_embeddings = saved_model[0]
    names = saved_model[1] 
    dist_list = [] 
    
    for idx, emb_db in enumerate(face_embeddings):
        dist = torch.dist(emb, emb_db).item()
        dist_list.append(dist)
        
    idx_min = dist_list.index(min(dist_list))
    return (names[idx_min], min(dist_list))


result = recognize_face('images/michelle_detected.png', 'saved_models/face_recog_model.pt')

print('Face matched with: ', result[0])

Face matched with:  Michelle Obama
