In [1]:
"""Script to finetune AlexNet using Tensorflow.

With this script you can finetune AlexNet as provided in the alexnet.py
class on any given dataset. Specify the configuration settings at the
beginning according to your problem.
This script was written for TensorFlow >= version 1.2rc0 and comes with a blog
post, which you can find here:

https://kratzert.github.io/2017/02/24/finetuning-alexnet-with-tensorflow.html

Author: Frederik Kratzert
contact: f.kratzert(at)gmail.com
"""

import os

import numpy as np
import tensorflow as tf


from alexnet import AlexNet
from datagenerator import ImageDataGenerator
from datetime import datetime
from tensorflow.contrib.data import Iterator

"""
Configuration Part.
"""



  from ._conv import register_converters as _register_converters


'\nConfiguration Part.\n'

In [2]:
# Path to the textfiles for the trainings and validation set
train_file = './train.txt'
val_file = './test.txt'

# Learning params
learning_rate = 0.02
num_epochs = 30
batch_size = 64

# Network params
dropout_rate = 0.5
num_classes = 2
train_layers = ['fc8', 'fc7', 'fc6']

# How often we want to write the tf.summary data to disk
display_step = 20

# Path for tf.summary.FileWriter and to store model checkpoints
filewriter_path = "/tensorboard"
checkpoint_path = "/checkpoints"

"""
Main Part of the finetuning Script.
"""

# Create parent path if it doesn't exist
if not os.path.isdir(checkpoint_path):
    os.mkdir(checkpoint_path)

# Place data loading and preprocessing on the cpu


In [3]:
with tf.device('/gpu:0'):
    tr_data = ImageDataGenerator(train_file,
                                 mode='training',
                                 batch_size=batch_size,
                                 num_classes=num_classes,
                                 shuffle=True)
    val_data = ImageDataGenerator(val_file,
                                  mode='inference',
                                  batch_size=batch_size,
                                  num_classes=num_classes,
                                  shuffle=False)


./Torn/1.jpg
0

./Torn/2.jpg
0

./Torn/3.jpg
0

./Torn/4.jpg
0

./Torn/5.jpg
0

./Torn/6.jpg
0

./Torn/7.jpg
0

./Torn/8.jpg
0

./Torn/9.jpg
0

./Torn/10.jpg
0

./Torn/11.jpg
0

./Torn/12.jpg
0

./Torn/13.jpg
0

./Torn/14.jpg
0

./Torn/15.jpg
0

./Torn/16.jpg
0

./Torn/17.jpg
0

./Torn/18.jpg
0

./Torn/19.jpg
0

./Torn/20.jpg
0

./Torn/21.jpg
0

./Torn/22.jpg
0

./Torn/23.jpg
0

./Torn/24.jpg
0

./Torn/25.jpg
0

./Torn/26.jpg
0

./Torn/27.jpg
0

./Torn/28.jpg
0

./Torn/29.jpg
0

./Torn/30.jpg
0

./Torn/31.jpg
0

./Torn/32.jpg
0

./Torn/33.jpg
0

./Torn/34.jpg
0

./Torn/35.jpg
0

./Torn/36.jpg
0

./Torn/37.jpg
0

./Torn/38.jpg
0

./Torn/39.jpg
0

./Torn/40.jpg
0

./Torn/41.jpg
0

./Torn/42.jpg
0

./Torn/43.jpg
0

./Torn/44.jpg
0

./Torn/45.jpg
0

./Torn/46.jpg
0

./Torn/47.jpg
0

./Torn/48.jpg
0

./Torn/49.jpg
0

./Torn/50.jpg
0

./Torn/51.jpg
0

./Torn/52.jpg
0

./Torn/53.jpg
0

./Torn/54.jpg
0

./Torn/55.jpg
0

./Torn/56.jpg
0

./Torn/57.jpg
0

./Torn/58.jpg
0

./Torn/59.jpg
0

./Torn

0

./Torn/1737.jpg
0

./Torn/1738.jpg
0

./Torn/1739.jpg
0

./Torn/1740.jpg
0

./Torn/1741.jpg
0

./Torn/1742.jpg
0

./Torn/1743.jpg
0

./Torn/1744.jpg
0

./Torn/1745.jpg
0

./Torn/1746.jpg
0

./Torn/1747.jpg
0

./Torn/1748.jpg
0

./Torn/1749.jpg
0

./Torn/1750.jpg
0

./Torn/1751.jpg
0

./Torn/1752.jpg
0

./Torn/1753.jpg
0

./Torn/1754.jpg
0

./Torn/1755.jpg
0

./Torn/1756.jpg
0

./Torn/1757.jpg
0

./Torn/1758.jpg
0

./Torn/1759.jpg
0

./Torn/1760.jpg
0

./Torn/1761.jpg
0

./Torn/1762.jpg
0

./Torn/1763.jpg
0

./Torn/1764.jpg
0

./Torn/1765.jpg
0

./Torn/1766.jpg
0

./Torn/1767.jpg
0

./Torn/1768.jpg
0

./Torn/1769.jpg
0

./Torn/1770.jpg
0

./Torn/1771.jpg
0

./Torn/1772.jpg
0

./Torn/1773.jpg
0

./Torn/1774.jpg
0

./Torn/1775.jpg
0

./Torn/1776.jpg
0

./Torn/1777.jpg
0

./Torn/1778.jpg
0

./Torn/1779.jpg
0

./Torn/1780.jpg
0

./Torn/1781.jpg
0

./Torn/1782.jpg
0

./Torn/1783.jpg
0

./Torn/1784.jpg
0

./Torn/1785.jpg
0

./Torn/1786.jpg
0

./Torn/1787.jpg
0

./Torn/1788.jpg
0

./Torn/17

./Untorn/986.jpg
1

./Untorn/987.jpg
1

./Untorn/988.jpg
1

./Untorn/989.jpg
1

./Untorn/990.jpg
1

./Untorn/991.jpg
1

./Untorn/992.jpg
1

./Untorn/993.jpg
1

./Untorn/994.jpg
1

./Untorn/995.jpg
1

./Untorn/996.jpg
1

./Untorn/997.jpg
1

./Untorn/998.jpg
1

./Untorn/999.jpg
1

./Untorn/1000.jpg
1

./Untorn/1001.jpg
1

./Untorn/1002.jpg
1

./Untorn/1003.jpg
1

./Untorn/1004.jpg
1

./Untorn/1005.jpg
1

./Untorn/1006.jpg
1

./Untorn/1007.jpg
1

./Untorn/1008.jpg
1

./Untorn/1009.jpg
1

./Untorn/1010.jpg
1

./Untorn/1011.jpg
1

./Untorn/1012.jpg
1

./Untorn/1013.jpg
1

./Untorn/1014.jpg
1

./Untorn/1015.jpg
1

./Untorn/1016.jpg
1

./Untorn/1017.jpg
1

./Untorn/1018.jpg
1

./Untorn/1019.jpg
1

./Untorn/1020.jpg
1

./Untorn/1021.jpg
1

./Untorn/1022.jpg
1

./Untorn/1023.jpg
1

./Untorn/1024.jpg
1

./Untorn/1025.jpg
1

./Untorn/1026.jpg
1

./Untorn/1027.jpg
1

./Untorn/1028.jpg
1

./Untorn/1029.jpg
1

./Untorn/1030.jpg
1

./Untorn/1031.jpg
1

./Untorn/1032.jpg
1

./Untorn/1033.jpg
1

./Unto

Instructions for updating:
Use `tf.data.Dataset.from_tensor_slices()`.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
Instructions for updating:
Replace `num_threads=T` with `num_parallel_calls=T`. Replace `output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.
./torn/frame1.jpg
0

./torn/frame2.jpg
0

./torn/frame3.jpg
0

./torn/frame4.jpg
0

./torn/frame5.jpg
0

./torn/frame6.jpg
0

./torn/frame7.jpg
0

./torn/frame8.jpg
0

./torn/frame9.jpg
0

./torn/frame10.jpg
0

./torn/frame11.jpg
0

./torn/frame12.jpg
0

./torn/frame13.jpg
0

./torn/frame14.jpg
0

./torn/frame15.jpg
0

./torn/frame16.jpg
0

./torn/frame17.jpg
0

./torn/frame18.jpg
0

./torn/frame19.jpg
0

./torn/frame20.jpg
0

./torn/frame21.jpg
0

./torn/frame22.jpg
0

./torn/frame23.jpg
0

./torn/frame24.jpg
0

./torn/frame25.jpg
0

./torn/frame26.jpg
0

./torn/frame27.jpg
0

./torn/frame28.jpg
0

./torn

In [4]:
# create an reinitializable iterator given the dataset structure
iterator = Iterator.from_structure(tr_data.data.output_types,
                                       tr_data.data.output_shapes)

next_batch = iterator.get_next()

# Ops for initializing the two different iterators
training_init_op = iterator.make_initializer(tr_data.data)
validation_init_op = iterator.make_initializer(val_data.data)

# TF placeholder for graph input and output
x = tf.placeholder(tf.float32, [batch_size, 227, 227, 3])
y = tf.placeholder(tf.float32, [batch_size, num_classes])
keep_prob = tf.placeholder(tf.float32)

In [5]:
# Initialize model
model = AlexNet(x, keep_prob, num_classes, train_layers)

# Link variable to model output
score = model.fc8

# List of trainable variables of the layers we want to train
var_list = [v for v in tf.trainable_variables() if v.name.split('/')[0] in train_layers]

In [6]:
# Op for calculating the loss
with tf.name_scope("cross_ent"):
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=score,
                                                                  labels=y))

# Train op
with tf.name_scope("train"):
    # Get gradients of all trainable variables
    gradients = tf.gradients(loss, var_list)
    gradients = list(zip(gradients, var_list))

    # Create optimizer and apply gradient descent to the trainable variables
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.apply_gradients(grads_and_vars=gradients)

# Add gradients to summary
for gradient, var in gradients:
    tf.summary.histogram(var.name + '/gradient', gradient)

# Add the variables we train to the summary
for var in var_list:
    tf.summary.histogram(var.name, var)

# Add the loss to summary
tf.summary.scalar('cross_entropy', loss)

Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

INFO:tensorflow:Summary name fc6/weights:0/gradient is illegal; using fc6/weights_0/gradient instead.
INFO:tensorflow:Summary name fc6/biases:0/gradient is illegal; using fc6/biases_0/gradient instead.
INFO:tensorflow:Summary name fc7/weights:0/gradient is illegal; using fc7/weights_0/gradient instead.
INFO:tensorflow:Summary name fc7/biases:0/gradient is illegal; using fc7/biases_0/gradient instead.
INFO:tensorflow:Summary name fc8/weights:0/gradient is illegal; using fc8/weights_0/gradient instead.
INFO:tensorflow:Summary name fc8/biases:0/gradient is illegal; using fc8/biases_0/gradient instead.
INFO:tensorflow:Summary name fc6/weights:0 is illegal; using fc6/weights_0 instead.
INFO:tensorflow:Summary name fc6/biases:0 is illegal; using fc6/biases_0 instead.
INFO:tensorflow:Summary name fc7/weight

<tf.Tensor 'cross_entropy:0' shape=() dtype=string>

In [7]:
# Evaluation op: Accuracy of the model
with tf.name_scope("accuracy"):
    correct_pred = tf.equal(tf.argmax(score, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

# Add the accuracy to the summary
tf.summary.scalar('accuracy', accuracy)

# Merge all summaries together
merged_summary = tf.summary.merge_all()

# Initialize the FileWriter
writer = tf.summary.FileWriter(filewriter_path)

# Initialize an saver for store model checkpoints
saver = tf.train.Saver()

# Get the number of training/validation steps per epoch
train_batches_per_epoch = int(np.floor(tr_data.data_size/batch_size))
val_batches_per_epoch = int(np.floor(val_data.data_size / batch_size))

In [None]:
# Start Tensorflow session
config=tf.ConfigProto(allow_soft_placement = True)
with tf.Session(config=config) as sess:

    # Initialize all variables
    sess.run(tf.global_variables_initializer())

    # Add the model graph to TensorBoard
    writer.add_graph(sess.graph)

    # Load the pretrained weights into the non-trainable layer
    model.load_initial_weights(sess)

    print("{} Start training...".format(datetime.now()))
    print("{} Open Tensorboard at --logdir {}".format(datetime.now(),
                                                      filewriter_path))

    
    # Loop over number of epochs
    for epoch in range(num_epochs):

        print("{} Epoch number: {}".format(datetime.now(), epoch+1))

        # Initialize iterator with the training dataset
        sess.run(training_init_op)

        for step in range(train_batches_per_epoch):

            # get next batch of data
            img_batch, label_batch = sess.run(next_batch)       

            # And run the training op
            sess.run(train_op, feed_dict={x: img_batch,
                                          y: label_batch,
                                          keep_prob: dropout_rate})

            # Generate summary with the current batch of data and write to file
            if step % display_step == 0:
                s = sess.run(merged_summary, feed_dict={x: img_batch,
                                                        y: label_batch,
                                                        keep_prob: 1.})
                writer.add_summary(s, epoch*train_batches_per_epoch + step)
                print("{} step".format(datetime.now(), step))

        # Validate the model on the entire validation set
        print("{} Start validation".format(datetime.now()))
        sess.run(validation_init_op)
        test_acc = 0.
        test_count = 0
        for _ in range(val_batches_per_epoch):

            img_batch, label_batch = sess.run(next_batch)
            acc = sess.run(accuracy, feed_dict={x: img_batch,
                                                y: label_batch,
                                                keep_prob: 1.})
            test_acc += acc
            test_count += 1
        test_acc /= test_count
        print("{} Validation Accuracy = {:.4f}".format(datetime.now(),
                                                       test_acc))
        print("{} Saving checkpoint of model...".format(datetime.now()))

        # save checkpoint of the model
        checkpoint_name = os.path.join(checkpoint_path,
                                       'model_epoch'+str(epoch+1)+'.ckpt')
        save_path = saver.save(sess, checkpoint_name)

        print("{} Model checkpoint saved at {}".format(datetime.now(),
                                                       checkpoint_name))

2018-05-28 05:42:34.469138 Start training...
2018-05-28 05:42:34.469258 Open Tensorboard at --logdir /tensorboard
2018-05-28 05:42:34.469317 Epoch number: 1
2018-05-28 05:42:39.254477 step
2018-05-28 05:42:42.205042 step
2018-05-28 05:42:45.120394 step
2018-05-28 05:42:46.370815 Start validation
2018-05-28 05:42:46.873787 Validation Accuracy = 0.6781
2018-05-28 05:42:46.874002 Saving checkpoint of model...
2018-05-28 05:42:47.114536 Model checkpoint saved at /checkpoints/model_epoch1.ckpt
2018-05-28 05:42:47.114790 Epoch number: 2
2018-05-28 05:42:50.505771 step
2018-05-28 05:42:53.793557 step
2018-05-28 05:42:56.719671 step
2018-05-28 05:42:57.943222 Start validation
2018-05-28 05:42:58.437850 Validation Accuracy = 0.6031
2018-05-28 05:42:58.437967 Saving checkpoint of model...
2018-05-28 05:42:58.671368 Model checkpoint saved at /checkpoints/model_epoch2.ckpt
2018-05-28 05:42:58.671493 Epoch number: 3
2018-05-28 05:43:02.169972 step
2018-05-28 05:43:05.275709 step
2018-05-28 05:43:08

2018-05-28 05:46:44.018466 step
2018-05-28 05:46:46.919746 step
2018-05-28 05:46:48.186020 Start validation
2018-05-28 05:46:48.664555 Validation Accuracy = 0.7063
2018-05-28 05:46:48.664773 Saving checkpoint of model...
2018-05-28 05:46:48.915340 Model checkpoint saved at /checkpoints/model_epoch22.ckpt
2018-05-28 05:46:48.915613 Epoch number: 23
2018-05-28 05:46:52.268913 step
2018-05-28 05:46:55.407070 step
2018-05-28 05:46:58.285539 step
2018-05-28 05:46:59.533012 Start validation
2018-05-28 05:47:00.004149 Validation Accuracy = 0.6375
2018-05-28 05:47:00.004346 Saving checkpoint of model...
2018-05-28 05:47:00.262872 Model checkpoint saved at /checkpoints/model_epoch23.ckpt
2018-05-28 05:47:00.262981 Epoch number: 24
2018-05-28 05:47:03.592857 step
2018-05-28 05:47:06.800023 step
2018-05-28 05:47:09.713148 step
2018-05-28 05:47:10.952110 Start validation
2018-05-28 05:47:11.425377 Validation Accuracy = 0.7562
2018-05-28 05:47:11.425581 Saving checkpoint of model...
2018-05-28 05:4