In [None]:
import google.colab
from pathlib import Path

google.colab.drive.mount("/content/drive")
AUX_DATA_ROOT = Path("/content/drive/My Drive")

import zipfile
with zipfile.ZipFile(AUX_DATA_ROOT / 'eye_openness_data/dataset_final.zip', 'r') as archive:
    archive.extractall()

Mounted at /content/drive


In [1]:
import os
import shutil
import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
# from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Dataset, DataLoader
import torchvision
from torchvision import transforms
import torchvision.models as models

from scipy.optimize import brentq
from scipy.interpolate import interp1d
from sklearn.metrics import roc_curve

In [2]:
def compute_eer(labels, scores):
    """Compute the Equal Error Rate (EER) from the predictions and scores.
    Args:
        labels (list[int]): values indicating whether the ground truth
            value is positive (1) or negative (0).
        scores (list[float]): the confidence of the prediction that the
            given sample is a positive.
    Return:
        (float, thresh): the Equal Error Rate and the corresponding threshold
    NOTES:
       The EER corresponds to the point on the ROC curve that intersects
       the line given by the equation 1 = FPR + TPR.
       The implementation of the function was taken from here:
       https://yangcha.github.io/EER-ROC/
    """
    fpr, tpr, thresholds = roc_curve(labels, scores, pos_label=1)
    eer = brentq(lambda x : 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
    thresh = interp1d(fpr, thresholds)(eer)
    return eer, thresh

In [56]:
def eval(model, loader, ckpt_path=False):
    if ckpt_path:
        model.load_state_dict(torch.load(ckpt_path))

    val_loss = 0.0
    total = 0
    correct = 0
    val_labels = []
    val_probs = []
    model.eval()  

    with torch.no_grad():
        for batch in loader:
            inputs, labels = batch[0].to(device), batch[1].to(device)
          
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            probs = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)
            total += labels.size(0)
            correct += (preds == labels).sum().item()

            val_labels += labels.cpu().numpy().tolist()
            val_probs += probs[:, 1].cpu().numpy().tolist()

    eer, _ = compute_eer(val_labels, val_probs)

    return val_loss, correct / total, eer

def train(model, epoch_num, optimizer, ckpt_save_path):
    print(criterion, optimizer)
    min_val_eer = np.inf

    for epoch in range(epoch_num):
        train_loss = 0.0

        for batch in train_loader:
            inputs, labels = batch[0].to(device), batch[1].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            
        val_loss, val_accuracy, val_eer = eval(model, val_loader)
        test_loss, test_accuracy, test_eer = eval(model, test_loader)

        print('\033[1m' + f'Epoch {epoch}:' + '\033[0m' + f' train loss = {train_loss:.4f}')
        print(f'val loss = {val_loss:.4f}, val accuracy = {val_accuracy:.4f}, val eer = {val_eer:.4f}')
        print(f'test loss = {test_loss:.4f}, test accuracy = {test_accuracy:.4f}, test eer = {test_eer:.4f}')

        if (val_eer < min_val_eer) or (val_eer < .02):
            min_val_eer = val_eer
            torch.save(model.state_dict(), ckpt_save_path)
            print(f'Saving new weights with current val loss = {val_loss:.4f}, val accuracy = {val_accuracy:.4f}, val eer = {val_eer:.4f}, test eer = {test_eer:.4f}')


    return eer

Удаление ненужных файлов '.ipynb_checkpoints' и '.DS_Store'


In [13]:
DATA_DIR = '/content/dataset_final/'

In [None]:
print(os.listdir(DATA_DIR))
os.rmdir(os.path.join(DATA_DIR, '.ipynb_checkpoints'))
os.remove(os.path.join(DATA_DIR, '.DS_Store'))

for mode in ['train', 'val', 'test']:
    path = os.path.join(DATA_DIR, mode)
    print(os.listdir(path))
    try:
        os.rmdir(os.path.join(path, '.ipynb_checkpoints'))
    except:
        print('.ipynb_checkpoints file is already deleted')
    try:
        os.remove(os.path.join(path, '.DS_Store'))
    except:
        print('.DS_Store file is already deleted')
    


['val', 'train', 'test']
['closed', 'open']
.ipynb_checkpoints file is already deleted
.DS_Store file is already deleted
['closed', 'open']
.ipynb_checkpoints file is already deleted
.DS_Store file is already deleted
['closed', 'open']
.ipynb_checkpoints file is already deleted
.DS_Store file is already deleted


In [11]:
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.RandomHorizontalFlip(p=0.2),
    transforms.RandomRotation(degrees=(-45, 45)),
    transforms.RandomPerspective(distortion_scale=0.7, p=1, interpolation=2, fill=0),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4976, 0.4976, 0.4976],
                         std=[0.1970, 0.1970, 0.1970]),
    ])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(256),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4976, 0.4976, 0.4976],
                          std=[0.1970, 0.1970, 0.1970])
    ])

  "Argument interpolation should be of type InterpolationMode instead of int. "


In [14]:
BATCH_SIZE = 64

train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, 'train'), transform=train_transform)
train_loader = DataLoader(train_dataset, 
                          batch_size=BATCH_SIZE, 
                          shuffle=True,  
                          num_workers=0)

val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, 'val'), 
                                               transform=val_transform)
val_loader = DataLoader(val_dataset, 
                        batch_size=BATCH_SIZE, 
                        shuffle=True, 
                        num_workers=0) 

test_dataset = torchvision.datasets.ImageFolder(root=os.path.join(DATA_DIR, 'test'), 
                                                transform=val_transform)
test_loader = DataLoader(test_dataset, 
                         batch_size=BATCH_SIZE, 
                         shuffle=True, 
                         num_workers=0) 

In [15]:
train_dataset.find_classes(os.path.join(DATA_DIR, 'train'))

(['closed', 'open'], {'closed': 0, 'open': 1})

In [None]:
model = models.wide_resnet50_2(pretrained=True)
model.fc = nn.Linear(2048, 2)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), 
                      lr=0.0008, 
                      momentum=0.9, 
                      nesterov=True, 
                      weight_decay=0.002)

_ = train(model, 80, optimizer, '/content/wide_resnet50_2.pth')

CrossEntropyLoss() SGD (
Parameter Group 0
    dampening: 0
    lr: 0.0008
    momentum: 0.9
    nesterov: True
    weight_decay: 0.002
)


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


[1mEpoch 0:[0m train loss = 31.2839
val loss = 2.7283, val accuracy = 0.9127, val eer = 0.0863
test loss = 5.4325, test accuracy = 0.9137, test eer = 0.0699
Saving new weights with current val loss = 2.7283, val accuracy = 0.9127, val eer = 0.0863, test eer = 0.0699
[1mEpoch 1:[0m train loss = 12.0012
val loss = 1.0023, val accuracy = 0.9401, val eer = 0.0457
test loss = 1.7968, test accuracy = 0.9551, test eer = 0.0308
Saving new weights with current val loss = 1.0023, val accuracy = 0.9401, val eer = 0.0457, test eer = 0.0308
[1mEpoch 2:[0m train loss = 7.6901
val loss = 0.7814, val accuracy = 0.9526, val eer = 0.0355
test loss = 1.5704, test accuracy = 0.9586, test eer = 0.0379
Saving new weights with current val loss = 0.7814, val accuracy = 0.9526, val eer = 0.0355, test eer = 0.0379
[1mEpoch 3:[0m train loss = 5.9289
val loss = 0.7285, val accuracy = 0.9626, val eer = 0.0254
test loss = 1.4480, test accuracy = 0.9610, test eer = 0.0213
Saving new weights with current val 

KeyboardInterrupt: ignored

In [58]:
eval(model, val_loader, 'wide_resnet50_2.pth')

(0.5923457383178174, 0.9850374064837906, 0.015228426396226367)

In [57]:
eval(model, test_loader, 'wide_resnet50_2.pth')

(1.1331980407703668, 0.983451536643026, 0.015366430259682356)

In [51]:
class OpenEyesClassificator:
    def __init__(self, ckpt_path):
        self.ckpt_path = ckpt_path
        self.model = models.wide_resnet50_2()
        self.model.fc = nn.Linear(2048, 2)
        self.model.load_state_dict(torch.load(ckpt_path))
        self.transform = transforms.Compose([
                                             transforms.Resize(256),
                                             transforms.CenterCrop(256),
                                             transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.4976, 0.4976, 0.4976],
                                                                  std=[0.1970, 0.1970, 0.1970])
                                             ])
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def predict(self, inplm):
        img = Image.open(inplm)
        img = img.convert('RGB') # 3 channels needed for pretrained wideresnet weights usage 
        img = self.transform(img)
        img = torch.unsqueeze(img, 0)
        img = img.to(self.device)
        
        with torch.no_grad():
            output = self.model(img)
            probs = torch.softmax(output, dim=1)
            is_open_score = probs[:, 1].cpu().numpy()
            preds = torch.argmax(probs, dim=1)

        return is_open_score

In [59]:
classifier = OpenEyesClassificator('wide_resnet50_2.pth')

In [61]:
path = '/content/dataset_final/val/closed/closed_102.jpg'
classifier.predict(path)

array([1.781304e-06], dtype=float32)

In [None]:
path = '/content/dataset_final/val/open/open_197.jpg'
classifier.predict(path)