In [1]:
import os
import glob
import json
import torch
import cv2
from PIL import Image
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from facenet_pytorch import MTCNN

from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from torch.nn import functional as F
from torchvision.models import resnet18
from albumentations import Normalize, Compose
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import multiprocessing as mp

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Running on device: {device}')

  check_for_updates()


Running on device: cuda:0


In [None]:
train_dir = 'D:\W\VS\VS Folder\DFD\DFDC MTCNN Extracted/'
SAVE_PATH = 'googleViT.pth' # The location where the model should be saved.
PRETRAINED_MODEL_PATH = ''
N_FACES = 5
TEST_SIZE = 0.3
RANDOM_STATE = 123

BATCH_SIZE = 32
NUM_WORKERS = mp.cpu_count()

WARM_UP_EPOCHS = 10
WARM_UP_LR = 1e-4
FINE_TUNE_EPOCHS = 100
FINE_TUNE_LR = 1e-6

THRESHOLD = 0.5
EPSILON = 1e-7

  train_dir = 'D:\W\VS\VS Folder\DFD\DFDC MTCNN Extracted/'


In [6]:
def calculate_f1(preds, labels):
    '''
    Parameters:
        preds: The predictions.
        labels: The labels.

    Returns:
        f1 score
    '''

    labels = np.array(labels, dtype=np.uint8)
    preds = (np.array(preds) >= THRESHOLD).astype(np.uint8)
    tp = np.count_nonzero(np.logical_and(labels, preds))
    tn = np.count_nonzero(np.logical_not(np.logical_or(labels, preds)))
    fp = np.count_nonzero(np.logical_not(labels)) - tn
    fn = np.count_nonzero(labels) - tp
    precision = tp / (tp + fp + EPSILON)
    recall = tp / (tp + fn + EPSILON)
    f1 = (2 * precision * recall) / (precision + recall + EPSILON)
    
    return f1


def train_the_model(
    model,
    criterion,
    optimizer,
    epochs,
    train_dataloader,
    val_dataloader,
    best_val_loss=1e7,
    best_val_logloss=1e7,
    save_the_best_on='val_logloss'
):
    '''
    Parameters:
        model: The model needs to be trained.
        criterion: Loss function.
        optimizer: The optimizer.
        epochs: The number of epochs
        train_dataloader: The dataloader used to generate training samples.
        val_dataloader: The dataloader used to generate validation samples.
        best_val_loss: The initial value of the best val loss (default: 1e7.)
        best_val_logloss: The initial value of the best val log loss (default: 1e7.)
        save_the_best_on: Whether to save the best model based on "val_loss" or "val_logloss" (default: val_logloss.)

    Returns:
        losses: All computed losses.
        val_losses: All computed val_losses.
        loglosses: All computed loglosses.
        val_loglosses: All computed val_loglosses.
        f1_scores: All computed f1_scores.
        val_f1_scores: All computed val_f1_scores.
        best_val_loss: New value of the best val loss.
        best_val_logloss: New value of the best val log loss.
        best_model_state_dict: The state_dict of the best model.
        best_optimizer_state_dict: The state_dict of the optimizer corresponds to the best model.
    '''

    losses = np.zeros(epochs)
    val_losses = np.zeros(epochs)
    loglosses = np.zeros(epochs)
    val_loglosses = np.zeros(epochs)
    f1_scores = np.zeros(epochs)
    val_f1_scores = np.zeros(epochs)
    best_model_state_dict = None
    best_optimizer_state_dict = None

    logloss = nn.BCELoss()

    for i in tqdm(range(epochs)):
        batch_losses = []
        train_pbar = tqdm(train_dataloader)
        train_pbar.desc = f'Epoch {i+1}'
        classifier.train()

        all_labels = []
        all_preds = []

        for i_batch, sample_batched in enumerate(train_pbar):
            # Make prediction.
            faces = sample_batched['faces'].to(device)
            labels = sample_batched['label'].to(device)
            y_pred = classifier(faces)

            all_labels.extend(labels.squeeze(dim=-1).tolist())
            all_preds.extend(y_pred.squeeze(dim=-1).tolist())

            # Compute loss.
            loss = criterion(y_pred, labels)
            batch_losses.append(loss.item())

            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Display some information in progress-bar.
            train_pbar.set_postfix({
                'loss': batch_losses[-1]
            })

        # Compute scores.
        loglosses[i] = logloss(torch.tensor(all_preds).to(device), torch.tensor(all_labels).to(device))
        f1_scores[i] = calculate_f1(all_preds, all_labels)

        # Compute batch loss (average).
        losses[i] = np.array(batch_losses).mean()


        # Compute val loss
        val_batch_losses = []
        val_pbar = tqdm(val_dataloader)
        val_pbar.desc = 'Validating'
        classifier.eval()

        all_labels = []
        all_preds = []

        for i_batch, sample_batched in enumerate(val_pbar):
            # Make prediction.
            faces = sample_batched['faces'].to(device)
            labels = sample_batched['label'].to(device)
            y_pred = classifier(faces)

            all_labels.extend(labels.squeeze(dim=-1).tolist())
            all_preds.extend(y_pred.squeeze(dim=-1).tolist())

            # Compute val loss.
            val_loss = criterion(y_pred, labels)
            val_batch_losses.append(val_loss.item())

            # Display some information in progress-bar.
            val_pbar.set_postfix({
                'val_loss': val_batch_losses[-1]
            })

        # Compute val scores.
        val_loglosses[i] = logloss(torch.tensor(all_preds).to(device), torch.tensor(all_labels).to(device))
        val_f1_scores[i] = calculate_f1(all_preds, all_labels)

        val_losses[i] = np.array(val_batch_losses).mean()
        print(f'loss: {losses[i]} | val loss: {val_losses[i]} | f1: {f1_scores[i]} | val f1: {val_f1_scores[i]} | log loss: {loglosses[i]} | val log loss: {val_loglosses[i]}')
        
        # Update the best values
        if val_losses[i] < best_val_loss:
            best_val_loss = val_losses[i]
            if save_the_best_on == 'val_loss':
                print('Found a better checkpoint!')
                best_model_state_dict = classifier.state_dict()
                best_optimizer_state_dict = optimizer.state_dict()
        if val_loglosses[i] < best_val_logloss:
            best_val_logloss = val_loglosses[i]
            if save_the_best_on == 'val_logloss':
                print('Found a better checkpoint!')
                best_model_state_dict = classifier.state_dict()
                best_optimizer_state_dict = optimizer.state_dict()
            
    return losses, val_losses, loglosses, val_loglosses, f1_scores, val_f1_scores, best_val_loss, best_val_logloss, best_model_state_dict, best_optimizer_state_dict


def visualize_results(
    losses,
    val_losses,
    loglosses,
    val_loglosses,
    f1_scores,
    val_f1_scores
):
    '''
    Parameters:
        losses: A list of losses.
        val_losses: A list of val losses.
        loglosses: A list of loglosses.
        val_loglosses: A list of val loglosses.
        f1_scores: A list of f1 scores.
        val_f1_scores: A list of val f1 scores.
    '''

    fig = plt.figure(figsize=(16, 8))
    ax = fig.add_axes([0, 0, 1, 1])

    ax.plot(np.arange(1, len(losses) + 1), losses)
    ax.plot(np.arange(1, len(val_losses) + 1), val_losses)
    ax.set_xlabel('epoch', fontsize='xx-large')
    ax.set_ylabel('focal loss', fontsize='xx-large')
    ax.legend(
        ['loss', 'val loss'],
        loc='upper right',
        fontsize='xx-large',
        shadow=True
    )
    plt.show()

    
    fig = plt.figure(figsize=(16, 8))
    ax = fig.add_axes([0, 0, 1, 1])

    ax.plot(np.arange(1, len(loglosses) + 1), loglosses)
    ax.plot(np.arange(1, len(val_loglosses) + 1), val_loglosses)
    ax.set_xlabel('epoch', fontsize='xx-large')
    ax.set_ylabel('log loss', fontsize='xx-large')
    ax.legend(
        ['log loss', 'val log loss'],
        loc='upper right',
        fontsize='xx-large',
        shadow=True
    )
    plt.show()


    fig = plt.figure(figsize=(16, 8))
    ax = fig.add_axes([0, 0, 1, 1])

    ax.plot(np.arange(1, len(f1_scores) + 1), f1_scores)
    ax.plot(np.arange(1, len(val_f1_scores) + 1), val_f1_scores)
    ax.set_xlabel('epoch', fontsize='xx-large')
    ax.set_ylabel('f1 score', fontsize='xx-large')
    ax.legend(
        ['f1', 'val f1'],
        loc='upper left',
        fontsize='xx-large',
        shadow=True
    )
    plt.show()

In [7]:
class DeepfakeClassifier(nn.Module):
    def __init__(self, encoder, in_channels=3, num_classes=1):
        super(DeepfakeClassifier, self).__init__()
        self.encoder = encoder
        
        # Modify input layer.
        self.encoder.conv1 = nn.Conv2d(
            in_channels,
            64,
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False
        )
        
        # Modify output layer.
        self.encoder.fc = nn.Linear(512 * 1, num_classes)

    def forward(self, x):
        return torch.sigmoid(self.encoder(x))
    
    def freeze_all_layers(self):
        for param in self.encoder.parameters():
            param.requires_grad = False

    def freeze_middle_layers(self):
        self.freeze_all_layers()
        
        for param in self.encoder.conv1.parameters():
            param.requires_grad = True
            
        for param in self.encoder.fc.parameters():
            param.requires_grad = True

    def unfreeze_all_layers(self):
        for param in self.encoder.parameters():
            param.requires_grad = True


class FaceDataset(Dataset):
    def __init__(self, img_dirs, labels, n_faces=1, preprocess=None):
        self.img_dirs = img_dirs
        self.labels = labels
        self.n_faces = n_faces
        self.preprocess = preprocess

    def __len__(self):
        return len(self.img_dirs)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_dir = self.img_dirs[idx]
        label = self.labels[idx]
        face_paths = glob.glob(f'{img_dir}/*.png')

        if len(face_paths) >= self.n_faces:
            sample = np.random.choice(face_paths, self.n_faces, replace=False)
        else:
            sample = np.random.choice(face_paths, self.n_faces, replace=True)
            
        faces = []
        
        for face_path in sample:
            face = cv2.imread(face_path, 1)
            face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
            if self.preprocess is not None:
                augmented = self.preprocess(image=face)
                face = augmented['image']
            faces.append(face)

        return {'faces': np.concatenate(faces, axis=-1).transpose(2, 0, 1), 'label': np.array([label], dtype=float)}
    
    
class FaceValDataset(Dataset):
    def __init__(self, img_dirs, labels, n_faces=1, preprocess=None):
        self.img_dirs = img_dirs
        self.labels = labels
        self.n_faces = n_faces
        self.preprocess = preprocess

    def __len__(self):
        return len(self.img_dirs)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_dir = self.img_dirs[idx]
        label = self.labels[idx]
        face_paths = glob.glob(f'{img_dir}/*.png')

        face_indices = [
            path.split('\\')[-1].split('.')[0].split('_')[0]
            for path in face_paths
        ]        
        max_idx = np.max(np.array(face_indices, dtype=np.uint32))

        selected_paths = []

        for i in range(self.n_faces):
            stride = int((max_idx + 1)/(self.n_faces**2))
            sample = np.linspace(i*stride, max_idx + i*stride, self.n_faces).astype(int)

            # Get faces
            for idx in sample:
                paths = glob.glob(f'{img_dir}/{idx}*.png')

                selected_paths.extend(paths)

                if len(selected_paths) >= self.n_faces:
                    break
            
            if len(selected_paths) >= self.n_faces:
                break

        faces = []

        selected_paths = selected_paths[:self.n_faces] # Get top
        for selected_path in selected_paths:
            img = cv2.imread(selected_path, 1)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            faces.append(img)

        if self.preprocess is not None:
            for j in range(len(faces)):
                augmented = self.preprocess(image=faces[j])
                faces[j] = augmented['image']

        faces = np.concatenate(faces, axis=-1).transpose(2, 0, 1)

        return {
            'faces': faces,
            'label': np.array([label], dtype=float)
        }


class FocalLoss(nn.Module):
    def __init__(self, gamma=2, sample_weight=None):
        super().__init__()
        self.gamma = gamma
        self.sample_weight = sample_weight

    def forward(self, logit, target):
        target = target.float()
        max_val = (-logit).clamp(min=0)
        loss = logit - logit * target + max_val + \
               ((-max_val).exp() + (-logit - max_val).exp()).log()

        invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        if len(loss.size())==2:
            loss = loss.sum(dim=1)
        if self.sample_weight is not None:
            loss = loss * self.sample_weight
        return loss.mean()

In [None]:
train_df = pd.read_csv(os.path.join(train_dir, 'metadata.csv'))
train_df['path'] = train_df['filename'].apply(lambda x: os.path.join(train_dir, x.split('.')[0]))
train_df.head()