In [2]:
from model.cnn.xception_net.model import xception
from model.cnn.mesonet4.model import mesonet
import torch
from module.train_torch import calculate_metric
import random
import numpy as np
import os
import torch.nn as nn
from glob import glob
from tqdm import tqdm
from os.path import join
from utils.util import get_test_metric_by_step2
from operator import truediv
import os, sys
from random import shuffle
import os.path as osp

import torch
from torchvision import datasets, transforms

from dataloader.utils import make_weights_for_balanced_classes, make_weights_for_balanced_classes_2
from dataloader.dual_fft_dataset import DualFFTMagnitudeFeatureDataset, DualFFTMagnitudeImageDataset, TripleFFTMagnitudePhaseDataset
from dataloader.pairwise_dual_fft_dataset import PairwiseDualFFTMagnitudeFeatureDataset, PairwiseDualFFTMagnitudeImageDataset
from dataloader.pairwise_single_fft_dataset import PairwiseSingleFFTMagnitudeImageDataset
from dataloader.pairwise_triple_fft_dataset import PairwiseTripleFFTMagnitudePhaseDataset
from dataloader.triplewise_dual_fft_dataset import TriplewiseDualFFTMagnitudeImageDataset
from dataloader.transform import transform_method

from model.cnn.capsule_net.model import VggExtractor, CapsuleNet
from loss.capsule_loss import CapsuleLoss
from torch.autograd import Variable
from sklearn.metrics import accuracy_score

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def eval_kfold_capsulenet(capnet, vgg_ext, dataloader, device, capsule_loss):
    capnet.eval()
    vgg_ext.eval()

    y_label = []
    y_pred = []
    y_pred_label = []
    loss = 0
    mac_accuracy = 0
    
    for inputs, labels in dataloader:
        labels[labels > 1] = 1
        img_label = labels.numpy().astype(np.float)
        inputs, labels = inputs.to(device), labels.to(device)

        input_v = Variable(inputs)
        x = vgg_ext(input_v)
        classes, class_ = capnet(x, random=False)

        loss_dis = capsule_loss(classes, Variable(labels, requires_grad=False))
        loss_dis_data = loss_dis.item()
        output_dis = class_.data.cpu().numpy()

        output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

        for i in range(output_dis.shape[0]):
            if output_dis[i,1] >= output_dis[i,0]:
                output_pred[i] = 1.0
            else:
                output_pred[i] = 0.0

        loss += loss_dis_data
        y_label.extend(img_label)
        y_pred.extend(output_dis)
        y_pred_label.extend(output_pred)
        mac_accuracy += accuracy_score(img_label, output_pred)
        
    mac_accuracy /= len(dataloader)
    loss /= len(dataloader)
    assert len(y_label) == len(y_pred_label), "Bug"
    ######## Calculate metrics:
    # built-in methods for calculating metrics
    mic_accuracy, reals, fakes, micros, macros = calculate_metric(y_label, y_pred_label)
    return loss, mac_accuracy, mic_accuracy, reals, fakes, micros, macros

def generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, image_size, batch_size, num_workers, adj_brightness=1.0, adj_contrast=1.0):
    transform_fwd = transforms.Compose([transforms.Resize((image_size,image_size)),\
                                        transforms.Lambda(lambda img :transforms.functional.adjust_brightness(img,adj_brightness)),\
                                        transforms.Lambda(lambda img :transforms.functional.adjust_contrast(img,adj_contrast)),\
                                        transforms.ToTensor(),\
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],\
                                                             std=[0.229, 0.224, 0.225]),\
                                        ])
    # Make dataset using built-in ImageFolder function of torch
    test_dataset = datasets.ImageFolder(test_dir, transform=transform_fwd)
    assert test_dataset, "Test Dataset is empty!"
    print("Test image dataset: ", test_dataset.__len__())
    # Make dataloader
    dataloader_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    return dataloader_test

################### FROM DEEPFAKE TO OTHER DATASET:
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
device = torch.device('cuda')

test_dir_template = "/mnt/disk1/doan/phucnp/Dataset/my_extend_data/extend_data_train/{}/test"

checkpoints = {
    "deepfake": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/deepfake/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "faceswap_2d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/faceswap_2d/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "3dmm": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/3dmm/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "faceswap_3d":"/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/faceswap_3d/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "monkey": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/monkey/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "reenact": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/reenact/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "stargan": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/stargan/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "x2face": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/x2face/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0"
}

brightnesses = [1.0, 0.5, 0.75, 1.5, 2.0]
contrast = 1.0

for brightness in brightnesses:
    print("\n*********************************************************")
    print("             BRIGHTNESS: ", brightness)
    for technique in ['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']:
        checkpoint = checkpoints[technique]
        print("\n\n ================ {} CHECKPOINT ================".format(technique))
        my_break = False
        
        for fold_dir in sorted(os.listdir(checkpoint), key=lambda dir: int(dir[-1])):
            pt_file = glob(join(checkpoint, fold_dir, 'step', 'bestcapsule_test_acc*'))[0]

            fold_id = fold_dir[-1]

            vgg_ext = VggExtractor().to(device)
            capnet = CapsuleNet(num_class=2, device=device).to(device)

            capsule_ckcpoint = glob(join(checkpoint, fold_dir, 'step', 'bestcapsule_test_acc*'))[0]
            vggext_ckcpoint = glob(join(checkpoint, fold_dir, 'step', 'bestvgg_test_acc*'))[0]
            vgg_ext.load_state_dict(torch.load(vggext_ckcpoint))
            capnet.load_state_dict(torch.load(capsule_ckcpoint))
            capsule_loss = CapsuleLoss().to(device)
            
            batch_size = 16

            test_dir = test_dir_template.format(technique)
            dataloader_test = generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, 128, batch_size, 4, adj_brightness=brightness, adj_contrast=contrast)

            ####
            loss, mac_acc, mic_acc, reals, fakes, micros, macros = eval_kfold_capsulenet(capnet, vgg_ext, dataloader_test, device, capsule_loss)
            print("            pt file: ", '/'.join(pt_file.split('/')[-3:]))
            print("  FOLD: {}   ".format(fold_id))
            print("            accuracy =   {:.6f}       |  {:6f}".format(mic_acc, mac_acc))
            print("            precision =  {:.6f}  |   recall =    {:.6f}   |  f1 =    {:.6f}".format(macros[0], macros[1], macros[2]))

            if brightness == 1.0:
                step_ckcpoint = osp.split(pt_file)[0]
                step = int(osp.basename(pt_file).split('_')[-2])
                pairwise = True if 'pairwise' in pt_file else False
                metrics, bestloss = get_test_metric_by_step2(step_ckcpoint=step_ckcpoint, steps=step, pairwise=pairwise)
                acc, pre, rec, f1 = metrics['best_test_acc']['acc'], metrics['best_test_acc']['pre'], metrics['best_test_acc']['rec'], metrics['best_test_acc']['f1']

                # if abs(round(loss, 4) - round(bestloss, 4)) > 0.001:
                #     print("         Error in loss: {}, {}".format(loss, bestloss))
                #     my_break = True
                if abs(round(mic_acc, 4) - round(acc, 4)) > 0.001:
                    print("         Error in acc: {}, {}".format(mic_acc, acc))
                    my_break = True
                if abs(round(macros[0], 4) - round(pre, 4)) > 0.001:
                    print("         Error in pre: {}, {}".format(macros[0], pre))
                    my_break = True
                if abs(round(macros[1], 4) - round(rec, 4)) > 0.001:
                    print("         Error in rec: {}, {}".format(macros[1], rec))
                    my_break = True
                if abs(round(macros[2], 4) - round(f1, 4)) > 0.001:
                    print("         Error in f1: {}, {}".format(macros[2], f1))
                    my_break = True

                if my_break:
                    break

        if my_break:
            break

    if my_break:
        break



*********************************************************
             BRIGHTNESS:  1.0


Test image dataset:  10000
            pt file:  (1.9607_0.9026_0.9421)_fold_0/step/bestcapsule_test_acc_3000_0.942100.pt
  FOLD: 0   
            accuracy =   0.942100       |  0.942100
            precision =  0.944215  |   recall =    0.942100   |  f1 =    0.942031
Test image dataset:  10000
            pt file:  (1.9594_0.8933_0.9810)_fold_1/step/bestcapsule_test_acc_11600_0.981000.pt
  FOLD: 1   
            accuracy =   0.981000       |  0.981000
            precision =  0.981079  |   recall =    0.981000   |  f1 =    0.980999
Test image dataset:  10000
            pt file:  (1.9577_0.9029_0.9557)_fold_2/step/bestcapsule_test_acc_15000_0.955700.pt
  FOLD: 2   
            accuracy =   0.955700       |  0.955700
            precision =  0.955792  |   recall =    0.955700   |  f1 =    0.955698
Test image dataset:  10000
            pt file:  (1.9603_0.8770_0.9690)_fold_3/step/bestcapsule_test