In [None]:
!pip install vit-keras

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras import layers, Model, metrics, optimizers
from keras.callbacks import EarlyStopping, ModelCheckpoint
from tqdm import tqdm
from sklearn.metrics import f1_score, recall_score, precision_score, roc_curve, roc_auc_score
from vit_keras import vit, utils

In [None]:
def get_binary_testset(dataset_name):
    """
    `DariusAf_Deepfake_Database` (train_test)
    `Celeb-avg-30-(train/test)`
    `Celeb-rnd-30-(train/test)`
    `Celeb-diff-30-(train/test)`
    """
    testset = None
    path_2_root = "../.."
    if dataset_name == "DariusAf_Deepfake_Database":
        testset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database/train_test"
    elif dataset_name == "Celeb-avg-30-test":
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-avg-30-test"
    elif dataset_name == "Celeb-rnd-30-test":
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-rnd-30-test"
    elif dataset_name == "Celeb-diff-30-test":
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-diff-30-test"

    elif dataset_name == "Celeb-avg-30-train":
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-avg-30"
    elif dataset_name == "Celeb-rnd-30-train":
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-rnd-30"
    elif dataset_name == "Celeb-diff-30-train":
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2/Celeb-diff-30"

    elif dataset_name == "DariusAf-OC": # unary
        testset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database-OC/real-train/"
    elif dataset_name == "DariusAf-OC-test": # binary
        testset = f"{path_2_root}/_DATASETS/DariusAf_Deepfake_Database-OC/realfake-test/"

    elif dataset_name == "Celeb-DF-v2-OC": # unary
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-rnd-30-OC-real-train/"
    elif dataset_name == "Celeb-DF-v2-OC-val": # unary, only has real class
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-rnd-30-OC-real-val/"
    elif dataset_name == "Celeb-DF-v2-OC-test": # unary
        testset = f"{path_2_root}/_DATASETS/Celeb-DF-v2-OC/Celeb-rnd-30-OC-realfake-test/"
    return testset

In [None]:
DATASET = get_binary_testset("Celeb-rnd-30-train")
TEST_DATASET = get_binary_testset("Celeb-rnd-30-test")

BATCH_SIZE = 64
SEED = 42
IMAGE_SIZE = 7*16 #112x112 (IMAGE_SIZE % == 0 = True, because patches)
EPOCHS = 10

In [None]:
IMG_DATAGEN = ImageDataGenerator(
    validation_split=0.2, #20% of training set is used for validation
    preprocessing_function=vit.preprocess_inputs,
    )

GEN = IMG_DATAGEN.flow_from_directory(
    DATASET,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    seed=SEED,
    subset="training",
    class_mode="binary",
    )

VAL_GEN = IMG_DATAGEN.flow_from_directory(
    DATASET,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    seed=SEED,
    subset="validation",
    class_mode="binary",
    )

TEST_GEN = IMG_DATAGEN.flow_from_directory(
    DATASET,
    target_size=(IMAGE_SIZE, IMAGE_SIZE),
    batch_size=BATCH_SIZE,
    seed=SEED,
    class_mode="binary",
    )

In [None]:
PRE_TRAINED_MODEL = vit.vit_b16(
    image_size=IMAGE_SIZE,
    activation='sigmoid',
    pretrained=True,
    include_top=False,
    pretrained_top=False,
)

for layer in PRE_TRAINED_MODEL.layers:
    layer.trainable = False

x = layers.Flatten()(PRE_TRAINED_MODEL.output)
x = layers.BatchNormalization()(x)
x = layers.GlobalAveragePooling1D()(x)
x = layers.Dense(1, activation = 'sigmoid')(x)

VIT_FE = Model(PRE_TRAINED_MODEL.input, x) # Vision Transformer Feature Extractor

# optimizer = optimizers.SGD(learning_rate=scheduled_lrs, momentum=0.9)
VIT_FE.compile(
    optimizer = 'sgd',
    loss = 'binary_crossentropy',
    metrics = [metrics.AUC(name = 'auc')])
VIT_FE.summary()

In [None]:
EVAL_GEN = TEST_GEN
y_pred = []
y_true = []
for b, (X, y) in tqdm(enumerate(EVAL_GEN), total=len(EVAL_GEN)-1):
    y_pred += VIT_FE.predict(X).tolist() 
    y_true += y.tolist()
    # break
    if b >= (EVAL_GEN.samples / EVAL_GEN.batch_size) - 1:
        break

experiment_name = "Vision Transformer Untrained (CelebDFv2 Test Set)"
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_pred_rint = np.rint(y_pred)

# Get AUC
auroc = roc_auc_score(y_true, y_pred)
fpr, tpr, _  = roc_curve(y_true, y_pred)

# If model is worse than random but so much worse that, it's predicting the opposite way
if auroc < .5:
    auroc = 1 - auroc
    fpr, tpr = tpr, fpr
    y_pred = np.ones(y_pred.shape) - y_pred

# Get F1, Precision and Recall
f1 = f1_score(y_true, y_pred_rint)
prec = precision_score(y_true, y_pred_rint)
recall = recall_score(y_true, y_pred_rint)

# Plot AUC
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='magenta', lw=lw, label='ROC Curve (Area = %0.3f)' % auroc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
extra_xylim = 0.025
plt.xlim([0.0 - extra_xylim, 1.0 + extra_xylim])
plt.ylim([0.0 - extra_xylim, 1.0 + extra_xylim])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f"AUROC of {experiment_name}")
plt.legend(loc="lower right")
# plt.savefig(f"{fname}.png") # Save AUC fig
plt.show()

In [None]:
EARLY_STOP = EarlyStopping(monitor='val_auc',
                    patience=EPOCHS//2,
                    mode='max',
                    verbose=1,
                    restore_best_weights=True)

path_2_weights_dir = f"../../_WEIGHTS/vit/E{EPOCHS}"
MODEL_CHECKPOINT = ModelCheckpoint(
    filepath = path_2_weights_dir, 
    mode='max', 
    monitor='val_auc', 
    verbose=1, 
    save_best_only=True,
    save_weights_only=True,
    save_freq='epoch'
    )

In [None]:
history = VIT_FE.fit(
    GEN,
    validation_data = VAL_GEN,
    epochs = EPOCHS,
    verbose = 1,
    callbacks = [
        EARLY_STOP,
    ]
    # class_weight = CLASS_WEIGHT,
    )

In [None]:
plt.plot(history.history['loss'])
# plt.plot(history.history['mae'])
    # plt.plot(history.history['val_loss'])
# plt.plot(history.history['val_mae'])
plt.title('Masked Autoencoder Training Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['loss', 'mae'], loc='center right')
plt.show()

# plt.plot(history.history['loss'])
# plt.plot(history.history['mae'])
plt.plot(history.history['val_loss'])
# plt.plot(history.history['val_mae'])
plt.title('Masked Autoencoder Validation Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['val_loss', 'val_mae'], loc='upper right')
plt.show()

In [None]:
EVAL_GEN = TEST_GEN
y_pred = []
y_true = []
for b, (X, y) in tqdm(enumerate(EVAL_GEN), total=len(EVAL_GEN)-1):
    y_pred += VIT_FE.predict(X).tolist() 
    y_true += y.tolist()
    # break
    if b >= (EVAL_GEN.samples / EVAL_GEN.batch_size) - 1:
        break

experiment_name = "Vision Transformer Trained (CelebDFv2 Test Set)"
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_pred_rint = np.rint(y_pred)

# Get AUC
auroc = roc_auc_score(y_true, y_pred)
fpr, tpr, _  = roc_curve(y_true, y_pred)

# If model is worse than random but so much worse that, it's predicting the opposite way
if auroc < .5:
    auroc = 1 - auroc
    fpr, tpr = tpr, fpr
    y_pred = np.ones(y_pred.shape) - y_pred

# Get F1, Precision and Recall
f1 = f1_score(y_true, y_pred_rint)
prec = precision_score(y_true, y_pred_rint)
recall = recall_score(y_true, y_pred_rint)

# Plot AUC
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='magenta', lw=lw, label='ROC Curve (Area = %0.3f)' % auroc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
extra_xylim = 0.025
plt.xlim([0.0 - extra_xylim, 1.0 + extra_xylim])
plt.ylim([0.0 - extra_xylim, 1.0 + extra_xylim])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f"AUROC of {experiment_name}")
plt.legend(loc="lower right")
# plt.savefig(f"{fname}.png") # Save AUC fig
plt.show()

In [None]:
train_gen_list = list(GEN.classes)
val_gen_list = list(VAL_GEN.classes)

train_neg, train_pos = train_gen_list.count(0), train_gen_list.count(1)
val_neg, val_pos = val_gen_list.count(0), val_gen_list.count(1)

pos = train_pos + val_pos
neg = train_neg + val_neg
total = pos + neg

weight_for_0 = (1.0 / neg)*(total)/2.0 
weight_for_1 = (1.0 / pos)*(total)/2.0

CLASS_WEIGHT = {0: weight_for_0, 1: weight_for_1}
print(f'Class weights = {CLASS_WEIGHT}')

In [None]:
history = VIT_FE.fit(
    GEN,
    validation_data = VAL_GEN,
    epochs = EPOCHS,
    verbose = 1,
    callbacks = [
        EARLY_STOP,
    ]
    class_weight = CLASS_WEIGHT,
    )

In [None]:
plt.plot(history.history['loss'])

plt.title('ViT Training Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['loss', 'mae'], loc='center right')
plt.show()

plt.plot(history.history['val_loss'])
plt.title('ViT Validation Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['val_loss', 'val_mae'], loc='upper right')
plt.show()

EVAL_GEN = TEST_GEN
y_pred = []
y_true = []
for b, (X, y) in tqdm(enumerate(EVAL_GEN), total=len(EVAL_GEN)-1):
    y_pred += VIT_FE.predict(X).tolist() 
    y_true += y.tolist()
    # break
    if b >= (EVAL_GEN.samples / EVAL_GEN.batch_size) - 1:
        break

experiment_name = "Vision Transformer Trained (CelebDFv2 Test Set)"
y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_pred_rint = np.rint(y_pred)

# Get AUC
auroc = roc_auc_score(y_true, y_pred)
fpr, tpr, _  = roc_curve(y_true, y_pred)

# If model is worse than random but so much worse that, it's predicting the opposite way
if auroc < .5:
    auroc = 1 - auroc
    fpr, tpr = tpr, fpr
    y_pred = np.ones(y_pred.shape) - y_pred

# Get F1, Precision and Recall
f1 = f1_score(y_true, y_pred_rint)
prec = precision_score(y_true, y_pred_rint)
recall = recall_score(y_true, y_pred_rint)

# Plot AUC
plt.figure()
lw = 2
plt.plot(fpr, tpr, color='magenta', lw=lw, label='ROC Curve (Area = %0.3f)' % auroc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
extra_xylim = 0.025
plt.xlim([0.0 - extra_xylim, 1.0 + extra_xylim])
plt.ylim([0.0 - extra_xylim, 1.0 + extra_xylim])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f"AUROC of {experiment_name}")
plt.legend(loc="lower right")
# plt.savefig(f"{fname}.png") # Save AUC fig
plt.show()