In [1]:
from model.cnn.xception_net.model import xception
from model.cnn.mesonet4.model import mesonet
import torch
from module.train_torch import calculate_metric
from dataloader.gen_dataloader import *
import random
import numpy as np
import os
import torch.nn as nn
from glob import glob
from tqdm import tqdm
from os.path import join
from utils.util import get_test_metric_by_step

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

def eval_kfold_image_stream(model ,dataloader, device, criterion, adj_brightness=1.0, adj_contrast=1.0 ):
    loss = 0
    mac_accuracy = 0
    model.eval()
    y_label = []
    y_pred_label = []
    with torch.no_grad():
        for inputs, labels in dataloader:
            # Push to device
            y_label.extend(labels.cpu().numpy().astype(np.float64))
            inputs, labels = inputs.float().to(device), labels.float().to(device)

            # Forward network
            logps = model.forward(inputs)
            if len(logps.shape) == 2:
                logps = logps.squeeze(dim=1)

            # Loss in a batch
            batch_loss = criterion(logps, labels)
            # Cumulate into running val loss
            loss += batch_loss.item()

            # Find accuracy
            equals = (labels == (logps > 0.5))
            mac_accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
            #
            logps_cpu = logps.cpu().numpy()
            pred_label = (logps_cpu > 0.5)
            y_pred_label.extend(pred_label)
            
    assert len(y_label) == len(y_pred_label), "Bug"
    ######## Calculate metrics:
    loss /= len(dataloader)
    mac_accuracy /= len(dataloader)
    # built-in methods for calculating metrics
    mic_accuracy, reals, fakes, micros, macros = calculate_metric(y_label, y_pred_label)
    return loss, mac_accuracy, mic_accuracy, reals, fakes, micros, macros

################### FROM DEEPFAKE TO OTHER DATASET:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
device = torch.device('cuda')

test_dir_template = "/mnt/disk1/doan/phucnp/Dataset/my_extend_data/extend_data_train/{}/test"

checkpoints = {
    "deepfake": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/deepfake/kfold_meso4/lr6e-05_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "faceswap_2d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/faceswap_2d/kfold_meso4/lr0.0002_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "3dmm": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/3dmm/kfold_meso4/lr0.0002_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "faceswap_3d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/faceswap_3d/kfold_meso4/lr0.0002_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "monkey": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/monkey/kfold_meso4/lr6e-05_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "reenact": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/reenact/kfold_meso4/lr0.0002_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "stargan": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/stargan/kfold_meso4/lr0.0002_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0",
    "x2face": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/dl_technique/checkpoint/extend_data/x2face/kfold_meso4/lr0.0002_batch16_esnone_lossbce_nf5_trick1_seed0_drmlp0.0_aug0"
}

for technique in ['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']:
    checkpoint = checkpoints[technique]
    print("\n\n ==================================== {} CHECKPOINT =========================".format(technique))
    my_break = False
    for idx, other_technique in enumerate(['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']):
        print("\n************************ So sánh với technique: ", other_technique)
        for pt_file in glob(join(checkpoint, '*/*/*.pt')):
            if 'best_val_loss' not in pt_file:
                continue
                
            fold_id = pt_file.split('/')[-3][-1]

            model_ = mesonet(image_size=128)
            model_.load_state_dict(torch.load(pt_file))
            model_ = model_.to(device)
            
            batch_size = 16

            test_dir = test_dir_template.format(other_technique)
            dataloader_test = generate_test_dataloader_single_cnn_stream_for_kfold(test_dir, 128, batch_size, 4)

            criterion = nn.BCELoss().to(device)
            ####
            loss, mac_acc, mic_acc, reals, fakes, micros, macros = eval_kfold_image_stream(model_, dataloader_test, device, criterion)
            print("            pt file: ", '/'.join(pt_file.split('/')[-3:]))
            print("        {}. FOLD: {}    -   cross with dataset: {}".format(fold_id, idx, other_technique))
            print("            accuracy =   {:.6f}       |  {:6f}".format(mic_acc, mac_acc))
            print("            precision =  {:.6f}  |   recall =    {:.6f}   |  f1 =    {:.6f}".format(macros[0], macros[1], macros[2]))

            if other_technique == technique:
                step_ckcpoint = osp.split(pt_file)[0]
                step = int(osp.basename(pt_file).split('_')[-2])
                pairwise = True if 'pairwise' in pt_file else False
                metrics, bestloss = get_test_metric_by_step(step_ckcpoint=step_ckcpoint, steps=step, pairwise=pairwise)
                acc, pre, rec, f1 = metrics['best_val_bceloss']['acc'], metrics['best_val_bceloss']['pre'], metrics['best_val_bceloss']['rec'], metrics['best_val_bceloss']['f1']

                # if abs(round(loss, 4) - round(bestloss, 4)) > 0.001:
                #     print("         Error in loss: {}, {}".format(loss, bestloss))
                #     my_break = True
                if abs(round(mic_acc, 4) - round(acc, 4)) > 0.001:
                    print("         Error in acc: {}, {}".format(mic_acc, acc))
                    my_break = True
                if abs(round(macros[0], 4) - round(pre, 4)) > 0.001:
                    print("         Error in pre: {}, {}".format(macros[0], pre))
                    my_break = True
                if abs(round(macros[1], 4) - round(rec, 4)) > 0.001:
                    print("         Error in rec: {}, {}".format(macros[1], rec))
                    my_break = True
                if abs(round(macros[2], 4) - round(f1, 4)) > 0.001:
                    print("         Error in f1: {}, {}".format(macros[2], f1))
                    my_break = True

                if my_break:
                    break

        if my_break:
            break

    if my_break:
        break

  from .autonotebook import tqdm as notebook_tqdm





************************ So sánh với technique:  deepfake
Test image dataset:  10000
            pt file:  (0.0983_0.9043_0.9128)_fold_3/step/best_val_loss_17600_0.098254.pt
        3. FOLD: 0    -   cross with dataset: deepfake
            accuracy =   0.904300       |  0.904300
            precision =  0.910634  |   recall =    0.904300   |  f1 =    0.903930
Test image dataset:  10000
            pt file:  (0.0982_0.8780_0.8948)_fold_2/step/best_val_loss_19000_0.098220.pt
        2. FOLD: 0    -   cross with dataset: deepfake
            accuracy =   0.878000       |  0.878000
            precision =  0.885869  |   recall =    0.878000   |  f1 =    0.877375
Test image dataset:  10000
            pt file:  (0.2025_0.8786_0.9478)_fold_4/step/best_val_loss_13200_0.202490.pt
        4. FOLD: 0    -   cross with dataset: deepfake
            accuracy =   0.878600       |  0.878600
            precision =  0.888445  |   recall =    0.878600   |  f1 =    0.877826
Test image dataset:  100