In [1]:
from facenet_pytorch import MTCNN
import torch
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(DEVICE))
mtcnn = MTCNN(keep_all=True, device=DEVICE)

Running on device: cuda:0


In [2]:
import torchvision.transforms as transforms
from facenet_pytorch import InceptionResnetV1

resnet = InceptionResnetV1(pretrained='vggface2').eval().to(DEVICE)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

def get_emb(img_tensor):
    return resnet(img_tensor.to(DEVICE)).detach().cpu()

In [3]:
import pickle

with open('./model/FaceNet-SVM.pkl', 'rb') as f:
    model = pickle.load(f)

with open('./model/encoder.pkl', 'rb') as f:
    encoder = pickle.load(f)

In [7]:
import time
import cv2
import numpy as np
  
vid = cv2.VideoCapture(0)
window_handle = cv2.namedWindow('Camera', cv2.WINDOW_AUTOSIZE)

update_per_frame = 5
prev_frame_time = 0
new_frame_time = 0
frame_count = 0
fps = "0"
detected = False

result = []
result_prob = []
x1 = []
y1 = []
x2 = []
y2 = []

font = cv2.FONT_HERSHEY_SIMPLEX
green = (100, 255, 0)
yellow = (100, 255, 255)
red = (0, 0, 255)

while True:
    # Read
    ret, frame = vid.read()
    frame = cv2.flip(frame, 1)

    # Update
    frame_count = frame_count + 1
    if frame_count > update_per_frame:

        new_frame_time = time.time()
        fps = str(int(frame_count/(new_frame_time-prev_frame_time)))
        frame_count = 0
        
        boxes, _ = mtcnn.detect(frame)
        if boxes is not None:
            detected = True

            boxes = boxes.astype(int)
            x1 = boxes[:,0]
            y1 = boxes[:,1]
            x2 = boxes[:,2]
            y2 = boxes[:,3]

            faces = mtcnn(frame)
            embs = get_emb(faces)
            outputs = model.predict(embs)
            output_probs = model.predict_proba(embs)
            results = encoder.inverse_transform(outputs)
            result_probs = np.max(output_probs*100, axis=1)

        else:
            detected = False
            pass

    # Draw
    if detected:
        for i in range(len(results)):
            if result_probs[i] > 80:
                frame = cv2.rectangle(frame, (x1[i], y1[i]), (x2[i], y2[i]), green, 2)
                label = f'{results[i]} ({round(result_probs[i], 2)}%)'
                cv2.putText(frame, label, (x1[i], y1[i]-10), font, 0.5, green,2)
            else:
                frame = cv2.rectangle(frame, (x1[i], y1[i]), (x2[i], y2[i]), red, 2)
                cv2.putText(frame, 'unknown', (x1[i], y1[i]-10), font, 0.5, red,2)
    cv2.putText(frame, fps, (0, 25), font, 1, green, 2)
    cv2.imshow('Camera', frame)
    cv2.waitKey(1)
    try:
        cv2.getWindowProperty('Camera', 0)
    except:
        break
    prev_frame_time = new_frame_time
  
vid.release()
cv2.destroyAllWindows()