<a href="https://colab.research.google.com/github/starkdg/pyConvnetPhash/blob/master/train_deep_cae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/gdrive')

import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf
import tensorflow_hub as hub

n_inputs = 1792
jnorm_reg = 0.0

model_tag = 0
frozen_model = "/gdrive/My Drive/models/deepautoencoder/mobilenetv2_deep_autoenc_frozen_model{0}.pb".format(model_tag)

training_files_dir = "/gdrive/My Drive/imageset/train"
validation_files_dir = "/gdrive/My Drive/imageset/validation"
testing_files_dir = "/gdrive/My Drive/imageset/test"

batch_size = 10
epochs = 15
steps = 2000
learning_rate = 0.0004

model_dir = "/gdrive/My Drive/models"
module_inception_url = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/1"
module_mobilenetv2_url = "https://tfhub.dev/google/imagenet/mobilenet_v2_140_224/feature_vector/2"


# module graph and session (for extracting features from hub module)
module_graph = tf.Graph()
module_sess = tf.Session(graph=module_graph)  
with module_graph.as_default():
  module = hub.Module(module_mobilenetv2_url)
  target_height, target_width = hub.get_expected_image_size(module)

# autoencoder graph  and session
aec_graph = tf.Graph()
aec_sess = tf.Session(graph=aec_graph)

In [0]:
def get_weights():
  weights = dict()
  with tf.variable_scope("weights", reuse=tf.AUTO_REUSE):
    weights['w1'] = tf.get_variable('w1', shape=[n_inputs, 1024], trainable=True)
    weights['b1'] = tf.get_variable('b1', shape=(1024), trainable=True)

    weights['w2'] = tf.get_variable('w2', shape=[1024, 512], trainable=True)
    weights['b2'] = tf.get_variable('b2', shape=(512), trainable=True)
  
    weights['w3'] = tf.get_variable('w3', shape=[512, 256], trainable=True)
    weights['b3'] = tf.get_variable('b3', shape=(256), trainable=True)
    
    weights['w4'] = tf.transpose(weights['w3'])
    weights['b4'] = tf.get_variable('b4', shape=(512), trainable=True)
  
    weights['w5'] = tf.transpose(weights['w2'])
    weights['b5'] = tf.get_variable('b5', shape=(1024), trainable=True)
  
    weights['w6'] = tf.transpose(weights['w1'])
    weights['b6'] = tf.get_variable('b6', shape=(n_inputs), trainable=True)
  
  return weights     

  
def load_weights(**wts):  
  saver1 = tf.train.Saver({'w1': wts['w1'], 
                           'b1': wts['b1'],
                           'b2': wts['b6']})
  try:
    saver1.restore(aec_sess, '/gdrive/My Drive/models/cae1/cae.chp-1')
    print("restore layers 1, 6")
  except:
    print("unable to restore layers 1, 6")
    
    
  saver2 = tf.train.Saver({'w1': wts['w2'],
                           'b1': wts['b2'],
                           'b2': wts['b5']})
  try:
    saver2.restore(aec_sess, '/gdrive/My Drive/models/cae2/cae.chp-1')
    print("restore layers 2, 5")
  except:
    print("unable to restore layers 2, 5")
  
  saver3 = tf.train.Saver({'w1': wts['w3'],
                           'b1': wts['b3'],
                           'b2': wts['b4']})
  try:
    saver3.restore(aec_sess,'/gdrive/My Drive/models/cae3/cae.chp-1')
    print("restore layers 3, 4")
  except:
    print("unable to restore layers 3, 4")
   
 
def plot_weights_and_biases(**weights):
  with aec_sess.as_default():
    w1 = weights['w1'].eval()
    w2 = weights['w2'].eval()
    w3 = weights['w3'].eval()
    b1 = weights['b1'].eval()
    b2 = weights['b2'].eval()
    b3 = weights['b3'].eval()
    b4 = weights['b4'].eval()
    b5 = weights['b5'].eval()
    b6 = weights['b6'].eval()
    
  plt.figure(102)
  plt.hist([w1.ravel(), w2.ravel(), w3.ravel()], bins=100, histtype='bar', stacked=True)
  plt.legend(['w1', 'w2', 'w3'], loc='upper right')
  plt.title('Histogram of weights')
  plt.show()
    
  plt.figure(103)
  plt.hist([b1.ravel(), b2.ravel(), b3.ravel(), b4.ravel(), b5.ravel(), b6.ravel()], bins=100, histtype='bar', stacked=True)
  plt.legend(['b1', 'b2', 'b3', 'b4', 'b5', 'b6'], loc='upper right')
  plt.title('Histogram of biases')
  plt.show()

In [0]:
def create_deep_autoencoder(learning_rate, lambda_reg):
    
  with aec_graph.as_default():
    weights = get_weights()
    x = tf.placeholder(tf.float32, shape=(None, n_inputs), name="input")  
    
    reg_term = tf.constant(lambda_reg, tf.float32, name="jnorm_reg")
    
    num_x = tf.subtract(x, tf.reduce_min(x))
    den_x = tf.subtract(tf.reduce_max(x), tf.reduce_min(x))
    norm_x = tf.math.xdivy(num_x, den_x, name="normalization")
    
    # input_dims -> 1024
    layer1 = tf.nn.sigmoid(tf.add(tf.matmul(norm_x, weights['w1']), weights['b1']), name="output1024")
    
    # 1024 -> 512
    layer2 = tf.nn.sigmoid(tf.add(tf.matmul(layer1, weights['w2']), weights['b2']), name="output512")
    
    # 512 -> 256
    layer3 = tf.nn.sigmoid(tf.add(tf.matmul(layer2, weights['w3']), weights['b3']), name="output256")
    
    # reconstruction  
    # 256 -> 512
    layer4 = tf.identity(tf.add(tf.matmul(layer3, weights['w4']), weights['b4']), name="layer4")
    
    # 512 -> 1024
    layer5 = tf.identity(tf.add(tf.matmul(layer4, weights['w5']), weights['b5']), name="layer5") 
    
    # 1024 -> input_dims
    y = tf.identity(tf.add(tf.matmul(layer5, weights['w6']), weights['b6']), name="y")
    
    # Jacobian norm
    dhi1 = tf.square(tf.multiply(layer1, tf.subtract(1., layer1)))                  # N x 1024
    dwj1 = tf.reduce_sum(tf.square(weights['w6']), axis=1, keepdims=True)           # 1024 x n_input => 1024 x 1
    jnorm1 = tf.matmul(dhi1, dwj1, name="jnorm1")                                     # N x 1 
      
    # Jacobian norm
    dhi2 = tf.square(tf.multiply(layer2, tf.subtract(1., layer2)))                  # N x 512
    dwj2 = tf.reduce_sum(tf.square(weights['w5']), axis=1, keepdims=True)           # 512 x n_input => 512 x 1
    jnorm2 = tf.matmul(dhi2, dwj2, name="jnorm2")                                   # N x 1 
        
    # Jacobian norm
    dhi3 = tf.square(tf.multiply(layer3, tf.subtract(1., layer3)))                  # N x 256
    dwj3 = tf.reduce_sum(tf.square(weights['w4']), axis=1, keepdims=True)           # 256 x 1792 => 256 x 1
    jnorm3 = tf.matmul(dhi3, dwj3, name="jnorm3")                                   # N x 1 
        
    jnorm_total = tf.add_n([jnorm1, jnorm2, jnorm3])
    avg_jnorm = tf.reduce_mean(jnorm_total)
    cost = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=norm_x, logits=y), axis=1, keepdims=True)
    avg_cost = tf.reduce_mean(tf.add(cost, tf.multiply(reg_term, jnorm_total)))

    with tf.variable_scope("opt", reuse=tf.AUTO_REUSE): 
      optimizer = tf.train.AdamOptimizer(learning_rate) 
      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
      with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(cost)
     
    return x, layer3, weights, avg_cost, avg_jnorm, train_op


In [0]:
def get_tfrecord_files(path):
  files = []
  for entry in os.scandir(path):
    if entry.is_file() and entry.name.endswith('.tfrecord'):
             files.append(entry.path)
  return files
  
  
def _parse_example(example):
  features = {'height': tf.FixedLenFeature([], tf.int64),
              'width': tf.FixedLenFeature([], tf.int64),
              'image_raw': tf.FixedLenFeature([], tf.string)}
  parsed_features = tf.parse_single_example(example, features)
  img = tf.io.decode_raw(parsed_features['image_raw'], tf.uint8)
  height = tf.cast(parsed_features['height'], tf.int32)
  width = tf.cast(parsed_features['width'], tf.int32)

  img_reshaped = tf.manip.reshape(img, [height, width, 3])
  imgfl = tf.image.convert_image_dtype(img_reshaped, dtype=tf.float32)
  img_norm = tf.expand_dims(imgfl, 0)
  img_resized = tf.image.resize_bicubic(img_norm, [target_height, target_width])
  img_resized = tf.squeeze(img_resized, 0)
  return img_resized


def input_function(path, batch_size=1, num_epochs=None, shuffle=False):
  tfrecords = get_tfrecord_files(path)
  dataset = tf.data.TFRecordDataset(tfrecords)
  dataset = dataset.map(_parse_example)
  if (shuffle):
    dataset = dataset.shuffle(10000)
  dataset = dataset.batch(batch_size).repeat(num_epochs)
  iterator = dataset.make_initializable_iterator()
  return iterator


In [0]:
def train_autoenc_model(training_files_dir,
                        validation_files_dir,
                        testing_files_dir,
                        batch_size, epochs, steps, learning_rate, lambda_reg):
    period_size = 100
    input_dims = n_inputs
    
    x, y, wts, recon_cost, jnorm_cost, train_op = create_deep_autoencoder(learning_rate, lambda_reg)
    
    with module_graph.as_default():  
      training_iter = input_function(training_files_dir, batch_size)
      training_images = training_iter.get_next()
      training_features = module(training_images)
        
      validation_iter = input_function(validation_files_dir, batch_size)
      validation_images = validation_iter.get_next()
      validation_features = module(validation_images)
        
      testing_iter = input_function(testing_files_dir, 100)
      testing_images = testing_iter.get_next()
      testing_features = module(testing_images)
       
      module_init = tf.global_variables_initializer()  
    
    module_sess.run(module_init)
  
    with aec_graph.as_default():
      aec_init = tf.initializers.global_variables()
      
    aec_sess.run(aec_init)  
    
    load_weights(**wts)
    
    print("plot weights and biases before fine-tuning")
    plot_weights_and_biases(**wts)
    
    train_losses = []
    valid_losses = []
    jnorm_losses = []
    print("Train Deep Autoencoder")
    for i in range(epochs):
      module_sess.run([training_iter.initializer, validation_iter.initializer])
      iteration = 0
      total_cost = 0.
      total_val_cost = 0.
      total_val_jnorm_cost = 0.
      while True:
        try:
          Xtrain = module_sess.run(training_features)
          train_cost, opt = aec_sess.run([recon_cost, train_op], feed_dict={x: Xtrain})
          if (iteration % period_size == 0):
            Xvalid = module_sess.run(validation_features)
            validation_cost, valid_jnorm_cost = aec_sess.run([recon_cost, jnorm_cost], feed_dict={x: Xvalid})
            total_cost += train_cost
            total_val_cost += validation_cost
            total_val_jnorm_cost += valid_jnorm_cost
          iteration = iteration + 1
        except tf.errors.OutOfRangeError:
          break
        if (iteration > steps):
          break
                
      steps_taken = iteration//period_size
      avg_train_loss = total_cost/steps_taken
      avg_val_loss = total_val_cost/steps_taken
      avg_jnorm_loss = total_val_jnorm_cost/steps_taken
      print("epoch {0} training cost {1} valid. cost {2} (jnorm {3})".format(i+1, avg_train_loss, avg_val_loss, avg_jnorm_loss))
      train_losses.append(avg_train_loss)
      valid_losses.append(avg_val_loss)
      jnorm_losses.append(avg_jnorm_loss)
      
    plt.figure(101)
    plt.plot(train_losses)
    plt.plot(valid_losses)
    plt.plot(jnorm_losses)
    plt.plot()
    plt.title("Deep Autoencoder 1792->1024->512->256")
    plt.xlabel("epochs")
    plt.ylabel("cost")
    plt.legend(["training", "validation", "jnorm"], loc="upper right")
    plt.show()
    
    print("run test on 100 images")
    module_sess.run([testing_iter.initializer])
    Xtest = module_sess.run(testing_features)
    testing_cost, testing_jnorm = aec_sess.run([recon_cost, jnorm_cost], feed_dict={x: Xtest})
    print("test cost = {0} (jnorm = {1})".format(testing_cost, testing_jnorm))

    print("plot weights and biases after fine-tuning")
    plot_weights_and_biases(**wts)

In [0]:
def save_aec_graph_to_file():
  aec_graphdef = aec_graph.as_graph_def()
  aec_subgraphdef = tf.compat.v1.graph_util.extract_sub_graph(aec_graphdef, ['output256'])
  aec_subgraphdef = tf.compat.v1.graph_util.remove_training_nodes(aec_subgraphdef)
  aec_subgraphdef_frozen = tf.compat.v1.graph_util.convert_variables_to_constants(aec_sess, aec_subgraphdef, ['output256'])
        
  with tf.gfile.GFile(frozen_model, "wb") as f:
    f.write(aec_subgraphdef_frozen.SerializeToString())

In [0]:

print("Train autoencoder")
print("training files: ", training_files_dir)
print("validation files: ", validation_files_dir)
print("testing files: ", testing_files_dir)
print("learning_rate: ", learning_rate)
print("batch size: ", batch_size)
print("epochs: ", epochs)
print("steps: ", steps)
print("jnorm reg: ", jnorm_reg)
train_autoenc_model(training_files_dir,
                    validation_files_dir,
                    testing_files_dir,
                    batch_size, epochs, steps,
                    learning_rate, jnorm_reg)

print("save autoencoder graph to file: ", frozen_model)
save_aec_graph_to_file()
  

  