In [1]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [2]:
import tensorflow as tf
print(tf.__version__)

2.1.0


In [0]:
from tensorflow.keras import layers

In [0]:
import numpy as np
import os
import time

In [0]:
import tensorflow_addons as tfa

# Models

In [0]:
initializer = tf.random_normal_initializer(0., 0.02)

### CRU Node

In [0]:
class CRU(tf.keras.Model):
  def __init__(self):
    super(CRU, self).__init__()

    self.conv1 = layers.Conv2D(40, (3, 3), strides=1, padding='SAME', kernel_initializer=initializer, use_bias=False)
    self.batch1 = tfa.layers.GroupNormalization(groups=2, axis=3)

    self.conv2 = layers.Conv2D(40, (3, 3), strides=1, padding='SAME', kernel_initializer=initializer, use_bias=False)
    self.batch2 = tfa.layers.GroupNormalization(groups=2, axis=3)

    self.conv3 = layers.Conv2D(40, (3, 3), strides=1, padding='SAME', kernel_initializer=initializer, use_bias=False)
    self.batch3 = tfa.layers.GroupNormalization(groups=2, axis=3)
  
  def call(self, x, pooling=True):
    model_1 = self.conv1(x)
    model_1 = self.batch1(model_1)
    model_1 = layers.ReLU()(model_1)

    model_1 = x + model_1

    model_2 = self.conv2(model_1)
    model_2 = self.batch2(model_2)
    model_2 = layers.ReLU()(model_2)

    model_2 = model_1 + model_2

    model_3 = self.conv3(model_1)
    model_3 = self.batch3(model_2)
    model_3 = layers.ReLU()(model_2)

    model = model_3 + model_2

    if pooling:
      model = layers.MaxPool2D()(model)

    return model

### Projection

In [0]:
class ProjectionLayer(tf.keras.Model):
  def __init__(self, layer_index, input_dim=5120):
    super(ProjectionLayer, self).__init__()
    self.v = tf.Variable(initial_value=initializer(shape=(input_dim, 1), dtype='float32'), trainable=True, name="v_{}".format(layer_index))
  
  def call(self, compressed_x):
    unit_v = self.v / (tf.norm(self.v) + 1e-8)
    projection_route = tf.matmul(compressed_x, unit_v)

    return projection_route, unit_v

### TRU Node

In [0]:
class TRU(tf.keras.Model):
  def __init__(self, filters=20, size=16):
    super(TRU, self).__init__()

    self.conv = layers.Conv2D(filters, (1, 1), strides=1, padding='SAME', kernel_initializer=initializer)

    self.resize = layers.Lambda(lambda layer: tf.image.resize( 
      layer, 
      (size, size), 
      method=tf.image.ResizeMethod.BILINEAR
    ))

    self.reshape = layers.Reshape((size * size * filters,))
    self.batch = layers.BatchNormalization(scale=False)
  
  def call(self, x):
    model = self.conv(x)
    model = self.resize(model)
    model = self.reshape(model)
    model = self.batch(model)

    return model

### SFL Node

In [0]:
class SFL(tf.keras.Model):
  def __init__(self):
    super(SFL, self).__init__()

    self.conv1 = layers.Conv2D(1, (1, 1), activation="sigmoid", kernel_initializer=initializer)

    self.conv2 = layers.Conv2D(40, (3,3), kernel_initializer=initializer)
    self.conv3 = layers.Conv2D(40, (3,3), kernel_initializer=initializer)
    self.flat = layers.Flatten()

    self.dense1 = layers.Dense(500, kernel_initializer=initializer)
    self.dense2 = layers.Dense(1, activation="sigmoid", kernel_initializer=initializer)
  
  def call(self, x):
    mask = self.conv1(x)

    classification = self.conv2(x)
    classification = self.conv3(classification)
    classification = self.flat(classification)

    classification = self.dense1(classification)
    classification = self.dense2(classification)

    return classification, mask

### Unsupervised loss functions

In [0]:
alpha = 1e-3
beta = 1e-2

a3 = 2.0
a4 = 0.001

In [0]:
def compute_tru_loss(spoof_x, not_x, unit_v, training):
  loss = 0.0
  if training and tf.shape(spoof_x)[0] > 0:
    transpose_unit_v = tf.transpose(unit_v, [1, 0])
    covariance_matrix = tf.matmul(spoof_x, spoof_x, transpose_a=True)
    trace = tf.linalg.trace(covariance_matrix)

    eigenvalue = tf.matmul(tf.matmul(transpose_unit_v, covariance_matrix), unit_v)

    route_loss = tf.exp(-alpha * eigenvalue) + beta * trace

    spoof_x_loss = -tf.reduce_mean(tf.square(tf.matmul(spoof_x, unit_v)))
    not_x_loss = tf.reduce_mean(tf.square(tf.matmul(not_x, unit_v)))

    unique_loss = spoof_x_loss + not_x_loss

    loss = (a3 * route_loss) + (a4 * unique_loss)

  return loss 

In [0]:
class DTN(tf.keras.Model):
  def __init__(self):
    super(DTN, self).__init__()

    self.input_layer = layers.Conv2D(40, (1, 1))

    self.cru_1 = CRU()
    
    self.cru_2 = CRU()
    self.cru_3 = CRU()

    self.cru_4 = CRU()
    self.cru_5 = CRU()
    self.cru_6 = CRU()
    self.cru_7 = CRU()
    
    self.cru_8 = CRU()
    self.cru_9 = CRU()
    self.cru_10 = CRU()
    self.cru_11 = CRU()
    self.cru_12 = CRU()
    self.cru_13 = CRU()
    self.cru_14 = CRU()
    self.cru_15 = CRU()

    self.tru_1 = TRU(filters=10, size=32)
    
    self.tru_2 = TRU()
    self.tru_3 = TRU()

    self.tru_4 = TRU()
    self.tru_5 = TRU()
    self.tru_6 = TRU()
    self.tru_7 = TRU()

    self.sfl_1 = SFL()
    self.sfl_2 = SFL()
    self.sfl_3 = SFL()
    self.sfl_4 = SFL()
    self.sfl_5 = SFL()
    self.sfl_6 = SFL()
    self.sfl_7 = SFL()
    self.sfl_8 = SFL()

    self.projection_root_layer = ProjectionLayer(layer_index=1, input_dim=10240)
    
    self.projection_left_1_1_layer = ProjectionLayer(layer_index=2, input_dim=5120)
    self.projection_right_1_2_layer = ProjectionLayer(layer_index=3, input_dim=5120)

    self.projection_left_2_1_layer = ProjectionLayer(layer_index=4, input_dim=5120)
    self.projection_right_2_2_layer = ProjectionLayer(layer_index=5, input_dim=5120)
    self.projection_left_2_3_layer = ProjectionLayer(layer_index=6, input_dim=5120)
    self.projection_right_2_4_layer = ProjectionLayer(layer_index=7, input_dim=5120)
  
  @tf.function
  def call(self, x, labels, training=True):
    input_layer = self.input_layer(x)

    # Root Level
    cru_root = self.cru_1(input_layer)
    tru_root = self.tru_1(cru_root)
    projection_root, root_unit_v = self.projection_root_layer(tru_root)

    # Root Masking

    if training:
      spoof_mask = tf.where(labels == 1, x=True, y=False)
      live_mask = tf.where(labels == 0, x=True, y=False)

      right_data = tf.boolean_mask(tru_root, spoof_mask)
      left_data = tf.boolean_mask(tru_root, live_mask)
    else:
      right_data = np.array([])
      left_data = np.array([])

    unsupervised_root_loss = compute_tru_loss(right_data, left_data, root_unit_v, training)

    # Level 1 ================================================================================================================================

    # Block 1, 2 Create masks

    right_mask = tf.where(projection_root >= 0, x=True, y=False)
    left_mask = tf.where(projection_root < 0, x=True, y=False)

    # Block 1 (left node)
    cru_left_1_1 = self.cru_2(cru_root)
    tru_left_1_1 = self.tru_2(cru_left_1_1)
    projection_left_1_1, left_1_1_unit_v = self.projection_left_1_1_layer(tru_left_1_1)

    # Block 1 Masking

    right_data = tf.boolean_mask(tru_left_1_1, tf.squeeze(right_mask))
    left_data = tf.boolean_mask(tru_left_1_1, tf.squeeze(left_mask))

    unsupervised_left_1_1_loss = compute_tru_loss(left_data, right_data, left_1_1_unit_v, training)

    # Block 2 (right node)

    cru_right_1_2 = self.cru_3(cru_root)
    tru_right_1_2 = self.tru_3(cru_right_1_2)
    projection_right_1_2, right_1_2_unit_v = self.projection_right_1_2_layer(tru_right_1_2)

    # Block 2 Masking

    right_data = tf.boolean_mask(tru_right_1_2, tf.squeeze(right_mask))
    left_data = tf.boolean_mask(tru_right_1_2, tf.squeeze(left_mask))

    unsupervised_right_1_2_loss = compute_tru_loss(right_data, left_data, right_1_2_unit_v, training)

    # Level 2 ===============================================================================================================================

    # Block 1, 2 Create masks

    right_mask = tf.where(projection_left_1_1 >= 0, x=True, y=False)
    left_mask = tf.where(projection_left_1_1 < 0, x=True, y=False)

    # Block 1 (left node)
    cru_left_2_1 = self.cru_4(cru_left_1_1)
    tru_left_2_1 = self.tru_4(cru_left_2_1)
    projection_left_2_1, left_2_1_unit_v = self.projection_left_2_1_layer(tru_left_2_1)

    # Block 1 Masking

    right_data = tf.boolean_mask(tru_left_2_1, tf.squeeze(right_mask))
    left_data = tf.boolean_mask(tru_left_2_1, tf.squeeze(left_mask))

    unsupervised_left_2_1_loss = compute_tru_loss(left_data, right_data, left_2_1_unit_v, training)

    # Block 2 (right node)

    cru_right_2_2 = self.cru_5(cru_left_1_1)
    tru_right_2_2 = self.tru_5(cru_right_2_2)
    projection_right_2_2, right_2_2_unit_v = self.projection_right_2_2_layer(tru_right_2_2)

    # Block 2 Masking
    right_data = tf.boolean_mask(tru_right_2_2, tf.squeeze(right_mask))
    left_data = tf.boolean_mask(tru_right_2_2, tf.squeeze(left_mask))

    unsupervised_right_2_2_loss = compute_tru_loss(right_data, left_data, right_2_2_unit_v, training)

    # Block 3, 4 Create masks

    right_mask = tf.where(projection_right_1_2 >= 0, x=True, y=False)
    left_mask = tf.where(projection_right_1_2 < 0, x=True, y=False)

    # Block 3 (left node)
    cru_left_2_3 = self.cru_6(cru_right_1_2)
    tru_left_2_3 = self.tru_6(cru_left_2_3)
    projection_left_2_3, left_2_3_unit_v = self.projection_left_2_3_layer(tru_left_2_3)

    # Block 3 Masking

    right_data = tf.boolean_mask(tru_left_2_3, tf.squeeze(right_mask))
    left_data = tf.boolean_mask(tru_left_2_3, tf.squeeze(left_mask))

    unsupervised_left_2_3_loss = compute_tru_loss(left_data, right_data, left_2_3_unit_v, training)

    # Block 4 (right node)

    cru_right_2_4 = self.cru_7(cru_right_1_2)
    tru_right_2_4 = self.tru_7(cru_right_2_4)
    projection_right_2_4, right_2_4_unit_v = self.projection_right_2_4_layer(tru_right_2_4)

    # Block 4 Masking

    right_data = tf.boolean_mask(tru_right_2_4, tf.squeeze(right_mask))
    left_data = tf.boolean_mask(tru_right_2_4, tf.squeeze(left_mask))

    unsupervised_right_2_4_loss = compute_tru_loss(right_data, left_data, right_2_4_unit_v, training)

    # Level 3 ===============================================================================================================================

    # Block 1 (left node)

    cru_left_3_1 = self.cru_8(cru_left_2_1, pooling=False)
    class_left_3_1, map_left_3_1 = self.sfl_1(cru_left_3_1)

    mask_left_3_1 = tf.where(projection_left_2_1 < 0, x=True, y=False)

    # Block 2 (right node)

    cru_right_3_2 = self.cru_9(cru_left_2_1, pooling=False)
    class_right_3_2, map_right_3_2 = self.sfl_2(cru_right_3_2)

    mask_right_3_2 = tf.where(projection_left_2_1 >= 0, x=True, y=False)

    # Block 3 (left node)

    cru_left_3_3 = self.cru_10(cru_right_2_2, pooling=False)
    class_left_3_3, map_left_3_3 = self.sfl_3(cru_left_3_3)

    mask_left_3_3 = tf.where(projection_right_2_2 < 0, x=True, y=False)

    # Block 4 (right node)

    cru_right_3_4 = self.cru_11(cru_right_2_2, pooling=False)
    class_right_3_4, map_right_3_4 = self.sfl_4(cru_right_3_4)

    mask_right_3_4 = tf.where(projection_right_2_2 >= 0, x=True, y=False)

    # Block 5 (left node)

    cru_left_3_5 = self.cru_12(cru_left_2_3, pooling=False)
    class_left_3_5, map_left_3_5 = self.sfl_5(cru_left_3_5)

    mask_left_3_5 = tf.where(projection_left_2_3 < 0, x=True, y=False)

    # Block 6 (right node)

    cru_right_3_6 = self.cru_13(cru_left_2_3, pooling=False)
    class_right_3_6, map_right_3_6 = self.sfl_6(cru_right_3_6)

    mask_right_3_6 = tf.where(projection_left_2_3 >= 0, x=True, y=False)

    # Block 7 (left node)

    cru_left_3_7 = self.cru_14(cru_right_2_4, pooling=False)
    class_left_3_7, map_left_3_7 = self.sfl_7(cru_left_3_7)

    mask_left_3_7 = tf.where(projection_right_2_4 < 0, x=True, y=False)

    # Block 8 (right node)

    cru_right_3_8 = self.cru_15(cru_right_2_4, pooling=False)
    class_right_3_8, map_right_3_8 = self.sfl_8(cru_right_3_8)

    mask_right_3_8 = tf.where(projection_right_2_4 >= 0, x=True, y=False)

    # Outputs

    classes_pred = [class_left_3_1, class_right_3_2, class_left_3_3, class_right_3_4,
                    class_left_3_5, class_right_3_6, class_left_3_7, class_right_3_8]

    maps_pred = [map_left_3_1, map_right_3_2, map_left_3_3, map_right_3_4,
                  map_left_3_5, map_right_3_6, map_left_3_7, map_right_3_8]

    unsupervised_loss = unsupervised_root_loss + unsupervised_left_1_1_loss + unsupervised_right_1_2_loss
    unsupervised_loss += unsupervised_left_2_1_loss + unsupervised_right_2_2_loss + unsupervised_left_2_3_loss
    unsupervised_loss += unsupervised_right_2_4_loss

    masks = [mask_left_3_1, mask_right_3_2, mask_left_3_3, mask_right_3_4, 
              mask_left_3_5, mask_right_3_6, mask_left_3_7, mask_right_3_8]

    if training:
      return classes_pred, maps_pred, masks, unsupervised_loss
    else:
      return classes_pred, maps_pred, masks

# Training

In [0]:
model = DTN()

In [0]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.5)
batch_size = 32
epochs = 15

### Checkpoints

In [0]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

### Supervised Loss Functions

In [0]:
binary_cross_entropy = tf.keras.losses.BinaryCrossentropy()
L1_loss = tf.keras.losses.MeanAbsoluteError()

In [0]:
a1 = 0.001
a2 = 1.0

In [0]:
def compute_sfl_loss(labels, masks_true, classes_pred, masks_pred, masks):
  total_loss = 0.0

  for pred, mask_pred, mask in zip(classes_pred, masks_pred, masks):
    data_in_node = tf.reduce_sum(tf.cast(mask, tf.float32))

    if data_in_node > 0:
      masked_labels = tf.boolean_mask(labels, tf.squeeze(mask))
      masked_classes = tf.boolean_mask(pred, tf.squeeze(mask))
      masked_mask = tf.boolean_mask(mask_pred, tf.squeeze(mask))

      classification_loss = binary_cross_entropy(masked_labels, masked_classes)
      mask_loss = L1_loss(masks_true, masks_pred)

      total_loss += (a1 * classification_loss) + (a2 * mask_loss)

  return total_loss 

### Dummy data

In [0]:
label = np.random.randint(2, size=32)
labels_batches = [label, label, label] # Simulates an array with 3 batches
labels_batches = tf.cast(labels_batches, dtype=tf.float32)

In [0]:
images_batches = tf.random.uniform([3, 32, 256, 256, 6], minval=0, maxval=255, dtype=tf.int64)
images_batches = tf.cast(images_batches, dtype=tf.float32)

In [0]:
masks_batches = tf.random.uniform([3, 32, 32, 32, 1], minval=0, maxval=2, dtype=tf.int64)
masks_batches = tf.cast(masks_batches, dtype=tf.float32)

### Train functions

In [0]:
@tf.function
def train_step(images, labels, masks_true, total_steps):
  with tf.GradientTape() as tape:
    classes_pred, maps_pred, masks, unsupervised_loss = model(images, labels, training=True)

    supervised_loss = compute_sfl_loss(labels, masks_true, classes_pred, maps_pred, masks)

    if total_steps > 10000:
      loss = supervised_loss + unsupervised_loss
    else:
      loss = supervised_loss
      
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  return loss

In [0]:
loss_results = []
total_steps = 0
train_steps = 3

In [0]:
def train(epochs, total_steps):
  for epoch in range(epochs):
    batch_time = time.time()
    epoch_time = time.time()
    step = 0

    epoch_count = f"0{epoch + 1}/{epochs}" if epoch < 9 else f"{epoch + 1}/{epochs}"

    # Change for a true generator
    for images, labels, masks_true in zip(images_batches, labels_batches, masks_batches):
      loss = train_step(images, labels, masks_true, total_steps)

      loss = float(loss.numpy())
      step += 1

      print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
            '| loss:', f"{loss:.5f}", "| Step time:", f"{time.time() - batch_time:.2f}", end='')    
      
      batch_time = time.time()
      total_steps += 1

    loss_results.append(loss)

    checkpoint.save(file_prefix=checkpoint_prefix)

    print('\r', 'Epoch', epoch_count, '| Step', f"{step}/{train_steps}",
          '| loss:', "| Epoch time:", f"{time.time() - epoch_time:.2f}")

In [0]:
train(epochs, total_steps)