In [1]:
"""Script to finetune AlexNet using Tensorflow.

With this script you can finetune AlexNet as provided in the alexnet.py
class on any given dataset. Specify the configuration settings at the
beginning according to your problem.
This script was written for TensorFlow >= version 1.2rc0 and comes with a blog
post, which you can find here:

https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html

Author: Frederik Kratzert
contact: f.kratzert(at)gmail.com
"""

import os

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
#import cv2

from alexnet import AlexNet
from datagenerator import ImageDataGenerator
from datetime import datetime
from tensorflow.contrib.data import Iterator

"""
Configuration Part.
"""



  from ._conv import register_converters as _register_converters


'\nConfiguration Part.\n'

In [2]:
# Path to the textfiles for the trainings and validation set
train_file = './train.txt'
val_file = './testoh.txt'

# Learning params
learning_rate = 0.02
num_epochs = 30
batch_size = 64

# Network params
dropout_rate = 0.5
num_classes = 2
train_layers = ['fc8','fc7','conv1']

# How often we want to write the tf.summary data to disk
display_step = 20

# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "/tensorboard"
checkpoint_path = "/checkpoints"

"""
Main Part of the finetuning Script.
"""

# Create parent path if it doesn't exist
if not os.path.isdir(checkpoint_path):
    os.mkdir(checkpoint_path)

# Place data loading and preprocessing on the cpu


In [3]:
with tf.device('/gpu:0'):
    tr_data = ImageDataGenerator(train_file,
                                 mode='training',
                                 batch_size=batch_size,
                                 num_classes=num_classes,
                                 shuffle=True)
    val_data = ImageDataGenerator(val_file,
                                  mode='inference',
                                  batch_size=batch_size,
                                  num_classes=num_classes,
                                  shuffle=False)
    


Instructions for updating:
Use `tf.data.Dataset.from_tensor_slices()`.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.


In [4]:
# create an reinitializable iterator given the dataset structure
iterator = Iterator.from_structure(tr_data.data.output_types,
                                       tr_data.data.output_shapes)

next_batch = iterator.get_next()

# Ops for initializing the two different iterators
training_init_op = iterator.make_initializer(tr_data.data)
validation_init_op = iterator.make_initializer(val_data.data)

# TF placeholder for graph input and output
x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
y = tf.placeholder(tf.float32, [batch_size, num_classes])
keep_prob = tf.placeholder(tf.float32)

In [5]:
# Initialize model
model = AlexNet(x, keep_prob, num_classes, train_layers)

# Link variable to model output
score = model.fc8

# List of trainable variables of the layers we want to train
var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]

In [6]:
# Op for calculating the loss
with tf.name_scope("cross_ent"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=score,
                                                                  labels=y))

# Train op
with tf.name_scope("train"):
    # Get gradients of all trainable variables
    gradients = tf.gradients(loss, var_list)
    gradients = list(zip(gradients, var_list))

    # Create optimizer and apply gradient descent to the trainable variables
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars=gradients)

# Add gradients to summary
for gradient, var in gradients:
    tf.summary.histogram(var.name + '/gradient', gradient)

# Add the variables we train to the summary
for var in var_list:
    tf.summary.histogram(var.name, var)

# Add the loss to summary
tf.summary.scalar('cross_entropy', loss)

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

INFO:tensorflow:Summary name conv1/weights:0/gradient is illegal; using conv1/weights_0/gradient instead.
INFO:tensorflow:Summary name conv1/biases:0/gradient is illegal; using conv1/biases_0/gradient instead.
INFO:tensorflow:Summary name fc7/weights:0/gradient is illegal; using fc7/weights_0/gradient instead.
INFO:tensorflow:Summary name fc7/biases:0/gradient is illegal; using fc7/biases_0/gradient instead.
INFO:tensorflow:Summary name fc8/weights:0/gradient is illegal; using fc8/weights_0/gradient instead.
INFO:tensorflow:Summary name fc8/biases:0/gradient is illegal; using fc8/biases_0/gradient instead.
INFO:tensorflow:Summary name conv1/weights:0 is illegal; using conv1/weights_0 instead.
INFO:tensorflow:Summary name conv1/biases:0 is illegal; using conv1/biases_0 instead.
INFO:tensorflow:Summary

<tf.Tensor 'cross_entropy:0' shape=() dtype=string>

In [7]:
# Evaluation op: Accuracy of the model
with tf.name_scope("accuracy"):
    correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Add the accuracy to the summary
tf.summary.scalar('accuracy', accuracy)

# Merge all summaries together
merged_summary = tf.summary.merge_all()

# Initialize the FileWriter
writer = tf.summary.FileWriter(filewriter_path)

# Initialize an saver for store model checkpoints
saver = tf.train.Saver()

# Get the number of training/validation steps per epoch
train_batches_per_epoch = int(np.floor(tr_data.data_size/batch_size))
val_batches_per_epoch = int(np.floor(val_data.data_size / batch_size))

In [8]:
# Start Tensorflow session
config=tf.ConfigProto(allow_soft_placement = True)
with tf.Session(config=config) as sess:

    # Initialize all variables
    sess.run(tf.global_variables_initializer())

    # Add the model graph to TensorBoard
    writer.add_graph(sess.graph)

    # Load the pretrained weights into the non-trainable layer
    model.load_initial_weights(sess)

    print("{} Start training...".format(datetime.now()))
    print("{} Open Tensorboard at --logdir {}".format(datetime.now(),
                                                      filewriter_path))

    
    # Loop over number of epochs
    for epoch in range(num_epochs):

        print("{} Epoch number: {}".format(datetime.now(), epoch+1))

        # Initialize iterator with the training dataset
        sess.run(training_init_op)

        for step in range(train_batches_per_epoch):

            # get next batch of data
            img_batch, label_batch = sess.run(next_batch)       

            # And run the training op
            sess.run(train_op, feed_dict={x: img_batch,
                                          y: label_batch,
                                          keep_prob: dropout_rate})

            # Generate summary with the current batch of data and write to file
            if step % display_step == 0:
                s = sess.run(merged_summary, feed_dict={x: img_batch,
                                                        y: label_batch,
                                                        keep_prob: 1.})
                writer.add_summary(s, epoch*train_batches_per_epoch + step)
                print("{} step".format(datetime.now(), step))

        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.now()))
        sess.run(validation_init_op)
        test_acc = 0.
        test_count = 0
        for _ in range(val_batches_per_epoch):

            img_batch, label_batch = sess.run(next_batch)
            acc = sess.run(accuracy, feed_dict={x: img_batch,
                                                y: label_batch,
                                                keep_prob: 1.})
            test_acc += acc
            test_count += 1
        test_acc /= test_count
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(),
                                                       test_acc))
        print("{} Saving checkpoint of model...".format(datetime.now()))

        # save checkpoint of the model
        checkpoint_name = os.path.join(checkpoint_path,
                                       'model_epoch'+str(epoch+1)+'.ckpt')
        save_path = saver.save(sess, checkpoint_name)

        print("{} Model checkpoint saved at {}".format(datetime.now(),
                                                       checkpoint_name))

2018-05-31 05:12:41.085727 Start training...
2018-05-31 05:12:41.085843 Open Tensorboard at --logdir /tensorboard
2018-05-31 05:12:41.085891 Epoch number: 1
2018-05-31 05:12:45.181931 step
2018-05-31 05:12:48.079719 step
2018-05-31 05:12:51.038301 step
2018-05-31 05:12:52.920098 Start validation
2018-05-31 05:12:53.508046 Validation Accuracy = 0.4125
2018-05-31 05:12:53.508192 Saving checkpoint of model...
2018-05-31 05:12:54.082870 Model checkpoint saved at /checkpoints/model_epoch1.ckpt
2018-05-31 05:12:54.082987 Epoch number: 2
2018-05-31 05:12:56.524282 step
2018-05-31 05:12:59.616335 step
2018-05-31 05:13:02.576897 step
2018-05-31 05:13:04.538124 Start validation
2018-05-31 05:13:04.986214 Validation Accuracy = 0.4625
2018-05-31 05:13:04.986358 Saving checkpoint of model...
2018-05-31 05:13:05.416200 Model checkpoint saved at /checkpoints/model_epoch2.ckpt
2018-05-31 05:13:05.416297 Epoch number: 3
2018-05-31 05:13:07.887440 step
2018-05-31 05:13:11.019986 step
2018-05-31 05:13:14

2018-05-31 05:16:46.291306 step
2018-05-31 05:16:49.296005 step
2018-05-31 05:16:51.200283 Start validation
2018-05-31 05:16:51.663827 Validation Accuracy = 0.1000
2018-05-31 05:16:51.664008 Saving checkpoint of model...
2018-05-31 05:16:52.156512 Model checkpoint saved at /checkpoints/model_epoch22.ckpt
2018-05-31 05:16:52.156620 Epoch number: 23
2018-05-31 05:16:54.596580 step
2018-05-31 05:16:57.682251 step
2018-05-31 05:17:00.740138 step
2018-05-31 05:17:02.632243 Start validation
2018-05-31 05:17:03.035037 Validation Accuracy = 0.1000
2018-05-31 05:17:03.035213 Saving checkpoint of model...
2018-05-31 05:17:03.531601 Model checkpoint saved at /checkpoints/model_epoch23.ckpt
2018-05-31 05:17:03.531720 Epoch number: 24
2018-05-31 05:17:06.015046 step
2018-05-31 05:17:09.146887 step
2018-05-31 05:17:12.107381 step
2018-05-31 05:17:14.028778 Start validation
2018-05-31 05:17:14.452726 Validation Accuracy = 0.1156
2018-05-31 05:17:14.452934 Saving checkpoint of model...
2018-05-31 05:1