In [1]:
from model.cnn.xception_net.model import xception
from model.cnn.xception_net.model import xception
import torch
from module.train_torch import calculate_metric
import random
import numpy as np
import os
import torch.nn as nn
from glob import glob
from tqdm import tqdm
from os.path import join
from utils.util import get_test_metric_by_step2
from operator import truediv
import os, sys
from random import shuffle
import os.path as osp

import torch
from torchvision import datasets, transforms

from dataloader.utils import make_weights_for_balanced_classes, make_weights_for_balanced_classes_2
from dataloader.dual_fft_dataset import DualFFTMagnitudeFeatureDataset, DualFFTMagnitudeImageDataset, TripleFFTMagnitudePhaseDataset
from dataloader.pairwise_dual_fft_dataset import PairwiseDualFFTMagnitudeFeatureDataset, PairwiseDualFFTMagnitudeImageDataset
from dataloader.pairwise_single_fft_dataset import PairwiseSingleFFTMagnitudeImageDataset
from dataloader.pairwise_triple_fft_dataset import PairwiseTripleFFTMagnitudePhaseDataset
from dataloader.triplewise_dual_fft_dataset import TriplewiseDualFFTMagnitudeImageDataset
from dataloader.transform import transform_method

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def eval_kfold_image_stream(model ,dataloader, device, criterion, adj_brightness=1.0, adj_contrast=1.0 ):
    loss = 0
    mac_accuracy = 0
    model.eval()
    y_label = []
    y_pred_label = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            # Push to device
            y_label.extend(labels.cpu().numpy().astype(np.float64))
            inputs, labels = inputs.float().to(device), labels.float().to(device)

            # Forward network
            logps = model.forward(inputs)
            if len(logps.shape) == 2:
                logps = logps.squeeze(dim=1)

            # Loss in a batch
            batch_loss = criterion(logps, labels)
            # Cumulate into running val loss
            loss += batch_loss.item()

            # Find accuracy
            equals = (labels == (logps > 0.5))
            mac_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
            #
            logps_cpu = logps.cpu().numpy()
            pred_label = (logps_cpu > 0.5)
            y_pred_label.extend(pred_label)
            
    assert len(y_label) == len(y_pred_label), "Bug"
    ######## Calculate metrics:
    loss /= len(dataloader)
    mac_accuracy /= len(dataloader)
    # built-in methods for calculating metrics
    mic_accuracy, reals, fakes, micros, macros = calculate_metric(y_label, y_pred_label)
    return loss, mac_accuracy, mic_accuracy, reals, fakes, micros, macros

def generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, image_size, batch_size, num_workers, adj_brightness=1.0, adj_contrast=1.0):
    transform_fwd = transforms.Compose([transforms.Resize((image_size,image_size)),\
                                        transforms.Lambda(lambda img :transforms.functional.adjust_brightness(img,adj_brightness)),\
                                        transforms.Lambda(lambda img :transforms.functional.adjust_contrast(img,adj_contrast)),\
                                        transforms.ToTensor(),\
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],\
                                                             std=[0.229, 0.224, 0.225]),\
                                        ])
    # Make dataset using built-in ImageFolder function of torch
    test_dataset = datasets.ImageFolder(test_dir, transform=transform_fwd)
    assert test_dataset, "Test Dataset is empty!"
    print("Test image dataset: ", test_dataset.__len__())
    # Make dataloader
    dataloader_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return dataloader_test

################### FROM DEEPFAKE TO OTHER DATASET:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda')

test_dir_template = "/mnt/disk1/doan/phucnp/Dataset/my_extend_data/extend_data_train/{}/test"

checkpoints = {
    "deepfake": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/deepfake/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "faceswap_2d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/faceswap_2d/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "3dmm": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/3dmm/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "faceswap_3d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/faceswap_3d/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "monkey": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/monkey/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "reenact": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/reenact/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "stargan": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/stargan/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "x2face": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/x2face/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0"
}

brightness = 1.0
contrasts = [1.0, 0.5, 0.75, 1.5, 2.0]

for contrast in contrasts:
    print("\n*********************************************************")
    print("             CONTRAST: ", contrast)
    for technique in ['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']:
        checkpoint = checkpoints[technique]
        print("\n\n ================ {} CHECKPOINT ================".format(technique))
        my_break = False
        
        for fold_dir in sorted(os.listdir(checkpoint), key=lambda dir: int(dir[-1])):
            pt_file = glob(join(checkpoint, fold_dir, 'step', 'best_test_acc*'))[0]

            fold_id = pt_file.split('/')[-3][-1]

            model_ = xception(pretrained=False)
            model_.load_state_dict(torch.load(pt_file))
            model_ = model_.to(device)
            
            batch_size = 16

            test_dir = test_dir_template.format(technique)
            dataloader_test = generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, 128, batch_size, 4, adj_brightness=brightness, adj_contrast=contrast)

            criterion = nn.BCELoss().to(device)
            ####
            loss, mac_acc, mic_acc, reals, fakes, micros, macros = eval_kfold_image_stream(model_, dataloader_test, device, criterion)
            print("            pt file: ", '/'.join(pt_file.split('/')[-3:]))
            print("  FOLD: {}   ".format(fold_id))
            print("            accuracy =   {:.6f}       |  {:6f}".format(mic_acc, mac_acc))
            print("            precision =  {:.6f}  |   recall =    {:.6f}   |  f1 =    {:.6f}".format(macros[0], macros[1], macros[2]))

            if contrast == 1.0:
                step_ckcpoint = osp.split(pt_file)[0]
                step = int(osp.basename(pt_file).split('_')[-2])
                pairwise = True if 'pairwise' in pt_file else False
                metrics, bestloss = get_test_metric_by_step2(step_ckcpoint=step_ckcpoint, steps=step, pairwise=pairwise)
                acc, pre, rec, f1 = metrics['best_test_acc']['acc'], metrics['best_test_acc']['pre'], metrics['best_test_acc']['rec'], metrics['best_test_acc']['f1']

                # if abs(round(loss, 4) - round(bestloss, 4)) > 0.001:
                #     print("         Error in loss: {}, {}".format(loss, bestloss))
                #     my_break = True
                if abs(round(mic_acc, 4) - round(acc, 4)) > 0.001:
                    print("         Error in acc: {}, {}".format(mic_acc, acc))
                    my_break = True
                if abs(round(macros[0], 4) - round(pre, 4)) > 0.001:
                    print("         Error in pre: {}, {}".format(macros[0], pre))
                    my_break = True
                if abs(round(macros[1], 4) - round(rec, 4)) > 0.001:
                    print("         Error in rec: {}, {}".format(macros[1], rec))
                    my_break = True
                if abs(round(macros[2], 4) - round(f1, 4)) > 0.001:
                    print("         Error in f1: {}, {}".format(macros[2], f1))
                    my_break = True

                if my_break:
                    break

        if my_break:
            break

    if my_break:
        break


  from .autonotebook import tqdm as notebook_tqdm



*********************************************************
             CONTRAST:  1.0


Test image dataset:  10000
            pt file:  (0.0001_1.0000)_fold_0/step/best_test_acc_3400_1.000000.pt
  FOLD: 0   
            accuracy =   1.000000       |  1.000000
            precision =  1.000000  |   recall =    1.000000   |  f1 =    1.000000
Test image dataset:  10000
            pt file:  (0.0000_1.0000)_fold_1/step/best_test_acc_5000_1.000000.pt
  FOLD: 1   
            accuracy =   1.000000       |  1.000000
            precision =  1.000000  |   recall =    1.000000   |  f1 =    1.000000
Test image dataset:  10000
            pt file:  (0.0001_1.0000)_fold_2/step/best_test_acc_1000_1.000000.pt
  FOLD: 2   
            accuracy =   1.000000       |  1.000000
            precision =  1.000000  |   recall =    1.000000   |  f1 =    1.000000
Test image dataset:  10000
            pt file:  (0.0000_1.0000)_fold_3/step/best_test_acc_1200_1.000000.pt
  FOLD: 3   
            accuracy =   

KeyboardInterrupt: 