In [None]:
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from tensorflow.keras import regularizers
from tensorflow.keras import layers
from tensorflow.keras.initializers import RandomNormal
import numpy as np
import matplotlib.pyplot as plt
import os
import datetime
import time
import random
import os, os.path
#import tensorflow_addons as tfa
import math
import json
from statistics import median
import sklearn.metrics

from contextlib import redirect_stdout
import seaborn as sns

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

IMG_WIDTH = 224
IMG_HEIGHT = 224
IMG_CHANNELS = 1

VALIDATION_SET_SIZE = 256
BATCH_SIZE = 32

PATCH_SIZE = 64

PATH_TRAINING = ''

In [None]:
#create shuffeled TF-Datasets containing the paths to the files
def shuffle_paths(ds_paths, shuffle_loops):
    #see https://stackoverflow.com/questions/46444018/meaning-of-buffer-size-in-dataset-map-dataset-prefetch-and-dataset-shuffle/48096625#48096625
    list_paths = list(ds_paths.as_numpy_iterator())
    for i in range(shuffle_loops):
        random.seed(4)
        random.shuffle(list_paths)
    return tf.data.Dataset.from_tensor_slices(list_paths)

#train paths only contain live fingerprints
ds_train_paths = tf.data.Dataset.list_files(str(PATH_TRAINING + '*/Live/*.png'), seed=4)
ds_train_paths = shuffle_paths(ds_train_paths,10)

ds_fake_validation_paths = tf.data.Dataset.list_files(str(PATH_TRAINING + '*/Fake/*.png'), seed=4)
ds_fake_validation_paths = shuffle_paths(ds_fake_validation_paths,10)
ds_fake_validation_paths = ds_fake_validation_paths.take(int(VALIDATION_SET_SIZE/2))

ds_live_validation_paths = ds_train_paths.take(int(VALIDATION_SET_SIZE/2))
ds_train_paths = ds_train_paths.skip(int(VALIDATION_SET_SIZE/2))

ds_validation_paths = ds_fake_validation_paths.concatenate(ds_live_validation_paths)

# Data Input & Preprocessing

In [None]:
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # The second to last is the class-directory
    if parts[-2] == 'Live':
        return 1
    return 0

def decode_bmp(file_path):
    file = tf.io.read_file(file_path)
    img = tf.image.decode_bmp(file, channels=1)
    return img    

def decode_png(file_path):
    file = tf.io.read_file(file_path)
    img = tf.image.decode_png(file, channels=1)
    return img

def equalize_hist(img):
    img_zero_map = (img != 255);
    hist, bins = np.histogram(img[img_zero_map], 256,[0,255])
    cdf = hist.cumsum()
    cdf = (cdf - cdf.min())*255/(cdf.max()-cdf.min())
    cdf = np.ma.filled(cdf, 255).astype('uint8')
    return cdf[img]

def exctract_roi(img):
    
    img_height = img.shape[0]
    img_width = img.shape[1]

    y_start = 0
    y_stop = img_height
    x_start = 0
    x_stop = img_width
    
    for i in range(img_height):
        if tf.math.reduce_sum(img[i], axis=None, keepdims=False, name=None) != img_width:
            y_start = i
            break

    for i in range(img_height-1, 0, -1):
        if tf.math.reduce_sum(img[i], axis=None, keepdims=False, name=None) != img_width:
            y_stop = i
            break

    for i in range(img_width):
        if tf.math.reduce_sum(img[:,i], axis=None, keepdims=False, name=None) != img_height:
            x_start = i
            break

    for i in range(img_width-1, 0, -1):
        if tf.math.reduce_sum(img[:,i], axis=None, keepdims=False, name=None) != img_height:
            x_stop = i
            break

    img = img[y_start:y_stop,x_start:x_stop]
    img = (img-1)*-1
    
    if FULL_IMAGE:
        img = tf.image.resize_with_pad(img, IMG_HEIGHT, IMG_WIDTH)
    
    return img

def get_random_patch(image):
    non_zero_count = 0
    loop_count = 0
    #making sure the crop does not contain mainly void.
    while non_zero_count < 1100:
        loop_count += 1
        cropped = tf.image.random_crop(image, [PATCH_SIZE, PATCH_SIZE,IMG_CHANNELS], seed=None, name=None)
        #cropped = tf.image.resize(cropped, [IMG_HEIGHT, IMG_WIDTH])
        non_zero_count = tf.math.count_nonzero(cropped)
        if loop_count > 10:
            return cropped
    return cropped


def process_path(file_path):
    label = get_label(file_path)
    img = decode_png(file_path)
    
    #normalize img data to [0,1) scale
    img = tf.image.convert_image_dtype(img, tf.float32)
    
    img = exctract_roi(img)

    #for transfer learning which nets which are trained on RGB imges
    if IMG_CHANNELS == 3:
        img = tf.concat([img,img,img], 2)
        
    return img, label

In [None]:
#define mappable functions to run custom python code in tf

def mappable_get_random_patch(image,label):
    random_patch = tf.py_function(func=get_random_patch,
                                inp=[image],
                                Tout=(tf.float32))
    result_tensor = random_patch, label
    result_tensor[0].set_shape((PATCH_SIZE, PATCH_SIZE, IMG_CHANNELS))
    result_tensor[1].set_shape(())
    return result_tensor

def mappable_fn(x):
    result_tensor = tf.py_function(func=process_path,
                                inp=[x],
                                Tout=(tf.float32,tf.uint8))
    result_tensor[0].set_shape((IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
    result_tensor[1].set_shape(())
    return result_tensor

def mappable_fn_patch(x):
    result_tensor = tf.py_function(func=process_path,
                                inp=[x],
                                Tout=(tf.float32,tf.uint8))
    return result_tensor

In [None]:
def augment(image,label):
    image = tf.image.random_flip_left_right(image, seed=None)
    
    #degrees = tf.random.uniform([], minval=-20, maxval=20, dtype=tf.dtypes.float32, seed=None, name=None)
    #image = tfa.image.transform_ops.rotate(image, degrees * math.pi/180)
    
    brightness = tf.random.uniform([], minval=0.75, maxval=1.25, dtype=tf.dtypes.float32, seed=None, name=None)
    image = image*brightness
    image = tf.clip_by_value(image, 0, 1, name=None)
    
    return image, label

In [None]:
def center_crop_validation(img, label):
    #This is necessary, because augmentation during training
    croped_img = tf.image.central_crop(img, 0.90) #200/224
    resized_img = tf.image.resize_with_pad(croped_img, IMG_HEIGHT, IMG_WIDTH)
    return resized_img, label

In [None]:
def label_is_train_data(img, label):
    return img, img

In [None]:
def gan_rescale(image, label):
    return image*2-1, label

## Prepare Datasets

In [None]:
ds_patch_train = (ds_train_paths
            .map(mappable_fn_patch, num_parallel_calls=AUTOTUNE)
            .cache()
            .shuffle(buffer_size=4000)
            .map(augment, num_parallel_calls=AUTOTUNE) # randomizes the image based on augmentation rules
            .map(mappable_get_random_patch)
            .map(gan_rescale)
            .map(label_is_train_data)
            .batch(BATCH_SIZE, drop_remainder=True)
            .prefetch(buffer_size=AUTOTUNE)
           )

ds_patch_validation = (ds_validation_paths
                 .map(mappable_fn_patch, num_parallel_calls=AUTOTUNE)
                 .cache()
                 .map(mappable_get_random_patch)
                 .map(gan_rescale)
                 .map(label_is_train_data)  
                 .batch(VALIDATION_SET_SIZE)
                 .prefetch(buffer_size=AUTOTUNE)
          )

ds_patch_validation_with_labels = (ds_validation_paths
                 .map(mappable_fn_patch, num_parallel_calls=AUTOTUNE)
                 .cache()
                 .map(mappable_get_random_patch)
                 .map(gan_rescale)
                 .batch(VALIDATION_SET_SIZE)
                 .prefetch(buffer_size=AUTOTUNE)
          )

## Autoencoder

In [None]:
def predict(autoencoder, test_data):
    shp = test_data.shape
    scores = []
    for test_sample in test_data:
        test_sample = tf.reshape(test_sample, [1, shp[1], shp[2], shp[3]])
        score = autoencoder.evaluate(test_sample, test_sample, verbose=0)
        scores.append(score)

    return scores

def predict_new(autoencoder, data):
    predictions = autoencoder.predict(data)
    mse = tf.keras.losses.mean_squared_error(predictions, data)

    mean_loss_per_image = tf.reduce_mean(mse, axis=[1,2])
    return mean_loss_per_image

def mean_squared_error_w(y_true, y_pred):
    if not K.is_tensor(y_pred):
        y_pred = K.constant(y_pred)
    y_true = K.cast(y_true, y_pred.dtype)
    mses = K.mean(K.square(y_pred - y_true), axis=-1)
    std_of_mses = K.std(mses, axis=[1,2])
    const = K.mean(mses, axis = [1,2]) + (__c * std_of_mses)
    mask = K.cast(K.less(mses, const), dtype = "float32")
    return mask * mses


def visualize_ae_result(show_id, model, dump_path=None):
    image_batch, label_batch = next(iter(ds_patch_validation_with_labels))
    prediction = model.predict(tf.reshape(image_batch[show_id], [1, PATCH_SIZE,PATCH_SIZE,1]))

    f, axarr = plt.subplots(1,2,figsize=(10,10))
    axarr[0].imshow(image_batch[show_id])
    axarr[1].imshow(prediction[0])
    
    plt.show()
    
    if dump_path is not None:
        plt.savefig(dump_path)

In [None]:
class ValidationCallback(tf.keras.callbacks.Callback):
    
    def __init__(self, validation_data, validation_labels, autoencoder, dump_frequency=2):
        self.validation_data = validation_data
        self.validation_labels = validation_labels
        self.autoencoder = autoencoder
        self.real_reconstruction_error_history = []
        self.fake_reconstruction_error_history = []
        self.APCER_history = []
        self.BCER_history = []
        self.ACR_history = []
        self.dump_frequency = dump_frequency
      
    
    def dump_stats(self, file):
        history_dump = {
            'real_reconstruction_error': self.real_reconstruction_error_history,
            'fake_reconstruction_error': self.fake_reconstruction_error_history,
            'APCER': self.APCER_history,
            'BCER': self.BCER_history,
            'ACR': self.ACR_history,
            'dump_frequency': self.dump_frequency
            }

        with open(file, 'w') as outfile:
                json.dump(history_dump, outfile)   

                
    def on_epoch_end(self, epoch, logs=None):
        
        if epoch%self.dump_frequency == 0:
            visualize_ae_result(100, self.autoencoder)
            predictions = predict(self.autoencoder, self.validation_data)
            
            real_reconstruction_error = sum(predictions[128:])/128
            fake_reconstruction_error = sum(predictions[:128])/128
            print('\nrecunstruction error real: {} fake: {}'.format(real_reconstruction_error, fake_reconstruction_error))
            
            classification_threshold = (fake_reconstruction_error+real_reconstruction_error)/2

            custom_threshhold_predictions = tf.cast(tf.math.less(
                predictions, classification_threshold, name=None
            ), dtype=tf.int32)

            matrix = tf.math.confusion_matrix(
                self.validation_labels, tf.math.round(custom_threshhold_predictions, name=None), num_classes=None, weights=None, dtype=tf.dtypes.int32,
                name=None
                )

            print(matrix)
            

            APCER_count = int(matrix[0][1])
            BPCER_count = int(matrix[1][0])
            APCER_SFNR = APCER_count / (APCER_count + int(matrix[0][0]))
            BPCER_SFPR = BPCER_count / (BPCER_count + int(matrix[1][1]))
            ACR = (APCER_SFNR + BPCER_SFPR) / 2

            print("APCER: " + str(APCER_SFNR) + " BPCER: " + str(BPCER_SFPR) + " ACR: " + str(ACR))
            
            self.real_reconstruction_error_history.append(real_reconstruction_error)
            self.fake_reconstruction_error_history.append(fake_reconstruction_error)
            self.APCER_history.append(APCER_SFNR)
            self.BCER_history.append(BPCER_SFPR)
            self.ACR_history.append(ACR)

# Patch based

In [None]:
validation_data, validation_labels = next(iter(ds_patch_validation_with_labels))

# GAN Experiment

In [None]:
init = RandomNormal(stddev=0.02)

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(4*4*1024, use_bias = False, input_shape = (200,), kernel_initializer=init))
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    model.add(layers.Reshape((4, 4, 1024)))
    
    model.add(layers.Conv2DTranspose(512, (5, 5), strides = (2,2), padding = "same", use_bias = False, kernel_initializer=init))
    assert model.output_shape == (None, 8, 8, 512)
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    
    model.add(layers.Conv2DTranspose(256, (5,5), strides = (2,2), padding = 'same', use_bias = False, kernel_initializer=init))
    assert model.output_shape == (None, 16, 16, 256)
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    
    model.add(layers.Conv2DTranspose(128, (5,5), strides = (2,2), padding = 'same', use_bias = False, kernel_initializer=init))
    assert model.output_shape == (None, 32, 32, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.ReLU())
    
    model.add(layers.Conv2DTranspose(1, (5,5), strides = (2,2), padding = 'same', use_bias = False, activation = 'tanh', kernel_initializer=init))
    assert model.output_shape == (None, 64, 64, 1)
    
    return model

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[64, 64, 1], kernel_initializer=init))
    assert model.output_shape == (None, 32, 32, 64)
    model.add(layers.LeakyReLU(alpha=0.2))
    #model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same', kernel_initializer=init))
    assert model.output_shape == (None, 16, 16, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    #model.add(layers.Dropout(0.3))
    
    model.add(layers.Conv2D(256, (5, 5), strides=(2, 2), padding='same', kernel_initializer=init))
    assert model.output_shape == (None, 8, 8, 256)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))
    
    model.add(layers.Conv2D(512, (5, 5), strides=(2, 2), padding='same', kernel_initializer=init))
    assert model.output_shape == (None, 4, 4, 512)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU(alpha=0.2))
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1,activation='sigmoid'))

    return model

In [None]:
discriminator = make_discriminator_model()
discriminator.load_weights('./disc/')
discriminator_base = tf.keras.Sequential(discriminator.layers[:-2])
discriminator_base.trainable = True
discriminator_base.summary()

In [None]:
inputs = keras.Input(shape = (64,64,1))
x = discriminator_base(inputs, training = False)
x = layers.Conv2D(1024, (5, 5), padding='same', kernel_initializer=init)(x) #latent vector
discriminator_new = keras.Model(inputs, x)
discriminator_new.summary()

In [None]:
generator = make_generator_model()
generator.load_weights('./gen/')
generator_base = tf.keras.Sequential(generator.layers[4:])
generator_base.trainable = True

inputs = keras.Input(shape = (4, 4, 1024))
x = generator_base(inputs, training = False)
generator_new = keras.Model(inputs, x)

In [None]:
gan_ae = tf.keras.models.Sequential([discriminator_new, generator_new])
gan_ae.compile(optimizer=tf.keras.optimizers.Adam(0.00001, beta_1=0.5), loss='mean_squared_error')
validationCallback = ValidationCallback(validation_data, validation_labels, gan_ae, 10)

In [None]:
history = gan_ae.fit(ds_patch_train,
                          epochs=6000,
                          callbacks=[validationCallback])

In [None]:
gan_ae.summary()

# Validation

In [None]:
model_to_store = gan_ae
model_path = './ae_run_1/'

In [None]:
os.mkdir(model_path)

with open(model_path + 'architecture.txt', 'w') as f:
    with redirect_stdout(f): 
        print(model_to_store.summary())   
        print('Batch_size:' + str(32))
        print('Optimizer:' + 'adam 0.00001')
        print('Loss:' + 'mse')
        
        
visualize_ae_result(100, model_to_store, model_path + 'plot1')

In [None]:
validationCallback.dump_stats(model_path + 'train_log')
model_to_store.save_weights(model_path + 'weights/')

In [None]:
model_to_store.save(model_path + 'saved_model/model')

In [None]:
with open(model_path + 'train_log') as json_file:
    data = json.load(json_file)

plt.plot(data['APCER'], label="APCER")
plt.plot(data['BCER'], label="BCER")
plt.plot(data['ACR'], label="ACR")
plt.legend()
plt.grid()
plt.show()

plt.plot(data['real_reconstruction_error'], label="real")
plt.plot(data['fake_reconstruction_error'], label="fake")
plt.grid()
plt.legend()
plt.show()

In [None]:
plt.plot(data['real_reconstruction_error'][:300], label="Bona Fide")
plt.plot(data['fake_reconstruction_error'][:300], label="Presentation Attack")
plt.ylim([0,0.15])
plt.grid()
plt.legend()
plt.savefig('training_progres')
plt.show()