In [1]:
import cv2
import dlib
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torchvision import transforms

import matplotlib.pyplot as plt 
import numpy as np
from PIL import Image, ImageOps

import time
from facenet_pytorch import MTCNN

In [2]:
class EyeClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 16, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 3)
        self.fc1 = nn.Linear(512, 256)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 512)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
# Load  pre-trained MTCNN face detector
mtcnn = MTCNN()

# Load pre-trained landmark predictor
predictor = dlib.shape_predictor("../Pretrained Detectors/shape_predictor_68_face_landmarks.dat")

# Load CNN eye classifier
eye_model = EyeClassifier()
eye_model.load_state_dict(torch.load("../Saved Models/model2.pt"))
eye_model.eval()

EyeClassifier(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=2, bias=True)
)

In [4]:
# Locates bounding box for a single face
def detect_face(img):
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)     
    blob = cv2.dnn.blobFromImage(cv2.resize(img, (300, 300)), 1.0, (300, 300), (104.0, 177.0, 123.0))

    face_model.setInput(blob)
    detections = face_model.forward()

    (x1, y1, x2, y2) = 0, 0, 0, 0
    max_confidence = 0

    for i in range(detections.shape[2]):                          
        confidence = detections[0, 0, i, 2]

        if confidence > 0.5 and confidence > max_confidence:      # Only considers predictions with > 0.5 confidence
            (h, w) = img.shape[:2]
            x1 = int(detections[0, 0, i, 3] * w)
            y1 = int(detections[0, 0, i, 4] * h)
            x2 = int(detections[0, 0, i, 5] * w)
            y2 = int(detections[0, 0, i, 6] * h)

            max_confidence = confidence                           # If multiple faces are detected, only return the one with highest confidence

    return dlib.rectangle(x1, y1, x2, y2), max_confidence


# Locates bounding box for a single eye
def detect_eye(img, face):
    landmarks = predictor(img, face)

    if landmarks.num_parts == 0:
        return (0, 0, 0, 0), False
    
    """ Below is some random math I came up with to turn LEFT eye landmarks into a square box, feel free to change"""
    
    x1 = landmarks.part(17).x                   
    x2 = landmarks.part(21).x
    d = abs(x2-x1)
    k = d * 0.05

    x1 = x1 - int(k/2)
    x2 = x2 + int(k/2)
    y1 = landmarks.part(19).y - int(k/2)
    y2 = y1 + int(d+k)
    
    """
    x1 = landmarks.part(36).x                   
    x2 = landmarks.part(39).x
    d = abs(x2-x1)
    k = d * 0.9

    x1 = x1 - int(k/2)
    x2 = x2 + int(k/2)
    y1 = landmarks.part(36).y - int((d+k)/2)
    y2 = landmarks.part(36).y + int((d+k)/2)
    """

    return (x1, y1, x2, y2), True

In [29]:
# Prepares an image for CNN eye classifier
def preprocess(img):
    t = transforms.Compose([transforms.Resize([24, 24]), 
                            transforms.ToTensor()]) 
                            
    img = Image.fromarray(img).convert("L")
    img = ImageOps.equalize(img)
    img = t(img)

    return img


# Predicts eye state given a single 1x24x24 tensor
def predict_eye_state(img):
    outputs = eye_model(img.unsqueeze(0))
    prob = F.softmax(outputs, dim = 1)
    pred = outputs.argmax(dim = 1).item()

    #print(f"Probabilities: ({prob[0][0]}, {prob[0][1]})")
    #print("Prediction:", pred)

    return pred

In [5]:
cap = cv2.VideoCapture("../Datasets/DROZY/videos_i8/11-2.mp4")    
#cap = cv2.VideoCapture("./Datasets/X/02/5.mov")   
#cap = cv2.VideoCapture(0)   
#cap = cv2.VideoCapture("./Datasets/Fold4_part1/40/0.mp4")   


while cap.isOpened():
    ret, frame = cap.read()     # return status and image

    #time.sleep(0.05)

    if not ret:
        print("Can't retreive frame")
        break

    rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    rgb = Image.fromarray(rgb)

    try:
        t = time.time()
        (boxes, probs, landmarks) = mtcnn.detect(rgb, landmarks = True)
        print(time.time() - t)
        t = time.time()

        cv2.rectangle(frame, (int(boxes[0][0]), int(boxes[0][1])), (int(boxes[0][2]), int(boxes[0][3])), (0, 255, 0), 1)
        cv2.circle(frame, (int(landmarks[0][0][0]), int(landmarks[0][0][1])), 2, (255, 0, 0), 1)

        d = int(0.2 * abs(boxes[0][0] - boxes[0][2]))
        x1 = int(landmarks[0][0][0]) - d
        x2 = int(landmarks[0][0][0]) + d 
        y1 = int(landmarks[0][0][1]) - d 
        y2 = int(landmarks[0][0][1]) + d 

        cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 1)

        eye = frame[y1:y2, x1:x2]
        eye = preprocess(eye)
        pred = predict_eye_state(eye)

        """ Do stuff with PERCLOS"""

        if(pred == 0):
            cv2.putText(frame, "Closed", (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)


    except:
        pass


    # Display frame
    cv2.imshow("img", frame)

    # Exit window using "q" key
    if cv2.waitKey(1) == ord("q"):
        break


cap.release()
cv2.destroyAllWindows()

0.13403010368347168
0.07101631164550781
0.07001590728759766
0.08301901817321777
0.07501721382141113
0.0780172348022461
0.09002113342285156
0.08301806449890137
0.06801557540893555
0.07301640510559082
0.0670156478881836
0.06501460075378418
0.06601476669311523
0.06601476669311523
0.07101631164550781
0.07001543045043945
0.06401371955871582
0.07001566886901855
0.07701778411865234
0.089019775390625
0.06401419639587402
0.06601524353027344
0.0690147876739502
0.0690155029296875
0.07701802253723145
0.07301664352416992
0.07701754570007324
0.08701944351196289
0.07001590728759766
0.09103012084960938
0.08201861381530762
0.0780181884765625
0.07401657104492188
0.06601524353027344
0.06601452827453613
0.07701754570007324
0.07601714134216309
0.07501745223999023
0.07601737976074219
0.08101797103881836
0.07001686096191406
0.07301688194274902
0.0780179500579834
0.07201647758483887
0.07201623916625977
0.08802056312561035
0.08101820945739746
0.08301925659179688
0.07701706886291504
0.08501935005187988
0.079017

In [8]:
print(boxes)

[[248.72688 130.89388 314.76306 220.96274]]


In [13]:
print(landmarks)

[[[268.3033  167.30743]
  [300.82962 166.44177]
  [284.50043 180.28775]
  [272.05902 203.69574]
  [296.6056  203.28278]]]
