In [1]:
import numpy as np
import pandas as pd
import os
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.applications.resnet import ResNet152
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.preprocessing import image
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.models import load_model
import wandb
from wandb.keras import WandbCallback

In [None]:
config_defaults = {
    'epochs': 30,
    'batch_size': 32,
    'learning_rate': 0.0001,
    'dropout': 0.5,
    'regularization': 0.0001,
}
wandb.init(config=config_defaults, project="deepfake-resnet", entity="dat550")
config = wandb.config

In [2]:
img_size = 128
batch_size = config.batch_size
data_dir = "./data/processed"

In [3]:
class_names=os.listdir(f'{data_dir}/train')
class_names

['FAKE', 'REAL']

In [18]:
train_datagen = ImageDataGenerator(
    rescale = 1/255,    #rescale the tensor values to [0,1]
    rotation_range = 10,
    width_shift_range = 0.1,
    height_shift_range = 0.1,
    shear_range = 0.2,
    zoom_range = 0.1,
    horizontal_flip = True,
    fill_mode = 'nearest'
)

In [25]:
train_generator = train_datagen.flow_from_directory(
    directory = f'{data_dir}/train',
    target_size = (img_size, img_size),
    color_mode = "rgb",
    class_mode = "binary",
    batch_size = batch_size,
    shuffle = True
)

Found 2002 images belonging to 2 classes.


In [26]:
val_datagen = ImageDataGenerator(
    rescale = 1/255    #rescale the tensor values to [0,1]
)

In [27]:
val_generator = val_datagen.flow_from_directory(
    directory = f'{data_dir}/validation',
    target_size = (img_size, img_size),
    color_mode = "rgb",
    class_mode = "binary",
    batch_size = batch_size,
    shuffle = True
)

Found 402 images belonging to 2 classes.


In [56]:
resnet_model = Sequential()

pretrained_model = ResNet152(include_top=False,
                   input_shape=(img_size, img_size, 3),
                   pooling='max',
                   classes=2,
                   weights='imagenet'
                   )

resnet_model.add(pretrained_model)
resnet_model.add(Dense(units = 512, activation = 'relu', kernel_regularizer=tf.keras.regularizers.L2(config.regularization), bias_regularizer=tf.keras.regularizers.L2(config.regularization)))
resnet_model.add(Dropout(config.dropout))
resnet_model.add(Dense(units = 128, activation = 'relu', kernel_regularizer=tf.keras.regularizers.L2(config.regularization), bias_regularizer=tf.keras.regularizers.L2(config.regularization)))
resnet_model.add(Dense(units = 1, activation = 'sigmoid'))

In [58]:
resnet_model.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 resnet50 (Functional)       (None, 2048)              23587712  
                                                                 
 dense_3 (Dense)             (None, 512)               1049088   
                                                                 
 dropout_1 (Dropout)         (None, 512)               0         
                                                                 
 dense_4 (Dense)             (None, 128)               65664     
                                                                 
 dense_5 (Dense)             (None, 1)                 129       
                                                                 
Total params: 24,702,593
Trainable params: 24,649,473
Non-trainable params: 53,120
_________________________________________________________________


In [None]:
resnet_model.compile(optimizer=Adam(learning_rate=config.learning_rate), loss='binary_crossentropy', metrics=['accuracy'])

In [None]:
model_file = f'models/{wandb.run.name}_model.h5'

custom_callbacks = [
    EarlyStopping(
        monitor = 'val_loss',
        mode = 'min',
        patience = 5,
        verbose = 1
    ),
    ModelCheckpoint(
        filepath = model_file,
        monitor = 'val_loss',
        mode = 'min',
        verbose = 1,
        save_best_only = True
    ),
    WandbCallback()
]

In [None]:
num_epochs = config.epochs
history = resnet_model.fit(
    train_generator,
    epochs = num_epochs,
    steps_per_epoch = len(train_generator),
    validation_data = val_generator,
    validation_steps = len(val_generator),
    callbacks = custom_callbacks,
    verbose = 1
)

In [None]:
test_datagen = ImageDataGenerator(
    rescale = 1/255    #rescale the tensor values to [0,1]
)

test_generator = test_datagen.flow_from_directory(
    directory = f"{data_dir}/test",
    classes=['REAL', 'FAKE'],
    target_size = (img_size, img_size),
    color_mode = "rgb",
    class_mode = None,
    batch_size = 1,
    shuffle = False
)

In [None]:
best_model = load_model(model_file)

# Generate predictions
test_generator.reset()

preds = best_model.predict(
    test_generator,
    verbose = 1
)

test_results = pd.DataFrame({
    "Filename": test_generator.filenames,
    "Prediction": preds.flatten()
})

test_results

In [None]:
test_results["Rounded"] = test_results["Prediction"].round()

In [None]:
true_positive_fake = test_results[(test_results['Filename'].str.startswith('FAKE')) & (test_results['Rounded'] == 0)].count()[0]
false_positive_fake = test_results[(test_results['Filename'].str.startswith('REAL')) & (test_results['Rounded'] == 0)].count()[0]

true_positive_real = test_results[(test_results['Filename'].str.startswith('REAL')) & (test_results['Rounded'] == 1)].count()[0]
false_positive_real = test_results[(test_results['Filename'].str.startswith('FAKE')) & (test_results['Rounded'] == 1)].count()[0]

np.matrix([
    [true_positive_fake, false_positive_fake],
    [false_positive_real, true_positive_real]
])

In [None]:
# Log the resultmatrix to wandb
wandb.log({
    'true_positive_fake': true_positive_fake,
    'false_positive_fake': false_positive_fake,
    'true_positive_real': true_positive_real,
    'false_positive_real': false_positive_real
})