In [1]:
import torch
import random
import numpy as np
import os
import torch.nn as nn
from glob import glob
from tqdm import tqdm
from os.path import join
import os.path as osp
import pickle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

def calculate_metric(y_label, y_pred_label):
    mic_accuracy = accuracy_score(y_label, y_pred_label)
    macro_precision = precision_score(y_label, y_pred_label, average='macro')
    macro_recall = recall_score(y_label, y_pred_label, average='macro')
    macro_f1 = f1_score(y_label, y_pred_label, average='macro')
    return mic_accuracy, macro_precision, macro_recall, macro_f1

def eval_spectrum(data, model_file):
    pkl_file = open(data, 'rb')
    data = pickle.load(pkl_file)
    pkl_file.close()
    X = data["data"]
    y = data["label"]
    with open(model_file, 'rb') as f:
        svclassifier_r = pickle.load(f)
    # print("Loaded.")
    SVM_score = svclassifier_r.score(X, y)
    # print("accuracy by score function: " + str(SVM_score))
    ##### CALCULATE METRIC:
    y_pred = svclassifier_r.predict(X)
    acc, pre, rec, f1 = calculate_metric(y, y_pred)
    # print("accuracy: " + str(acc))
    # print("precision: " + str(pre))
    # print("recall: " + str(rec))
    # print("f1 score: " + str(f1))
    return SVM_score, acc, pre, rec, f1

test_dir_template = "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/spectrum/output_features/{}/test.pkl"

checkpoints = {
    "deepfake": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/deepfake/spectrum/c_2.000000",
    "faceswap_2d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/faceswap_2d/spectrum/c_2.000000",
    "3dmm": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/3dmm/spectrum/c_2.000000",
    "faceswap_3d": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/faceswap_3d/spectrum/c_2.000000",
    "monkey": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/monkey/spectrum/c_2.000000",
    "reenact": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/reenact/spectrum/c_2.000000",
    "stargan": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/stargan/spectrum/c_2.000000",
    "x2face": "/mnt/disk1/doan/phucnp/Graduation_Thesis/my_thesis/forensics/ml_technique/checkpoint/x2face/spectrum/c_2.000000"
}

for technique in ['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']:
    checkpoint = checkpoints[technique]
    print("\n\n ==================================== {} CHECKPOINT =========================".format(technique))
    for idx, other_technique in enumerate(['deepfake', '3dmm', 'faceswap_2d', 'faceswap_3d', 'monkey', 'reenact', 'stargan', 'x2face']):
        print("\n************************ So sánh với technique: ", other_technique)
        my_break = False
        for pkl_file in glob(join(checkpoint, '*/*.pkl')):
            if 'model' not in pkl_file:
                continue
                
            fold_id = pkl_file.split('/')[-2][-1]

            test_features = test_dir_template.format(other_technique)          
            ####
            SVM_score, acc, pre, rec, f1 = eval_spectrum(test_features, pkl_file)
            print("            pt file: ", '/'.join(pkl_file.split('/')[-2:]))
            print("    FOLD: {}    -   cross with dataset: {}".format(fold_id, other_technique))
            print("            accuracy =   {:.6f}       |  {:6f}".format(SVM_score, acc))
            print("            precision =  {:.6f}  |   recall =    {:.6f}   |  f1 =    {:.6f}".format(pre, rec, f1))

            if other_technique == technique:
                from_acc = float(pkl_file.split('/')[-2].split('_')[0])
                # if abs(round(loss, 4) - round(bestloss, 4)) > 0.001:
                #     print("         Error in loss: {}, {}".format(loss, bestloss))
                #     my_break = True
                if abs(round(from_acc, 4) - round(acc, 4)) > 0.001:
                    print("         Error in acc: {}, {}".format(from_acc, acc))
                    my_break = True

                if my_break:
                    break

        if my_break:
            break

    if my_break:
        break

  from .autonotebook import tqdm as notebook_tqdm





************************ So sánh với technique:  deepfake
            pt file:  0.987_fold_4/model.pkl
    FOLD: 4    -   cross with dataset: deepfake
            accuracy =   0.987000       |  0.987000
            precision =  0.987105  |   recall =    0.986990   |  f1 =    0.986999
            pt file:  0.9858_fold_2/model.pkl
    FOLD: 2    -   cross with dataset: deepfake
            accuracy =   0.985800       |  0.985800
            precision =  0.985942  |   recall =    0.985788   |  f1 =    0.985799
            pt file:  0.9854_fold_1/model.pkl
    FOLD: 1    -   cross with dataset: deepfake
            accuracy =   0.985400       |  0.985400
            precision =  0.985555  |   recall =    0.985388   |  f1 =    0.985399
            pt file:  0.988_fold_3/model.pkl
    FOLD: 3    -   cross with dataset: deepfake
            accuracy =   0.988000       |  0.988000
            precision =  0.988111  |   recall =    0.987990   |  f1 =    0.987999
            pt file:  0.9876_