# Face Anti Spoofing Neural Network

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

## Display function

In [None]:
def imshow_np(img):
    height,width,depth = img.shape
    if depth == 1:
        img=img[:,:,0]
    plt.imshow(img)
    plt.show()

def imshow(img):
    imshow_np(img.numpy())

## Data sets creation

In [None]:
#creation des donnees:

#Images
Images = np.load('images_sample.npz')

#Changement de base
# Anchors = np.load('anchors_sample.npz')
  
#label_D:
Labels_D = np.load('labels_D_sample.npz')

#label_spoofing:
Labels = np.load('label_sample.npz')

#set:
n = len(Images)

data_images = np.zeros((n,256,256,3),dtype=np.float32)
# data_anchors = np.zeros((n,2,4096),dtype=np.float32)
data_labels_D = np.zeros((n,32,32,1),dtype=np.float32)
data_labels = np.zeros((n),dtype=np.float32)

for item in Images.files:
    data_images[int(item),:,:,:] = Images[item]
    # data_anchors[int(item),:,:] = Anchors[item]
    data_labels_D[int(item),:,:,:] = Labels_D[item]
    data_labels[int(item)] = Labels[item]

In [None]:
training_part = 45 / 55
n_train = int(n * training_part)

#Training set
data_images_train = data_images[:n_train, :, :, :]
# data_anchors_train = data_anchors[:n_train,:,:]
data_labels_D_train = data_labels_D[:n_train, :, :, :]
data_labels_train = data_labels[:n_train]

#Test set
data_images_test = data_images[n_train:, :, :, :]
# data_anchors_test = data_anchors[n_train:,:,:]
data_labels_D_test = data_labels_D[n_train:, :, :, :]
data_labels_test = data_labels[n_train:]

In [None]:
def prepare_dataloader_D(data_images_train, data_images_test, data_labels_D_train, data_labels_D_test):
  
    trainset_D = torch.utils.data.TensorDataset(torch.tensor(np.transpose(data_images_train, (0, 3, 1, 2))), torch.tensor(data_labels_D_train))
    testset_D = torch.utils.data.TensorDataset(torch.tensor(np.transpose(data_images_test, (0, 3, 1, 2))), torch.tensor(data_labels_D_test))

    trainloader_D = torch.utils.data.DataLoader(trainset_D, batch_size=5, shuffle=False)
    testloader_D = torch.utils.data.DataLoader(testset_D, batch_size=5, shuffle=False)

    return trainloader_D, testloader_D

trainloader_D, testloader_D = prepare_dataloader_D(data_images_train, data_images_test, data_labels_D_train, data_labels_D_test)

## Model creation

In [None]:
from Models import Anti_Spoof_net

mon_model = Anti_Spoof_net.Anti_spoof_net()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(mon_model.parameters(), lr=3e-3, betas=(0.9, 0.999), eps=1e-08)

## Training

In [None]:
def train_CNN(net, optimizer, trainloader, criterion, n_epoch = 10):

    total = 0

    for epoch in range(n_epoch):
        # loop over the dataset multiple times

        running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            #Pre-created Data:
            images, labels_D = data
            # training step
            optimizer.zero_grad()
            outputs_D, _ = net(images)
            #handle NaN:
            if torch.norm((outputs_D != outputs_D).float()) == 0:
                if i % 50 == 0 or i % 50 == 1:
                    imshow_np(np.transpose(images[0,:,:,:].numpy(), (1,2,0)))
                    imshow_np(np.transpose(outputs_D[0,:,:,:].detach().numpy(), (1,2,0)))

                loss = criterion(outputs_D, labels_D)
                loss.backward()
                optimizer.step()

                # compute statistics
                total += labels_D.size(0)
                running_loss += loss.item()

                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / total))

        print('Epoch finished')

    print('Finished Training')

In [None]:
def train_RNN(net, optimizer, trainloader, labels, criterion, n_epoch = 10):

    total = 0

    for epoch in range(n_epoch):
    # loop over the dataset multiple times

        running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            #Donnees pre-crees:
            images, labels_D = data
            # training step
            optimizer.zero_grad()
            _, outputs_F = net(images)
            #handle NaN:
            if torch.norm((outputs_F != outputs_F).float())==0:
                if i % 50 == 0 or i % 50 == 1:
                    imshow_np(np.transpose(images[0,:,:,:].numpy(), (1,2,0)))
                    print('F:')
                    print(outputs_F)

                if labels[i * 5] == 0: #toutes les images du batch proviennent de la même vidéo
                    label=torch.zeros((5,1,2), dtype=torch.float32)
                else:
                    label=torch.ones((5,1,2), dtype=torch.float32)

                loss = criterion(outputs_F, label)
                loss.backward()
                optimizer.step()

                # compute statistics
                total += labels_D.size(0)
                running_loss += loss.item()

                print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / total))

        print('Epoch finished')

    print('Finished Training')

### For the overall model training, we alternatively train the CNN part and the CNN/RNN part

In [None]:
def train_All(net, optimizer, trainloader, labels, criterion, n_epoch = 10):
    for i in range(n_epoch):
        train_CNN(net=net, optimizer=optimizer, trainloader=trainloader, criterion=criterion, n_epoch = 1)
        torch.save(net,'mon_model')
        train_RNN(net=net, optimizer=optimizer, trainloader=trainloader, labels=labels, criterion=criterion, n_epoch = 1)
        torch.save(net,'mon_model')
    
mon_model = torch.load('mon_model')
outputs = train_All(net=mon_model, optimizer=optimizer, trainloader=trainloader_D, labels=data_labels_train, criterion=criterion, n_epoch = 10)

In [None]:
mon_model = torch.load('mon_model')

In [None]:
def accuracy(net, criterion, testloader, label):
    correct = 0
    total = 0
    l = 0.015

    for i, (images, _ ) in enumerate(testloader, 0):

        outputs_D, outputs_F = net(images)
        critere = torch.norm(outputs_D) + l * torch.norm(outputs_F)
        #We will take 850 as offset
        if critere > 850 and label == 1:
            correct+=1
        if critere < 850 and label == 0:
            correct==1

        total+=1

        print(correct / total)

    accuracy =  correct / total
    loss = loss / total
    return accuracy, loss

accuracy(mon_model, criterion, testloader_D, data_labels_test)