In [1]:
"""Script to finetune AlexNet using Tensorflow.

With this script you can finetune AlexNet as provided in the alexnet.py
class on any given dataset. Specify the configuration settings at the
beginning according to your problem.
This script was written for TensorFlow >= version 1.2rc0 and comes with a blog
post, which you can find here:

https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html

Author: Frederik Kratzert
contact: f.kratzert(at)gmail.com
"""

import os

import numpy as np
import tensorflow as tf

from alexnet import AlexNet
from datagenerator import ImageDataGenerator
from datetime import datetime
Iterator = tf.data.Iterator

print(tf.__version__)

"""
Configuration Part.
"""

# Path to the textfiles for the trainings and validation set
train_file = 'train.txt'
val_file = 'dev.txt'

# Learning params
learning_rate = 0.01
num_epochs = 100
batch_size = 128

# Network params
dropout_rate = 0.5
num_classes = 6
train_layers = ['fc8', 'fc7', 'fc6']

# How often we want to write the tf.summary data to disk
display_step = 100

# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "tmp/finetune_alexnet/tensorboard"
checkpoint_path = "tmp/finetune_alexnet/checkpoints"
#checkpoint_file = None
#checkpoint_file = "tmp/finetune_alexnet/checkpoints/model_epoch1012.ckpt"

1.7.0


In [None]:
"""
Main Part of the finetuning Script.
"""

# Create parent path if it doesn't exist
if not os.path.isdir(checkpoint_path):
    os.mkdir(checkpoint_path)
    
# Create parent path if it doesn't exist
if not os.path.isdir(filewriter_path):
    os.mkdir(filewriter_path)
    
# Create parent path if it doesn't exist
if not os.path.isdir(filewriter_path + "/dev"):
    os.mkdir(filewriter_path + "/dev")

# Create parent path if it doesn't exist
if not os.path.isdir(filewriter_path + "/train"):
    os.mkdir(filewriter_path + "/train")

# Place data loading and preprocessing on the cpu
with tf.device('/CPU:0'):
    tr_data = ImageDataGenerator(train_file,
                                 mode='training',
                                 batch_size=batch_size,
                                 num_classes=num_classes,
                                 shuffle=True)
    val_data = ImageDataGenerator(val_file,
                                  mode='inference',
                                  batch_size=batch_size,
                                  num_classes=num_classes,
                                  shuffle=False)

    # create an reinitializable iterator given the dataset structure
    iterator = Iterator.from_structure(tr_data.data.output_types,
                                       tr_data.data.output_shapes)
    next_batch = iterator.get_next()

# Ops for initializing the two different iterators
training_init_op = iterator.make_initializer(tr_data.data)
validation_init_op = iterator.make_initializer(val_data.data)

# TF placeholder for graph input and output
x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
y = tf.placeholder(tf.float32, [batch_size, num_classes])
keep_prob = tf.placeholder(tf.float32)

# Initialize model
model = AlexNet(x, keep_prob, num_classes, train_layers)

# Link variable to model output
score = model.fc8

# List of trainable variables of the layers we want to train
var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]

# Op for calculating the loss
with tf.name_scope("cross_ent"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=score, labels=y))

# Train op
with tf.name_scope("train"):
    # Get gradients of all trainable variables
    gradients = tf.gradients(loss, var_list)
    gradients = list(zip(gradients, var_list))

    # Create optimizer and apply gradient descent to the trainable variables
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars=gradients)

# Add gradients to summary
for gradient, var in gradients:
    tf.summary.histogram(var.name + '/gradient', gradient)

# Add the variables we train to the summary
for var in var_list:
    tf.summary.histogram(var.name, var)

# Add the loss to summary
tf.summary.scalar('cross_entropy', loss)


# Evaluation op: Accuracy of the model
with tf.name_scope("accuracy"):
    correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Add the accuracy to the summary
tf.summary.scalar('accuracy', accuracy)

# Merge all summaries together
merged_summary = tf.summary.merge_all()

# Initialize the FileWriters
trainWriter = tf.summary.FileWriter(filewriter_path + "/train")
devWriter = tf.summary.FileWriter(filewriter_path + "/dev")

# Initialize an saver for store model checkpoints
saver = tf.train.Saver()

# Get the number of training/validation steps per epoch
train_batches_per_epoch = int(np.floor(tr_data.data_size / batch_size))
val_batches_per_epoch = int(np.floor(val_data.data_size / batch_size))

global_step = tf.Variable(0, trainable=False, name='global_step')
increment_global_step = tf.assign_add(global_step, 1, name = 'increment_global_step')

# Start Tensorflow session
with tf.Session() as sess:
    
    # Initialize all variables
    sess.run(tf.global_variables_initializer())

    # Load the pretrained weights into the non-trainable layer
    #if not checkpoint_file:
    model.load_initial_weights(sess)

    # Restore checkpoint
    if checkpoint_path:
        saver.restore(sess, tf.train.latest_checkpoint(checkpoint_path))
        
    currentStep = sess.run(global_step)
    print("Current step", currentStep)
    
    print("{} Start training...".format(datetime.now()))
    print("{} Open Tensorboard at --logdir {}".format(datetime.now(), filewriter_path))

    # Loop over number of epochs
    for epoch in range(num_epochs):

        print("{} Epoch number: {}".format(datetime.now(), epoch+1))

        # Initialize iterator with the training dataset
        sess.run(training_init_op)

        for step in range(train_batches_per_epoch):

            # get next batch of data
            img_batch, label_batch = sess.run(next_batch)

            # And run the training op
            sess.run([train_op, increment_global_step],
                     feed_dict={x: img_batch, y: label_batch, keep_prob: dropout_rate})

            # Generate summary with the current batch of data and write to file
            if step % display_step == 0:
                currentStep = sess.run(global_step)
                s = sess.run(merged_summary, feed_dict={x: img_batch, y: label_batch, keep_prob: 1.})
                trainWriter.add_summary(s, currentStep)

        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.now()))
        
        sess.run(validation_init_op)
        test_acc = 0.
        test_count = 0
        
        for _ in range(val_batches_per_epoch):
            img_batch, label_batch = sess.run(next_batch)
            acc = sess.run(accuracy, feed_dict={x: img_batch, y: label_batch, keep_prob: 1.})
            test_acc += acc
            test_count += 1
            
        test_acc /= test_count
        
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(), test_acc))
        
        # Output summary to dev writer
        currentStep = sess.run(global_step)
        s = sess.run(merged_summary, feed_dict={x: img_batch, y: label_batch, keep_prob: 1.})
        devWriter.add_summary(s, currentStep)
        
        # save checkpoint of the model
        checkpoint_name = os.path.join(checkpoint_path, 'model.ckpt')
        save_path = saver.save(sess, checkpoint_name, global_step=currentStep)

        print("{} Model checkpoint saved at {}".format(datetime.now(), checkpoint_name))


INFO:tensorflow:Summary name fc6/weights:0/gradient is illegal; using fc6/weights_0/gradient instead.
INFO:tensorflow:Summary name fc6/biases:0/gradient is illegal; using fc6/biases_0/gradient instead.
INFO:tensorflow:Summary name fc7/weights:0/gradient is illegal; using fc7/weights_0/gradient instead.
INFO:tensorflow:Summary name fc7/biases:0/gradient is illegal; using fc7/biases_0/gradient instead.
INFO:tensorflow:Summary name fc8/weights:0/gradient is illegal; using fc8/weights_0/gradient instead.
INFO:tensorflow:Summary name fc8/biases:0/gradient is illegal; using fc8/biases_0/gradient instead.
INFO:tensorflow:Summary name fc6/weights:0 is illegal; using fc6/weights_0 instead.
INFO:tensorflow:Summary name fc6/biases:0 is illegal; using fc6/biases_0 instead.
INFO:tensorflow:Summary name fc7/weights:0 is illegal; using fc7/weights_0 instead.
INFO:tensorflow:Summary name fc7/biases:0 is illegal; using fc7/biases_0 instead.
INFO:tensorflow:Summary name fc8/weights:0 is illegal; using f