In [None]:
from model.cnn.xception_net.model import xception
import torch
from module.train_torch import calculate_metric
from dataloader.gen_dataloader import *
import random
import numpy as np
import os
import torch.nn as nn
from glob import glob
from tqdm import tqdm
from os.path import join
from utils.util import get_test_metric_by_step

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def eval_kfold_image_stream(model ,dataloader, device, criterion, adj_brightness=1.0, adj_contrast=1.0 ):
    loss = 0
    mac_accuracy = 0
    model.eval()
    y_label = []
    y_pred_label = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            # Push to device
            y_label.extend(labels.cpu().numpy().astype(np.float64))
            inputs, labels = inputs.float().to(device), labels.float().to(device)

            # Forward network
            logps = model.forward(inputs)
            if len(logps.shape) == 2:
                logps = logps.squeeze(dim=1)

            # Loss in a batch
            batch_loss = criterion(logps, labels)
            # Cumulate into running val loss
            loss += batch_loss.item()

            # Find accuracy
            equals = (labels == (logps > 0.5))
            mac_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
            #
            logps_cpu = logps.cpu().numpy()
            pred_label = (logps_cpu > 0.5)
            y_pred_label.extend(pred_label)
            
    assert len(y_label) == len(y_pred_label), "Bug"
    ######## Calculate metrics:
    loss /= len(dataloader)
    mac_accuracy /= len(dataloader)
    # built-in methods for calculating metrics
    mic_accuracy, reals, fakes, micros, macros = calculate_metric(y_label, y_pred_label)
    return loss, mac_accuracy, mic_accuracy, reals, fakes, micros, macros

################### FROM DEEPFAKE TO OTHER DATASET:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda')

test_dir_template = "/mnt/disk1/doan/phucnp/Dataset/my_extend_data/extend_data_train/{}/test"

checkpoints = {
    "deepfake": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/deepfake/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "faceswap_2d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/faceswap_2d/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "3dmm": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/3dmm/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "faceswap_3d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/faceswap_3d/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "monkey": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/monkey/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "reenact": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/reenact/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "stargan": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/stargan/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0",
    "x2face": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/x2face/kfold_xception/lr0.0002_batch16_esnone_lossbce_nf5_trick1_pre1_seed0_drmlp0.0_aug0"
}

for technique in ['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']:
    checkpoint = checkpoints[technique]
    print("\n\n ==================================== {} CHECKPOINT =========================".format(technique))
    for idx, other_technique in enumerate(['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']):
        print("\n************************ So sánh với technique: ", other_technique)
        my_break = False
        for pt_file in glob(join(checkpoint, '*/*/*.pt')):
            if 'best_val_loss' not in pt_file:
                continue
                
            fold_id = pt_file.split('/')[-3][-1]

            model_ = xception(pretrained=False)
            model_.load_state_dict(torch.load(pt_file))
            model_ = model_.to(device)
            
            batch_size = 16

            test_dir = test_dir_template.format(other_technique)
            dataloader_test = generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, 128, batch_size, 4)

            criterion = nn.BCELoss().to(device)
            ####
            loss, mac_acc, mic_acc, reals, fakes, micros, macros = eval_kfold_image_stream(model_, dataloader_test, device, criterion)
            print("            pt file: ", '/'.join(pt_file.split('/')[-3:]))
            print("        {}. FOLD: {}    -   cross with dataset: {}".format(fold_id, idx, other_technique))
            print("            accuracy =   {:.6f}       |  {:6f}".format(mic_acc, mac_acc))
            print("            precision =  {:.6f}  |   recall =    {:.6f}   |  f1 =    {:.6f}".format(macros[0], macros[1], macros[2]))

            if other_technique == technique:
                step_ckcpoint = osp.split(pt_file)[0]
                step = int(osp.basename(pt_file).split('_')[-2])
                pairwise = True if 'pairwise' in pt_file else False
                metrics, bestloss = get_test_metric_by_step(step_ckcpoint=step_ckcpoint, steps=step, pairwise=pairwise)
                acc, pre, rec, f1 = metrics['best_val_bceloss']['acc'], metrics['best_val_bceloss']['pre'], metrics['best_val_bceloss']['rec'], metrics['best_val_bceloss']['f1']

                # if abs(round(loss, 4) - round(bestloss, 4)) > 0.001:
                #     print("         Error in loss: {}, {}".format(loss, bestloss))
                #     my_break = True
                if abs(round(mic_acc, 4) - round(acc, 4)) > 0.001:
                    print("         Error in acc: {}, {}".format(mic_acc, acc))
                    my_break = True
                if abs(round(macros[0], 4) - round(pre, 4)) > 0.001:
                    print("         Error in pre: {}, {}".format(macros[0], pre))
                    my_break = True
                if abs(round(macros[1], 4) - round(rec, 4)) > 0.001:
                    print("         Error in rec: {}, {}".format(macros[1], rec))
                    my_break = True
                if abs(round(macros[2], 4) - round(f1, 4)) > 0.001:
                    print("         Error in f1: {}, {}".format(macros[2], f1))
                    my_break = True

                if my_break:
                    break

        if my_break:
            break

    if my_break:
        break

In [3]:
from model.cnn.xception_net.model import xception
import torch
from module.train_torch import calculate_metric
from dataloader.gen_dataloader import *
import random
import numpy as np
import os
import torch.nn as nn
from glob import glob
from tqdm import tqdm
from os.path import join
from utils.util import get_test_metric_by_step
from model.cnn.capsule_net.model import VggExtractor, CapsuleNet
from loss.capsule_loss import CapsuleLoss
from torch.autograd import Variable
from sklearn.metrics import accuracy_score

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def eval_kfold_capsulenet(capnet, vgg_ext, dataloader, device, capsule_loss):
    capnet.eval()
    vgg_ext.eval()

    y_label = []
    y_pred = []
    y_pred_label = []
    loss = 0
    mac_accuracy = 0
    
    for inputs, labels in dataloader:
        labels[labels > 1] = 1
        img_label = labels.numpy().astype(np.float)
        inputs, labels = inputs.to(device), labels.to(device)

        input_v = Variable(inputs)
        x = vgg_ext(input_v)
        classes, class_ = capnet(x, random=False)

        loss_dis = capsule_loss(classes, Variable(labels, requires_grad=False))
        loss_dis_data = loss_dis.item()
        output_dis = class_.data.cpu().numpy()

        output_pred = np.zeros((output_dis.shape[0]), dtype=np.float)

        for i in range(output_dis.shape[0]):
            if output_dis[i,1] >= output_dis[i,0]:
                output_pred[i] = 1.0
            else:
                output_pred[i] = 0.0

        loss += loss_dis_data
        y_label.extend(img_label)
        y_pred.extend(output_dis)
        y_pred_label.extend(output_pred)
        mac_accuracy += accuracy_score(img_label, output_pred)
        
    mac_accuracy /= len(dataloader)
    loss /= len(dataloader)
    assert len(y_label) == len(y_pred_label), "Bug"
    ######## Calculate metrics:
    # built-in methods for calculating metrics
    mic_accuracy, reals, fakes, micros, macros = calculate_metric(y_label, y_pred_label)
    return loss, mac_accuracy, mic_accuracy, reals, fakes, micros, macros


################### FROM DEEPFAKE TO OTHER DATASET:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
device = torch.device('cuda')

test_dir_template = "/mnt/disk1/doan/phucnp/Dataset/my_extend_data/extend_data_train/{}/test"

checkpoints = {
    "deepfake": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/deepfake/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "faceswap_2d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/faceswap_2d/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "3dmm": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/3dmm/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "faceswap_3d":"/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/faceswap_3d/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "monkey": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/monkey/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "reenact": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/reenact/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "stargan": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/stargan/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0",
    "x2face": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extenddata/x2face/kfold_capsule/lr0.001_batch16_esnone_nf5_trick1_beta0.9_dropout0.05_seed0_drmlp0.0_aug0"
}

for technique in ['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']:
    checkpoint = checkpoints[technique]
    print("\n\n ==================================== {} CHECKPOINT =========================".format(technique))
    for idx, other_technique in enumerate(['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']):
        print("\n************************ So sánh với technique: ", other_technique)
        my_break = False
        for fold_dir in glob(join(checkpoint, '*')):
                
            fold_id = fold_dir[-1]

            vgg_ext = VggExtractor().to(device)
            capnet = CapsuleNet(num_class=2, device=device).to(device)
            capsule_ckcpoint = glob(join(fold_dir, 'step', 'bestcapsule_val_loss*.pt'))[0]
            vggext_ckcpoint = glob(join(fold_dir, 'step', 'bestvgg_val_loss*.pt'))[0]
            vgg_ext.load_state_dict(torch.load(vggext_ckcpoint))
            capnet.load_state_dict(torch.load(capsule_ckcpoint))
            
            batch_size = 16

            test_dir = test_dir_template.format(other_technique)
            dataloader_test = generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, 128, batch_size, 4)

            capsule_loss = CapsuleLoss().to(device)
            ####
            loss, mac_acc, mic_acc, reals, fakes, micros, macros = eval_kfold_capsulenet(capnet, vgg_ext, dataloader_test, device, capsule_loss)
            print("            pt file: ", '/'.join(capsule_ckcpoint.split('/')[-3:]))
            print("        {}. FOLD: {}    -   cross with dataset: {}".format(fold_id, idx, other_technique))
            print("            accuracy =   {:.6f}       |  {:6f}".format(mic_acc, mac_acc))
            print("            precision =  {:.6f}  |   recall =    {:.6f}   |  f1 =    {:.6f}".format(macros[0], macros[1], macros[2]))

            if other_technique == technique:
                pt_file = capsule_ckcpoint
                step_ckcpoint = osp.split(pt_file)[0]
                step = int(osp.basename(pt_file).split('_')[-2])
                pairwise = True if 'pairwise' in pt_file else False
                metrics, bestloss = get_test_metric_by_step(step_ckcpoint=step_ckcpoint, steps=step, pairwise=pairwise)
                acc, pre, rec, f1 = metrics['best_val_bceloss']['acc'], metrics['best_val_bceloss']['pre'], metrics['best_val_bceloss']['rec'], metrics['best_val_bceloss']['f1']

                # if abs(round(loss, 4) - round(bestloss, 4)) > 0.001:
                #     print("         Error in loss: {}, {}".format(loss, bestloss))
                #     my_break = True
                if abs(round(mic_acc, 4) - round(acc, 4)) > 0.001:
                    print("         Error in acc: {}, {}".format(mic_acc, acc))
                    my_break = True
                if abs(round(macros[0], 4) - round(pre, 4)) > 0.001:
                    print("         Error in pre: {}, {}".format(macros[0], pre))
                    my_break = True
                if abs(round(macros[1], 4) - round(rec, 4)) > 0.001:
                    print("         Error in rec: {}, {}".format(macros[1], rec))
                    my_break = True
                if abs(round(macros[2], 4) - round(f1, 4)) > 0.001:
                    print("         Error in f1: {}, {}".format(macros[2], f1))
                    my_break = True

                if my_break:
                    break
                    
        if my_break:
            break

    if my_break:
        break




************************ So sánh với technique:  deepfake
Test image dataset:  10000
            pt file:  (1.9607_0.9026_0.9421)_fold_0/step/bestcapsule_val_loss_12600_1.960673.pt
        0. FOLD: 0    -   cross with dataset: deepfake
            accuracy =   0.902600       |  0.902600
            precision =  0.917092  |   recall =    0.902600   |  f1 =    0.901747
Test image dataset:  10000
            pt file:  (1.9594_0.9057_0.9684)_fold_4/step/bestcapsule_val_loss_14000_1.959382.pt
        4. FOLD: 0    -   cross with dataset: deepfake
            accuracy =   0.905700       |  0.905700
            precision =  0.919757  |   recall =    0.905700   |  f1 =    0.904904
Test image dataset:  10000
            pt file:  (1.9594_0.8933_0.9810)_fold_1/step/bestcapsule_val_loss_11000_1.959354.pt
        1. FOLD: 0    -   cross with dataset: deepfake
            accuracy =   0.893300       |  0.893300
            precision =  0.911121  |   recall =    0.893300   |  f1 =    0.892131
Tes