In [1]:
import torch
import os
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import cv2
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchsummary import summary
from tqdm import tqdm
from matplotlib.pyplot import figure
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
torch.cuda.empty_cache()
from sklearn.metrics import f1_score, recall_score, precision_score, roc_curve, roc_auc_score
import seaborn as sns
import re

sns.set()
plt.rc('font', family = 'serif')

In [2]:
def get_binary_testset(dataset_name):
    """
    dataset_name -> trainset, valset, testset
    """
    trainset, valset, testset = None, None, None
    dataset_name = dataset_name.lower()
    path_2_root = "../.."

    # if not "oc" in dataset_name:
    #     if "dar" in dataset_name and "oc" not in dataset_name: # DariusAf_Deepfake_Database
    #         trainset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database/train_test"
    #         testset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database/validation"

    #     if "avg" in dataset_name: # Celeb-avg-30
    #         if "celeb" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-avg-30"
    #             testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-avg-30-test"
                
    #         elif "deepfake" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/Deepfakes_avg"
    #         elif "face2face" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/Face2Face_avg"
    #         elif "shift" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/FaceShifter_avg"
    #         elif "swap" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/FaceSwap_avg"
    #         elif "neural" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/NeuralTextures_avg"

    #     elif "rnd" in dataset_name: # Celeb-rnd-30
    #         if "celeb" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-rnd-30"
    #             testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-rnd-30-test"
                
    #         elif "deepfake" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/Deepfakes_rnd"
    #         elif "face2face" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/Face2Face_rnd"
    #         elif "shift" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/FaceShifter_rnd"
    #         elif "swap" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/FaceSwap_rnd"
    #         elif "neural" in dataset_name:
    #             trainset = f"{path_2_root}/_DATASETS/FF/NeuralTextures_rnd"

    #     elif "diff" in dataset_name: # Celeb-diff-30
    #         trainset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-diff-30"
    #         testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-diff-30-test"

    #     elif "ff" in dataset_name and "deepfake" in dataset_name:
    #         trainset = f"{path_2_root}/_DATASETS/FF/Deepfakes_avg"

    #     elif "ff" in dataset_name and "deepfake" in dataset_name:
    #         trainset = f"{path_2_root}/_DATASETS/FF/Deepfakes_avg"

    # else: # oc
    if "dar" in dataset_name: # DariusAf_Deepfake_Database
        trainset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database-OC/real-train/" # unary
        testset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database-OC/realfake-test/" # binary

    elif "avg" in dataset_name:
        trainset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-avg-30-OC-real-train/" # unary
        # valset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-avg-30-OC-real-val/"  # unary
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-avg-30-OC-realfake-test/" # binary

    elif "rnd" in dataset_name: # Celeb-rnd-30-OC
        trainset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-rnd-30-OC-real-train/" # unary
        # valset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-rnd-30-OC-real-val/"  # unary
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-rnd-30-OC-realfake-test/" # binary

    testset = trainset if testset==None else testset
    valset = trainset if valset==None else valset
    return trainset, valset, testset

In [3]:
batch_size = 128
epochs = 1000
img_size = 100
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
transform_train = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

transform_test = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [5]:
class vae(nn.Module):
    def __init__(self):
        super(vae, self).__init__()

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3,stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3,stride=2, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(32)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3,stride=1, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, kernel_size=3,stride=2, padding=1, bias=False)
        self.bn4 = nn.BatchNorm2d(32)

        self.fc1 = nn.Linear(25 * 25 * 32, 1000)
        self.fc1_bn = nn.BatchNorm1d(1000)
        self.fc2_mean = nn.Linear(1000, 100)
        self.fc2_logvar = nn.Linear(1000, 100)

        self.fc3 = nn.Linear(100, 1000)
        self.fc3_bn = nn.BatchNorm1d(1000)
        self.fc4 = nn.Linear(1000, 25 * 25 * 32)
        self.fc4_bn = nn.BatchNorm1d(25 * 25 * 32)

        self.relu = nn.ReLU()

        self.conv5 = nn.ConvTranspose2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn5 = nn.BatchNorm2d(64)
        self.conv6 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        self.bn6 = nn.BatchNorm2d(32)
        self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn7 = nn.BatchNorm2d(16)
        self.conv8 = nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False)
        
    def encode(self, data):
        conv1 = self.relu(self.bn1(self.conv1(data)))
        conv2 = self.relu(self.bn2(self.conv2(conv1)))
        conv3 = self.relu(self.bn3(self.conv3(conv2)))
        conv4 = self.relu(self.bn4(self.conv4(conv3)))

        fc1 = self.relu(self.fc1_bn(self.fc1(conv4.view(-1, 25 * 25 * 32))))
        mean = self.fc2_mean(fc1)
        logvar = self.fc2_logvar(fc1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        std = 0.5 * torch.exp(logvar)
        z = (std.data.new(std.size()).normal_()) * std + mean
        return z
    
    def decode(self, z):
        fc3 = self.relu(self.fc3_bn(self.fc3(z)))
        fc4 = self.relu(self.fc4_bn(self.fc4(fc3)))
        conv5 = self.relu(self.bn5(self.conv5(fc4.view(-1, 32, 25, 25))))
        conv6 = self.relu(self.bn6(self.conv6(conv5)))
        conv7 = self.relu(self.bn7(self.conv7(conv6)))
        conv8 = self.conv8(conv7)
        return conv8.view(-1, 3, img_size, img_size)
    
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), mean, logvar

def loss_function(recon_x, x, mean, logvar):
    mse_loss = nn.MSELoss(reduction="sum")
    reconstruction_loss = mse_loss(recon_x, x)
    kld_loss = -0.5 * torch.sum(1+logvar-torch.exp(logvar)-mean**2)
    return reconstruction_loss + kld_loss

In [6]:
# TRAIN DFDB
trainset, valset, testset = get_binary_testset("oc dar") 
dataset_name =  "DFDB"
train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(trainset, 
                         transform=transform_train),
                         batch_size=batch_size, 
                         shuffle=True,
                         drop_last = True)
fakedect = vae().to(device)
optimizer = optim.Adam(fakedect.parameters(), lr=1e-4)
history_dfdb = []
# LOAD OLD WEIGHTS
weights_path = f"../../_WEIGHTS/oc_fakedect/OCFD1_Retrained_{dataset_name}.pkl"
fakedect = torch.load(weights_path)
# TRAIN LOOP
n_dpoints = len(train_loader.dataset)
for epoch in range(epochs):
    fakedect.train() # train mode
    train_loss = 0
    for batch_idx, (inputs, _) in tqdm(enumerate(train_loader), total=n_dpoints//batch_size-1):
        inputs = inputs.to(device)
        gen_imgs, mean, logvar = fakedect(inputs)
        loss = loss_function(gen_imgs, inputs, mean, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    L = train_loss / n_dpoints
    history_dfdb += [L]
    print(f"Epoch={epoch}\tloss={L:.4f}", end="\r")
weights_path = f"../../_WEIGHTS/oc_fakedect/OCFD1_Retrained2_{dataset_name}.pkl"
torch.save(fakedect, weights_path)

# TRAIN CDFv2 RND
fakedect = vae().to(device) # RESET WEIGHTS
trainset, valset, testset = get_binary_testset("oc rnd") 
dataset_name =  "RND"
train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(trainset, 
                         transform=transform_train),
                         batch_size=batch_size, 
                         shuffle=True,
                         drop_last = True)
optimizer = optim.Adam(fakedect.parameters(), lr=1e-4)
history_rnd = []
# TRAIN LOOP
n_dpoints = len(train_loader.dataset)
for epoch in range(epochs):
    fakedect.train() # train mode
    train_loss = 0
    for batch_idx, (inputs, _) in tqdm(enumerate(train_loader), total=n_dpoints//batch_size-1):
        inputs = inputs.to(device)
        gen_imgs, mean, logvar = fakedect(inputs)
        loss = loss_function(gen_imgs, inputs, mean, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    L = train_loss / n_dpoints
    history_rnd += [L]
    print(f"Epoch={epoch}\tloss={L:.4f}", end="\r")
weights_path = f"../../_WEIGHTS/oc_fakedect/OCFD1_Retrained2_{dataset_name}.pkl"
torch.save(fakedect, weights_path)

# TRAIN CDFv2 AVG
fakedect = vae().to(device) # RESET WEIGHTS
trainset, valset, testset = get_binary_testset("oc avg") 
dataset_name =  "AVG"
train_loader = torch.utils.data.DataLoader(
    datasets.ImageFolder(trainset, 
                         transform=transform_train),
                         batch_size=batch_size, 
                         shuffle=True,
                         drop_last = True)
optimizer = optim.Adam(fakedect.parameters(), lr=1e-4)
history_avg = []
# TRAIN LOOP
n_dpoints = len(train_loader.dataset)
for epoch in range(epochs):
    fakedect.train() # train mode
    train_loss = 0
    for batch_idx, (inputs, _) in tqdm(enumerate(train_loader), total=n_dpoints//batch_size-1):
        inputs = inputs.to(device)
        gen_imgs, mean, logvar = fakedect(inputs)
        loss = loss_function(gen_imgs, inputs, mean, logvar)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    L = train_loss / n_dpoints
    history_avg += [L]
    print(f"Epoch={epoch}\tloss={L:.4f}", end="\r")
weights_path = f"../../_WEIGHTS/oc_fakedect/OCFD1_Retrained2_{dataset_name}.pkl"
torch.save(fakedect, weights_path)
fakedect = vae().to(device) # RESET WEIGHTS
    # if L < min_loss:
    #     min_loss = L
    #     torch.save(fakedect, f"../../_WEIGHTS/oc_fakedect/OC_FD_e{epoch+1}_l{int(L)}_{dataset_name}.pkl")
    #     print(f"../../_WEIGHTS/oc_fakedect/OC_FD_e{epoch+1}_l{int(L)}_{dataset_name}.pkl")
    #     early_stop_coutner = 0
    # else:
    #     early_stop_coutner += 1
    
    # if early_stop_coutner>20:
    #     break
# print(summary(fakedect, (3, img_size, img_size)))

89it [00:25,  3.54it/s]                        


Epoch=0	loss=1285.3620

89it [00:21,  4.07it/s]                        


Epoch=1	loss=1288.9911

89it [00:21,  4.21it/s]                        


Epoch=2	loss=1284.2431

89it [00:21,  4.12it/s]                        


Epoch=3	loss=1289.4821

89it [00:21,  4.07it/s]                        


Epoch=4	loss=1289.4330

89it [00:21,  4.09it/s]                        


Epoch=5	loss=1286.8143

89it [00:21,  4.07it/s]                        


Epoch=6	loss=1288.1885

89it [00:21,  4.18it/s]                        


Epoch=7	loss=1288.3527

89it [00:21,  4.23it/s]                        


Epoch=8	loss=1291.1315

89it [00:20,  4.25it/s]                        


Epoch=9	loss=1286.9455

89it [00:20,  4.27it/s]                        


Epoch=10	loss=1290.8944

89it [00:21,  4.22it/s]                        


Epoch=11	loss=1290.2804

89it [00:20,  4.24it/s]                        


Epoch=12	loss=1287.1789

89it [00:20,  4.25it/s]                        


Epoch=13	loss=1291.6184

89it [00:21,  4.24it/s]                        


Epoch=14	loss=1291.7479

89it [00:20,  4.25it/s]                        


Epoch=15	loss=1287.5761

89it [00:20,  4.30it/s]                        


Epoch=16	loss=1291.2055

89it [00:20,  4.28it/s]                        


Epoch=17	loss=1287.6275

89it [00:20,  4.27it/s]                        


Epoch=18	loss=1288.4650

89it [00:21,  4.14it/s]                        


Epoch=19	loss=1286.7685

89it [00:23,  3.84it/s]                        


Epoch=20	loss=1287.7170

89it [00:23,  3.73it/s]                        


Epoch=21	loss=1291.7654

89it [00:23,  3.76it/s]                        


Epoch=22	loss=1292.1570

89it [00:22,  3.98it/s]                        


Epoch=23	loss=1293.4698

89it [00:23,  3.76it/s]                        


Epoch=24	loss=1292.2750

89it [00:23,  3.72it/s]                        


Epoch=25	loss=1288.8402

89it [00:24,  3.71it/s]                        


Epoch=26	loss=1289.5315

89it [00:23,  3.84it/s]                        


Epoch=27	loss=1288.6393

89it [00:23,  3.86it/s]                        


Epoch=28	loss=1288.0507

89it [00:22,  3.88it/s]                        


Epoch=29	loss=1288.7474

89it [00:23,  3.86it/s]                        


Epoch=30	loss=1289.4293

89it [00:23,  3.86it/s]                        


Epoch=31	loss=1294.0653

89it [00:23,  3.86it/s]                        


Epoch=32	loss=1291.2440

89it [00:23,  3.83it/s]                        


Epoch=33	loss=1291.2201

89it [00:23,  3.85it/s]                        


Epoch=34	loss=1287.3018

89it [00:23,  3.86it/s]                        


Epoch=35	loss=1290.5991

89it [00:23,  3.85it/s]                        


Epoch=36	loss=1290.6472

89it [00:23,  3.85it/s]                        


Epoch=37	loss=1287.9298

89it [00:23,  3.84it/s]                        


Epoch=38	loss=1290.9744

89it [00:22,  3.99it/s]                        


Epoch=39	loss=1287.1197

89it [00:21,  4.14it/s]                        


Epoch=40	loss=1289.9746

89it [00:20,  4.24it/s]                        


Epoch=41	loss=1294.3691

89it [00:21,  4.15it/s]                        


Epoch=42	loss=1291.4810

89it [00:20,  4.24it/s]                        


Epoch=43	loss=1289.8682

89it [00:20,  4.25it/s]                        


Epoch=44	loss=1289.8071

89it [00:21,  4.23it/s]                        


Epoch=45	loss=1294.7748

89it [00:20,  4.24it/s]                        


Epoch=46	loss=1291.6626

89it [00:21,  4.16it/s]                        


Epoch=47	loss=1290.4978

89it [00:22,  3.89it/s]                        


Epoch=48	loss=1289.3130

89it [00:22,  3.88it/s]                        


Epoch=49	loss=1290.2385

89it [00:22,  3.94it/s]                        


Epoch=50	loss=1288.1449

89it [00:22,  3.91it/s]                        


Epoch=51	loss=1288.8809

89it [00:22,  3.91it/s]                        


Epoch=52	loss=1289.0093

89it [00:22,  3.98it/s]                        


Epoch=53	loss=1286.7111

89it [00:23,  3.83it/s]                        


Epoch=54	loss=1292.1888

89it [00:23,  3.84it/s]                        


Epoch=55	loss=1289.8462

89it [00:23,  3.81it/s]                        


Epoch=56	loss=1290.8962

89it [00:23,  3.84it/s]                        


Epoch=57	loss=1290.5942

89it [00:23,  3.84it/s]                        


Epoch=58	loss=1289.3994

89it [00:22,  3.96it/s]                        


Epoch=59	loss=1284.3067

89it [00:22,  4.00it/s]                        


Epoch=60	loss=1290.8211

89it [00:22,  3.99it/s]                        


Epoch=61	loss=1289.6643

89it [00:21,  4.05it/s]                        


Epoch=62	loss=1291.0863

89it [00:21,  4.05it/s]                        


Epoch=63	loss=1291.7116

89it [00:21,  4.06it/s]                        


Epoch=64	loss=1292.5107

89it [00:21,  4.05it/s]                        


Epoch=65	loss=1286.8197

89it [00:21,  4.06it/s]                        


Epoch=66	loss=1288.7666

89it [00:21,  4.06it/s]                        


Epoch=67	loss=1290.1618

89it [00:21,  4.05it/s]                        


Epoch=68	loss=1292.9755

89it [00:21,  4.06it/s]                        


Epoch=69	loss=1286.0506

89it [00:21,  4.07it/s]                        


Epoch=70	loss=1292.2912

89it [00:22,  4.03it/s]                        


Epoch=71	loss=1290.5465

89it [00:22,  3.99it/s]                        


Epoch=72	loss=1291.2835

89it [00:22,  4.03it/s]                        


Epoch=73	loss=1286.6434

89it [00:22,  4.04it/s]                        


Epoch=74	loss=1292.5955

89it [00:21,  4.06it/s]                        


Epoch=75	loss=1289.9296

89it [00:21,  4.05it/s]                        


Epoch=76	loss=1289.3432

89it [00:21,  4.06it/s]                        


Epoch=77	loss=1289.7652

89it [00:22,  3.99it/s]                        


Epoch=78	loss=1289.1603

89it [00:21,  4.18it/s]                        


Epoch=79	loss=1288.5855

89it [00:21,  4.21it/s]                        


Epoch=80	loss=1289.7675

89it [00:22,  3.96it/s]                        


Epoch=81	loss=1294.2527

89it [00:21,  4.05it/s]                        


Epoch=82	loss=1289.4908

89it [00:22,  3.99it/s]                        


Epoch=83	loss=1289.1639

89it [00:23,  3.83it/s]                        


Epoch=84	loss=1290.2865

89it [00:22,  3.88it/s]                        


Epoch=85	loss=1293.7476

89it [00:22,  3.94it/s]                        


Epoch=86	loss=1291.3157

89it [00:23,  3.74it/s]                        


Epoch=87	loss=1288.2652

89it [00:23,  3.73it/s]                        


Epoch=88	loss=1288.4198

89it [00:23,  3.76it/s]                        


Epoch=89	loss=1288.5485

89it [00:23,  3.74it/s]                        


Epoch=90	loss=1288.1082

89it [00:23,  3.79it/s]                        


Epoch=91	loss=1292.0670

89it [00:23,  3.78it/s]                        


Epoch=92	loss=1292.2149

89it [00:23,  3.79it/s]                        


Epoch=93	loss=1289.7493

89it [00:23,  3.79it/s]                        


Epoch=94	loss=1285.8194

89it [00:23,  3.77it/s]                        


Epoch=95	loss=1296.4275

89it [00:23,  3.79it/s]                        


Epoch=96	loss=1291.0303

89it [00:22,  3.88it/s]                        


Epoch=97	loss=1292.5902

89it [00:22,  3.90it/s]                        


Epoch=98	loss=1289.9219

89it [00:23,  3.81it/s]                        


Epoch=99	loss=1288.1171

89it [00:23,  3.85it/s]                        


Epoch=100	loss=1288.5721

89it [00:22,  3.90it/s]                        


Epoch=101	loss=1288.4740

89it [00:22,  3.89it/s]                        


Epoch=102	loss=1291.0362

89it [00:23,  3.71it/s]                        


Epoch=103	loss=1290.1899

89it [00:22,  4.04it/s]                        


Epoch=104	loss=1292.1700

89it [00:22,  3.95it/s]                        


Epoch=105	loss=1288.2309

89it [00:21,  4.05it/s]                        


Epoch=106	loss=1288.4937

89it [00:21,  4.08it/s]                        


Epoch=107	loss=1289.3761

89it [00:21,  4.10it/s]                        


Epoch=108	loss=1288.7735

89it [00:22,  4.03it/s]                        


Epoch=109	loss=1290.4265

89it [00:22,  4.04it/s]                        


Epoch=110	loss=1292.6277

89it [00:21,  4.07it/s]                        


Epoch=111	loss=1293.4580

89it [00:21,  4.07it/s]                        


Epoch=112	loss=1293.2467

89it [00:21,  4.24it/s]                        


Epoch=113	loss=1291.1910

89it [00:20,  4.29it/s]                        


Epoch=114	loss=1286.3692

89it [00:21,  4.21it/s]                        


Epoch=115	loss=1286.5655

89it [00:20,  4.25it/s]                        


Epoch=116	loss=1291.0605

89it [00:21,  4.22it/s]                        


Epoch=117	loss=1292.2536

89it [00:20,  4.24it/s]                        


Epoch=118	loss=1287.6223

89it [00:21,  4.21it/s]                        


Epoch=119	loss=1295.1265

89it [00:20,  4.29it/s]                        


Epoch=120	loss=1291.4328

89it [00:21,  4.12it/s]                        


Epoch=121	loss=1287.7991

89it [00:21,  4.20it/s]                        


Epoch=122	loss=1289.7309

89it [00:21,  4.19it/s]                        


Epoch=123	loss=1289.6311

89it [00:21,  4.08it/s]                        


Epoch=124	loss=1287.9656

89it [00:21,  4.06it/s]                        


Epoch=125	loss=1293.7885

89it [00:22,  4.02it/s]                        


Epoch=126	loss=1291.2466

89it [00:21,  4.10it/s]                        


Epoch=127	loss=1288.4156

89it [00:21,  4.18it/s]                        


Epoch=128	loss=1290.4915

89it [00:20,  4.24it/s]                        


Epoch=129	loss=1291.1017

89it [00:21,  4.24it/s]                        


Epoch=130	loss=1286.8573

89it [00:21,  4.23it/s]                        


Epoch=131	loss=1290.7530

89it [00:21,  4.23it/s]                        


Epoch=132	loss=1290.8073

89it [00:21,  4.23it/s]                        


Epoch=133	loss=1293.5741

89it [00:21,  4.22it/s]                        


Epoch=134	loss=1292.1440

89it [00:20,  4.24it/s]                        


Epoch=135	loss=1285.8025

89it [00:21,  4.23it/s]                        


Epoch=136	loss=1290.5013

89it [00:21,  4.23it/s]                        


Epoch=137	loss=1286.8555

89it [00:21,  4.23it/s]                        


Epoch=138	loss=1288.1309

89it [00:21,  4.20it/s]                        


Epoch=139	loss=1291.6869

 74%|███████▍  | 65/88 [00:15<00:05,  3.84it/s]

In [None]:
# history = []

# early_stop_coutner = 0
# weights_path = f"../../_WEIGHTS/oc_fakedect/OC_FD_e29_l500_AVG.pkl"
# fakedect = torch.load(weights_path)

# n_dpoints = len(train_loader.dataset)
# min_loss = float("inf")
# for epoch in range(epochs):#tqdm(range(epochs), total=epochs):
#     fakedect.train() # train mode
#     train_loss = 0
#     for batch_idx, (inputs, _) in tqdm(enumerate(train_loader), total=n_dpoints//batch_size-1):
#         inputs = inputs.to(device)
#         gen_imgs, mean, logvar = fakedect(inputs)
#         loss = loss_function(gen_imgs, inputs, mean, logvar)
#         optimizer.zero_grad()
#         loss.backward()
#         optimizer.step()
#         train_loss += loss.item()
#     L = train_loss / n_dpoints
#     history += [L]
#     print(f"Epoch={epoch}\tloss={L:.4f}")
#     if L < min_loss:
#         min_loss = L
#         torch.save(fakedect, f"../../_WEIGHTS/oc_fakedect/OC_FD_e{epoch+1}_l{int(L)}_{dataset_name}.pkl")
#         print(f"../../_WEIGHTS/oc_fakedect/OC_FD_e{epoch+1}_l{int(L)}_{dataset_name}.pkl")
#         early_stop_coutner = 0
#     else:
#         early_stop_coutner += 1
    
#     if early_stop_coutner>20:
#         break

In [None]:
# load best weights for eval
# weights_path = "../../_WEIGHTS/oc_fakedect/OC_FD_e29_l500_AVG.pkl"
# weights_path = "../../_WEIGHTS/oc_fakedect/OC_FD_e2_l532_RND.pkl"

# weights_path = "../../_WEIGHTS/oc_fakedect/OC_FD_e204_l1429_DFDB.pkl"
# weights_path = "../../_WEIGHTS/oc_fakedect/OC_FD_e1_l1431_DFDB.pkl"
# fakedect = torch.load(weights_path)

In [None]:
# # history_dfdb, history_rnd, history_avg = [1,2], [2,3], [3,4]
# history_dfdb
# history_rnd
# history_avg

In [None]:
fig, ax = plt.subplots(1, figsize = (10, 10), sharex=True)
ax.plot(history_dfdb, label="DFDB Training Loss")
ax.plot(history_rnd, label="CDFv2 (RF) Training Loss")
ax.plot(history_avg, label="CDFv2 (AF) Training Loss")
ax.set_title(f'OC-FakeDect1 Training Losses')
ax.set_ylabel('Loss')
ax.legend()
plt.savefig(
    f"./Results/Losses2.pdf",
    bbox_inches="tight",
)
plt.show()

In [None]:
# shwcse_img_path = "D:/MInf/Datasets/Celeb-DF-v2-OC/Celeb-rnd-30-OC-realfake-test/Celeb-real/id11_0009_1.png"
# img = cv2.cvtColor(cv2.imread(shwcse_img_path), cv2.COLOR_BGR2RGB)
# image = transform_test(img)
# image = image.float().to(device)
# # print(image.shape)
# # fakedect = torch.load(weights_path)
# fakedect.eval()
# with torch.no_grad():
#     x = image.view(-1,3,100,100)
#     print(x.shape)
#     decode_z, mean_z, logvar_z = fakedect(x)
#     # print(decode_z.shape )
# decode_z = decode_z.view(3,100,100).cpu()
# decode_z = 255*(decode_z*0.5 + 0.5).numpy()
# decode_z = decode_z.astype(int)
# plt.subplot(1,2,1)
# plt.imshow(img)
# plt.subplot(1,2,2)
# plt.imshow(decode_z.transpose(1,2,0))

In [None]:
def solve(m1,m2,std1,std2):
  a = 1/(2*std1**2) - 1/(2*std2**2)
  b = m2/(std2**2) - m1/(std1**2)
  c = m1**2 /(2*std1**2) - m2**2 / (2*std2**2) - np.log(std2/std1)
  return np.roots([a,b,c])

In [None]:
#               # experiment_name,                                       model_name,          testset_name,                 fname
# EXPERIMENTS = [

#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on DFDB Dataset",  f"OC-FakeDect1_{dataset_name}", "DariusAf_Deepfake_Database", f"OC-FakeDect1_{dataset_name}_onDFDB"),

#                 # OC-FakeDect1 RND on Celebs indep eval
#                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on AVG CDFv2 Dataset",  f"OC-FakeDect1_{dataset_name}", "Celeb-avg-30", f"OC-FakeDect1_{dataset_name}_onAVG"),

#                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on RND CDFv2 Dataset",  f"OC-FakeDect1_{dataset_name}", "Celeb-rnd-30", f"OC-FakeDect1_{dataset_name}_onRND"),

#                 # OC-FakeDect1 RND on Celebs rae
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on AVG CDFv2 Dataset\
#                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Celeb-avg-30", f"OC-FakeDect1_{dataset_name}_raeonAVG"),

#                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on RND CDFv2 Dataset\
#                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Celeb-rnd-30", f"OC-FakeDect1_{dataset_name}_raeonRND"),

#                 # OC-FakeDect1 RND on FF indep eval
#                 # DF
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on AVG Deepfakes Dataset\
#                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_avg", f"OC-FakeDect1_{dataset_name}_raeonDFavg"),
             
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on AVG Deepfakes Dataset",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_avg", f"OC-FakeDect1_{dataset_name}_DFavg"),
             
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on RND Deepfakes Dataset\
#                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_rnd", f"OC-FakeDect1_{dataset_name}_raeonDFrnd"),
             
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                \nTested on RND Deepfakes Dataset",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_rnd", f"OC-FakeDect1_{dataset_name}_DFrnd"),

#                 # F2F
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG Face2Face Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Face2Face_avg", f"OC-FakeDect1_{dataset_name}_raeonF2Favg"),
             
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG Face2Face Dataset",  f"OC-FakeDect1_{dataset_name}", "Face2Face_avg", f"OC-FakeDect1_{dataset_name}_F2Favg"),
              
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND Face2Face Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Face2Face_rnd", f"OC-FakeDect1_{dataset_name}_raeonF2Frnd"),
               
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND Face2Face Dataset",  f"OC-FakeDect1_{dataset_name}", "Face2Face_rnd", f"OC-FakeDect1_{dataset_name}_F2Frnd"),

#                 # FaceShifter
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG FaceShifter Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_avg", f"OC-FakeDect1_{dataset_name}_raeonFSHFTavg"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG FaceShifter Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_avg", f"OC-FakeDect1_{dataset_name}_FSHFTavg"),
               
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND FaceShifter Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_rnd", f"OC-FakeDect1_{dataset_name}_raeonFSHFTrnd"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND FaceShifter Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_rnd", f"OC-FakeDect1_{dataset_name}_FSHFTrnd"),

#                 # FaceSwap
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG FaceSwap Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_avg", f"OC-FakeDect1_{dataset_name}_raeonFSavg"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG FaceSwap Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_avg", f"OC-FakeDect1_{dataset_name}_FSavg"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND FaceSwap Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_rnd", f"OC-FakeDect1_{dataset_name}_raeonFSrnd"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND FaceSwap Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_rnd", f"OC-FakeDect1_{dataset_name}_FSrnd"),

#                 # NeuralTextures
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG NeuralTextures Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_avg", f"OC-FakeDect1_{dataset_name}_raeonNTavg"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on AVG NeuralTextures Dataset",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_avg", f"OC-FakeDect1_{dataset_name}_NTavg"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND NeuralTextures Dataset\
#                 \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_rnd", f"OC-FakeDect1_{dataset_name}_raeonNTrnd"),
                
#                 (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
#                 \nTested on RND NeuralTextures Dataset",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_rnd", f"OC-FakeDect1_{dataset_name}_NTrnd"),
#                ]

In [None]:
for dataset_name in ["DFDB", "AVG", "RND"]:
    weights_path = f"OCFD1_Retrained2_{dataset_name}.pkl"
# for dataset_name, weights_path in zip(["DFDB", "AVG", "RND"], ["OC_FD_e204_l1429_DFDB.pkl", "OC_FD_e29_l500_AVG.pkl", "OC_FD_e2_l532_RND.pkl"]):
    fakedect = torch.load("../../_WEIGHTS/oc_fakedect/"+weights_path)
                  # experiment_name,                                       model_name,          testset_name,                 fname
    EXPERIMENTS = [

                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on DFDB Dataset",  f"OC-FakeDect1_{dataset_name}", "DariusAf_Deepfake_Database", f"OC-FakeDect1_{dataset_name}_onDFDB"),

                # OC-FakeDect1 RND on Celebs indep eval
               (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on AVG CDFv2 Dataset",  f"OC-FakeDect1_{dataset_name}", "Celeb-avg-30", f"OC-FakeDect1_{dataset_name}_onAVG"),

               (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on RND CDFv2 Dataset",  f"OC-FakeDect1_{dataset_name}", "Celeb-rnd-30", f"OC-FakeDect1_{dataset_name}_onRND"),

                # OC-FakeDect1 RND on Celebs rae
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on AVG CDFv2 Dataset\
               \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Celeb-avg-30", f"OC-FakeDect1_{dataset_name}_raeonAVG"),

               (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on RND CDFv2 Dataset\
               \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Celeb-rnd-30", f"OC-FakeDect1_{dataset_name}_raeonRND"),

                # OC-FakeDect1 RND on FF indep eval
                # DF
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on AVG Deepfakes Dataset\
               \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_avg", f"OC-FakeDect1_{dataset_name}_raeonDFavg"),
             
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on AVG Deepfakes Dataset",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_avg", f"OC-FakeDect1_{dataset_name}_DFavg"),
             
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on RND Deepfakes Dataset\
               \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_rnd", f"OC-FakeDect1_{dataset_name}_raeonDFrnd"),
             
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
               \nTested on RND Deepfakes Dataset",  f"OC-FakeDect1_{dataset_name}", "Deepfakes_rnd", f"OC-FakeDect1_{dataset_name}_DFrnd"),

                # F2F
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG Face2Face Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Face2Face_avg", f"OC-FakeDect1_{dataset_name}_raeonF2Favg"),
             
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG Face2Face Dataset",  f"OC-FakeDect1_{dataset_name}", "Face2Face_avg", f"OC-FakeDect1_{dataset_name}_F2Favg"),
              
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND Face2Face Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "Face2Face_rnd", f"OC-FakeDect1_{dataset_name}_raeonF2Frnd"),
               
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND Face2Face Dataset",  f"OC-FakeDect1_{dataset_name}", "Face2Face_rnd", f"OC-FakeDect1_{dataset_name}_F2Frnd"),

                # FaceShifter
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG FaceShifter Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_avg", f"OC-FakeDect1_{dataset_name}_raeonFSHFTavg"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG FaceShifter Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_avg", f"OC-FakeDect1_{dataset_name}_FSHFTavg"),
               
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND FaceShifter Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_rnd", f"OC-FakeDect1_{dataset_name}_raeonFSHFTrnd"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND FaceShifter Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceShifter_rnd", f"OC-FakeDect1_{dataset_name}_FSHFTrnd"),

                # FaceSwap
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG FaceSwap Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_avg", f"OC-FakeDect1_{dataset_name}_raeonFSavg"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG FaceSwap Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_avg", f"OC-FakeDect1_{dataset_name}_FSavg"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND FaceSwap Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_rnd", f"OC-FakeDect1_{dataset_name}_raeonFSrnd"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND FaceSwap Dataset",  f"OC-FakeDect1_{dataset_name}", "FaceSwap_rnd", f"OC-FakeDect1_{dataset_name}_FSrnd"),

                # NeuralTextures
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG NeuralTextures Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_avg", f"OC-FakeDect1_{dataset_name}_raeonNTavg"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on AVG NeuralTextures Dataset",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_avg", f"OC-FakeDect1_{dataset_name}_NTavg"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND NeuralTextures Dataset\
                \nRunning Average Evaluation",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_rnd", f"OC-FakeDect1_{dataset_name}_raeonNTrnd"),
                
                (f"OC-FakeDect1 (Trained on {dataset_name} Dataset)\
                \nTested on RND NeuralTextures Dataset",  f"OC-FakeDect1_{dataset_name}", "NeuralTextures_rnd", f"OC-FakeDect1_{dataset_name}_NTrnd"),
               ]
    for experiment_name, model_name, testset_name, fname in EXPERIMENTS[::-1]:
        _, _, testset = get_binary_testset(testset_name)
        print(f"'{experiment_name}'", model_name, testset_name, testset, fname, "\n", sep="\n")

        # recon scores
        test_n = sum([len(files) for _, _, files in os.walk(testset)])  # // batch_size
        score = {0: [], 1: []}
        avg_eval = "running" in experiment_name.lower()
        if avg_eval:
            # running average eval, set of images from each video is single data point (not indep.)
            path_to_testset_real_class = f"{testset}/{[d for d in os.listdir(testset) if 'real' in d][0]}"
            path_to_testset_fake_class = f"{testset}/{[d for d in os.listdir(testset) if 'real' not in d][0]}"

            for y_dir_pth, y_label in [(path_to_testset_fake_class, 1), (path_to_testset_real_class, 0)]:
                img_dir_list = os.listdir(y_dir_pth)
                Vpths = []
                # loop over all single videos
                for og_fname in sorted(list(set([re.split("_\d+.png", img_name)[0] for img_name in img_dir_list]))):

                    all_imgs_for_vid = [i for i in img_dir_list if og_fname in i]
                    path_all_imgs_for_vid = [f"{y_dir_pth}/{i}" for i in all_imgs_for_vid]
                    rsmeS = []

                    # loop over all frames from single video
                    for path_to_test_img in path_all_imgs_for_vid:
                        image = cv2.cvtColor(cv2.imread(test_img_path), cv2.COLOR_BGR2RGB)
                        image = transform_test(image)
                        image = image.float().to(device)

                        fakedect.eval()  # test mode
                        with torch.no_grad():
                            x = image.view(-1, 3, 100, 100)
                            x_prime, _, _ = fakedect(x)
                            xi = x.flatten()
                            xo = x_prime.flatten()
                        rsmeS += [((((xi.cpu() - xo.cpu()) ** 2).sum() / test_n) ** 0.5).item()]
                    score[y_label] += [np.average(rsmeS)]
        else:
            for class_dir in os.listdir(f"{testset}/"):
                y = 0 if "real" in class_dir else 1
                for test_img in tqdm(os.listdir(f"{testset}/{class_dir}")):
                    test_img_path = f"{testset}/{class_dir}/{test_img}"
                    if (
                        ".png" in test_img_path
                        or ".jpg" in test_img_path
                        or ".jpeg" in test_img_path
                    ):
                        image = cv2.cvtColor(cv2.imread(test_img_path), cv2.COLOR_BGR2RGB)
                        image = transform_test(image)
                        image = image.float().to(device)

                        fakedect.eval()  # test mode
                        with torch.no_grad():
                            x = image.view(-1, 3, 100, 100)
                            x_prime, _, _ = fakedect(x)
                            xi = x.flatten()
                            xo = x_prime.flatten()
                        rsme = ((((xi.cpu() - xo.cpu()) ** 2).sum() / test_n) ** 0.5).item()
                        score[y] += [rsme]

        y_true = np.array([0] * len(score[0]) + [1] * len(score[1]))
        y_pred = np.array(score[0] + score[1])

        k = max(np.array(score[0]).std() * 0.075, np.array(score[1]).std() * 0.075)
        min_r, max_r = min(min(score[0]), min(score[1])), max(max(score[0]), max(score[1]))
        bins = [x for x in np.arange(min_r, max_r, k)]

        figure(figsize=(10, 6))  # , dpi = 80)
        c0, _, p = plt.hist(score[0], bins, alpha=0.7, label="Real")
        c1, _, p = plt.hist(score[1], bins, alpha=0.7, label="Fake")
    
        # mode_0s = bins[np.where(c0 == c0.max())[0][0]] #np.array(score[0]).mean()
        # mode_1s = bins[np.where(c1 == c1.max())[0][0]] #np.array(score[1]).mean()

        # determine best threshold (highest f1)
        threshold_canditates = []
        for t in bins:
            y_pred_rint = (y_pred <= t).astype(float)
            threshold_canditates += [(roc_auc_score(y_true, y_pred_rint), t)]
            # threshold_canditates += [(f1_score(y_true, y_pred_rint), t)]
        threshold = max(threshold_canditates, key=lambda t: t[0])[1]

        # if threshold is not in between means, this is messed up and we need to try and new strat
        # if not (min(mode_0s, mode_1s) <= threshold <= max(mode_0s, mode_1s)):
        #     try:
        #         threshold = solve(np.mean(score[1]), np.mean(score[0]), np.var(score[1]), np.mean(score[0]))[0]
        #     except LinAlgError:
        #         threshold = (mode_0s + mode_1s)/2 # mid point between modes

        plt.vlines(
            threshold,
            0,
            int(max(max(c0), max(c1)) // 0.9),
            colors="red",
            label=f"Class Boundary @ {threshold:.3f}",
        )
        plt.legend()
        plt.xlabel("RMSE")
        plt.ylabel("Count")
        plt.title(f"{experiment_name.replace('Trained on RND', 'Trained on CDFv2 (RF)').replace('Trained on AVG', 'Trained on CDFv2 (AF)')}\nReconstruction Scores")
        plt.savefig(f"Results/thrsh_{fname}.pdf", bbox_inches="tight")
        plt.show()

        auroc = roc_auc_score(y_true, y_pred)
        fpr, tpr, _ = roc_curve(y_true, y_pred)

        # If model is worse than random but so much worse that, it's predicting the opposite way
        if auroc < 0.5:
            auroc = 1 - auroc
            fpr, tpr = tpr, fpr
            y_pred = np.ones(y_pred.shape) - y_pred
        # Get F1, Precision and Recall (WITH BEST THRESHOLD)
        y_pred_rint = (y_pred <= threshold).astype(float)
        f1 = f1_score(y_true, y_pred_rint)
        prec = precision_score(y_true, y_pred_rint)
        recall = recall_score(y_true, y_pred_rint)

        # Plot AUC
        plt.figure()
        lw = 2
        plt.plot(fpr, tpr, color="magenta", lw=lw, label="ROC Curve (Area = %0.3f)" % auroc)
        plt.plot([0, 1], [0, 1], color="black", lw=lw, linestyle="--")
        extra_xylim = 0.025
        plt.xlim([0.0 - extra_xylim, 1.0 + extra_xylim])
        plt.ylim([0.0 - extra_xylim, 1.0 + extra_xylim])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title(f"Test AUROC {experiment_name.replace('Trained on RND', 'Trained on CDFv2 (RF)').replace('Trained on AVG', 'Trained on CDFv2 (AF)')}")
        plt.legend(loc="lower right")
        plt.savefig(f"./Results/{fname}_AUC.pdf", bbox_inches="tight")
        plt.show()

        with open(f"./Results/{fname}.txt", "w") as f:
            f.write(f"{experiment_name, model_name, dataset_name, fname}\n")
            f.write(f"auroc={auroc}\n")
            f.write(f"f1={f1}\n")
            f.write(f"prec={prec}\n")
            f.write(f"recall={recall}\n")
            f.write(f"\nfpr={[f for f in fpr]}\n")
            f.write(f"\ntpr={[t for t in tpr]}\n")
            f.write(f"\ny_true={[t for t in y_true]}\n")
            f.write(f"\ny_pred={[p[0] if 'arr' in str(type(p)) and len(p) else p for p in y_pred]}\n")
            f.write(f"\score={score}")
    # break


In [None]:

# for dataset_name in ("DFDB", "CDFv2 (AF)", "CDFv2 (DF)")
#     for experiment_name, model_name, testset_name, fname in EXPERIMENTS[:1]:
#         print(experiment_name, model_name, testset_name, fname, "\n", sep="\n")
#         _, _, trainset = get_binary_testset(testset_name)

#         # recon scores
#         test_n = sum([len(files) for _, _, files in os.walk(trainset)]) #// batch_size
#         score = {0:[],1:[]}
#         avg_eval = "running" in experiment_name.lower()
#         if avg_eval:
#             # running average eval, set of images from each video is single data point (not indep.)
#             path_to_testset_real_class = f"{trainset}/{[d for d in os.listdir(trainset) if 'real' in d][0]}"
#             path_to_testset_fake_class = f"{trainset}/{[d for d in os.listdir(trainset) if 'real' not in d][0]}"

#             for y_dir_pth, y_label in [(path_to_testset_fake_class, 1), (path_to_testset_real_class, 0)]:
#                 img_dir_list = os.listdir(y_dir_pth)
#                 Vpths = []
#                 # loop over all single videos
#                 for og_fname in set([re.split('_\d+.png', img_name)[0] for img_name in img_dir_list]):

#                     all_imgs_for_vid = [i for i in img_dir_list if og_fname in i]
#                     path_all_imgs_for_vid = [f"{y_dir_pth}/{i}" for i in all_imgs_for_vid]
#                     rsmeS = []

#                     # loop over all frames from single video
#                     for path_to_test_img in path_all_imgs_for_vid:
#                         image = cv2.cvtColor(cv2.imread(test_img_path), cv2.COLOR_BGR2RGB)
#                         image = transform_test(image)
#                         image = image.float().to(device)

#                         fakedect.eval() # test mode
#                         with torch.no_grad():
#                             x = image.view(-1,3,100,100)
#                             x_prime, _, _ = fakedect(x)
#                             # z_mean, z_logvar = fakedect.encode(x_prime)
#                             xi = x.flatten()
#                             xo = x_prime.flatten()
#                         rsmeS += [((((xi.cpu() - xo.cpu())**2).sum()/test_n)**0.5).item()]
#                     score[y_label] += [np.average(rsmeS)]

#         else:
#             for class_dir in os.listdir(f"{testset}/"):
#                 y = 0 if "real" in class_dir else 1
#                 for test_img in tqdm(os.listdir(f"{testset}/{class_dir}")):
#                     test_img_path = f"{testset}/{class_dir}/{test_img}"
#                     if ".png" in test_img_path or ".jpg" in test_img_path or ".jpeg" in test_img_path:
#                         image = cv2.cvtColor(cv2.imread(test_img_path), cv2.COLOR_BGR2RGB)
#                         image = transform_test(image)
#                         image = image.float().to(device)

#                         fakedect.eval() # test mode
#                         with torch.no_grad():
#                             x = image.view(-1,3,100,100)
#                             x_prime, _, _ = fakedect(x)
#                             # z_mean, z_logvar = fakedect.encode(x_prime)
#                             xi = x.flatten()
#                             xo = x_prime.flatten()
#                         rsme = ((((xi.cpu() - xo.cpu())**2).sum()/test_n)**0.5).item()
#                         score[y] += [rsme]
        
#         k = max(np.array(score[0]).std() * 0.075, np.array(score[1]).std() * 0.075)
#         min_r, max_r = min(min(score[0]), min(score[1])), max(max(score[0]), max(score[1]))
#         bins = [x for x in np.arange(min_r, max_r, k)]

#         try:
#             threshold = solve(np.mean(score[1]), np.mean(score[0]), np.var(score[1]), np.mean(score[0]))[0]
#         except LinAlgError:
#             threshold = (min_r+max_r)/2

#         figure(figsize = (10, 6))#, dpi = 80)
#         c0,_,p = plt.hist(score[0], bins, alpha=0.7, label="Real")
#         c1,_,p = plt.hist(score[1], bins, alpha=0.7, label="Fake")
#         plt.vlines(threshold, 0, int(max(max(c0),max(c1))//0.9), colors="red", label=f"Class Boundary @ {threshold:.3f}" )
#         plt.legend()
#         plt.xlabel("RMSE")
#         plt.ylabel("Count")
#         plt.title(f"{experiment_name}\nReconstruction Scores")
#         plt.savefig(f"Results/thrsh_{fname}.pdf", bbox_inches="tight")
#         plt.show()

#         y_true = np.array([0]*len(score[0]) + [1]*len(score[1]))
#         y_pred = np.array(score[0] + score[1])
#         y_pred_rint = (y_pred <= threshold).astype(float)
#         y_pred = y_pred_rint

#         auroc = roc_auc_score(y_true, y_pred)
#         fpr, tpr, _  = roc_curve(y_true, y_pred)

#         # If model is worse than random but so much worse that, it's predicting the opposite way
#         if auroc < .5:
#             auroc = 1 - auroc
#             fpr, tpr = tpr, fpr
#             y_pred = np.ones(y_pred.shape) - y_pred

#         # Get F1, Precision and Recall
#         f1 = f1_score(y_true, y_pred_rint)
#         prec = precision_score(y_true, y_pred_rint)
#         recall = recall_score(y_true, y_pred_rint)

#         # Plot AUC
#         plt.figure()
#         lw = 2
#         plt.plot(fpr, tpr, color='magenta', lw=lw, label='ROC Curve (Area = %0.3f)' % auroc)
#         plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
#         extra_xylim = 0.025
#         plt.xlim([0.0 - extra_xylim, 1.0 + extra_xylim])
#         plt.ylim([0.0 - extra_xylim, 1.0 + extra_xylim])
#         plt.xlabel('False Positive Rate')
#         plt.ylabel('True Positive Rate')
#         plt.title(f"Test AUROC {experiment_name}")
#         plt.legend(loc="lower right")
#         # plt.savefig(f"./Results/{fname}_AUC.pdf")
#         plt.show()

#         # with open(f"./Results/{fname}.txt", "w") as f:
#         #     f.write(f"{experiment_name, model_name, dataset_name, fname}\n")
#         #     f.write(f"auroc={auroc}\n")
#         #     f.write(f"f1={f1}\n")
#         #     f.write(f"prec={prec}\n")
#         #     f.write(f"recall={recall}\n")
#         #     f.write(f"\nfpr={[f for f in fpr]}\n")
#         #     f.write(f"\ntpr={[t for t in tpr]}\n")
#         #     f.write(f"\ny_true={[t for t in y_true]}\n")
#         #     f.write(f"\ny_pred={[p[0] if 'arr' in str(type(p)) and len(p) else p for p in y_pred]}\n")
#         #     f.write(f"\score={score}")
#         break