In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys, os
import tensorflow as tf
import pandas as pd
import glob
import pydicom
ROOT_DIR = '/home/gagliardi/rnsa'
sys.path.append(os.path.join(ROOT_DIR, 'SSD-Tensorflow'))  # To find local version of the library
from datasets import dataset_factory, dataset_utils
from nets import ssd_vgg_300, ssd_vgg_512, ssd_common, np_methods
from preprocessing import ssd_vgg_preprocessing
from datasets import pascalvoc_to_tfrecords
import tf_utils
import ntpath
import numpy as np
import parse_tfrecord

slim = tf.contrib.slim

## Create TfRecordData

In [None]:
filenames = glob.glob('./data/tf_record_data/*')

# Building TF graph

In [None]:
# =========================================================================== #
# Flag reinitialisation and add a f flag for jupyter kernel
# =========================================================================== #
from absl import flags
for name in list(flags.FLAGS):
    delattr(flags.FLAGS, name)
    
# =========================================================================== #
# SSD Network flags.
# =========================================================================== #
tf.app.flags.DEFINE_float(
    'loss_alpha', 1., 'Alpha parameter in the loss function.')
tf.app.flags.DEFINE_float(
    'negative_ratio', 3., 'Negative ratio in the loss function.')
tf.app.flags.DEFINE_float(
    'match_threshold', 0.5, 'Matching threshold in the loss function.')

tf.app.flags.DEFINE_string('f', '', 'kernel')    
# =========================================================================== #
# Learning Rate Flags.
# =========================================================================== #
tf.app.flags.DEFINE_string(
    'learning_rate_decay_type',
    'exponential',
    'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
    ' or "polynomial"')
tf.app.flags.DEFINE_float('learning_rate', 0.001, 'Initial learning rate.')
tf.app.flags.DEFINE_float(
    'end_learning_rate', 0.0001,
    'The minimal end learning rate used by a polynomial decay learning rate.')
tf.app.flags.DEFINE_float(
    'label_smoothing', 0.0, 'The amount of label smoothing.')
tf.app.flags.DEFINE_float(
    'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
tf.app.flags.DEFINE_float(
    'num_epochs_per_decay', 2.0,
    'Number of epochs after which learning rate decays.')
tf.app.flags.DEFINE_float(
    'moving_average_decay', None,
    'The decay to use for the moving average.'
    'If left as None, then moving averages are not used.')
# =========================================================================== #
# Dataset Flags.
# =========================================================================== #
tf.app.flags.DEFINE_integer('batch_size', 16, 'The number of samples in each batch.')
tf.app.flags.DEFINE_integer( 'num_readers', 4, 'The number of parallel readers that read data from the dataset.')

# =========================================================================== #
# Optimization Flags.
# =========================================================================== #
tf.app.flags.DEFINE_float(
    'weight_decay', 0.00004, 'The weight decay on the model weights.')
tf.app.flags.DEFINE_string(
    'optimizer', 'rmsprop',
    'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
    '"ftrl", "momentum", "sgd" or "rmsprop".')
tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')

# =========================================================================== #
# Summary Flags.
# =========================================================================== #
tf.app.flags.DEFINE_string('summaries_dir','./logdir/','Summary')

# =========================================================================== #
# Fine-Tuning Flags.
# =========================================================================== #
tf.app.flags.DEFINE_string(
    'checkpoint_path', None,
    'The path to a checkpoint from which to fine-tune.')
tf.app.flags.DEFINE_string(
    'checkpoint_model_scope', None,
    'Model scope in the checkpoint. None if the same as the trained model.')
tf.app.flags.DEFINE_string(
    'checkpoint_exclude_scopes', None,
    'Comma-separated list of scopes of variables to exclude when restoring '
    'from a checkpoint.')
tf.app.flags.DEFINE_string(
    'trainable_scopes', None,
    'Comma-separated list of scopes to filter the set of variables to train.'
    'By default, None would train all the variables.')
tf.app.flags.DEFINE_boolean(
    'ignore_missing_vars', False,
    'When restoring a checkpoint would ignore missing variables.')

# =========================================================================== #
# epochs Flags.
# =========================================================================== #
tf.app.flags.DEFINE_integer(
    'epochs', 10,
    'The number of epochs.')

FLAGS = tf.app.flags.FLAGS

### SSD - preprocessing

In [None]:
num_preprocessing_threads = 8

#image = tf.placeholder(tf.uint8, shape=(1024, 1024, 3))
#labels = tf.placeholder(tf.int64, shape=(1,))
#bboxes = tf.placeholder(tf.float32, shape=(None, 4))

ssd_net = ssd_vgg_512.SSDNet() 

ssd_shape = ssd_net.params.img_shape
ssd_anchors = ssd_net.anchors(ssd_shape)
    
# resizing pictures
out_shape = tf.constant([512, 512])
#Num_samples x Height x Width x Channels
DATA_FORMAT = 'NHWC'
iterator = parse_tfrecord.get_iterator(filenames)
filename, image, labels, bboxes = iterator.get_next()

# Pre-processing image, labels and bboxes.
image, glabels, gbboxes = ssd_vgg_preprocessing.preprocess_image(image, labels, bboxes, out_shape, data_format=DATA_FORMAT, is_training=True)

# Encode groundtruth labels and bboxes.
gclasses, glocalisations, gscores = ssd_net.bboxes_encode(glabels, gbboxes, ssd_anchors)
batch_shape = [1] + [len(ssd_anchors)] * 3

# Training batches and queue.
r = tf.train.batch( tf_utils.reshape_list([image, gclasses, glocalisations, gscores]), 
                    batch_size=FLAGS.batch_size,
                    num_threads=num_preprocessing_threads,
                    capacity=5 * FLAGS.batch_size)

b_image, b_gclasses, b_glocalisations, b_gscores = tf_utils.reshape_list(r, batch_shape)

# Intermediate queueing: unique batch computation pipeline for all
# GPUs running the training.
batch_queue = slim.prefetch_queue.prefetch_queue(tf_utils.reshape_list([b_image, b_gclasses, b_glocalisations, b_gscores]), capacity=2)

###  Definte the model

In [None]:
# =================================================================== #
# Define the model running on every GPU.
# =================================================================== #
    
b_image, b_gclasses, b_glocalisations, b_gscores = tf_utils.reshape_list(batch_queue.dequeue(), batch_shape)

# Construct SSD network.
arg_scope = ssd_net.arg_scope(weight_decay=FLAGS.weight_decay, data_format=DATA_FORMAT)

with slim.arg_scope(arg_scope):
    predictions, localisations, logits, end_points = ssd_net.net(b_image, is_training=True)


# Add loss function.
ssd_net.losses(logits, 
               localisations,
               b_gclasses, 
               b_glocalisations, 
               b_gscores,
               match_threshold=FLAGS.match_threshold,
               negative_ratio=FLAGS.negative_ratio,
               alpha=FLAGS.loss_alpha,
               label_smoothing=FLAGS.label_smoothing)

### SSD - set optimizer

In [None]:
num_samples = 5659

global_step = tf.train.get_or_create_global_step()
summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))

learning_rate = tf_utils.configure_learning_rate(FLAGS, num_samples, global_step)
optimizer = tf_utils.configure_optimizer(FLAGS, learning_rate)

summaries.add(tf.summary.scalar('global_step', global_step))
summaries.add(tf.summary.scalar('learning_rate', learning_rate))

total_loss = tf.add_n(tf.get_collection(tf.GraphKeys.LOSSES))


train = optimizer.minimize(total_loss)

for loss in tf.get_collection(tf.GraphKeys.LOSSES):
    summaries.add(tf.summary.scalar(loss.op.name, loss))

# Add summaries for variables.
for variable in slim.get_model_variables():
    summaries.add(tf.summary.histogram(variable.op.name, variable))

summaries.add(tf.summary.scalar('total_loss', total_loss))

merged = tf.summary.merge_all()
#sess.run(iterator.initializer)    
# Run training.
#slim.learning.train(train, logdir='./logdir', init_fn=init_fn)

### Run session

In [None]:
'''
sess.run(tf.global_variables_initializer())
sess.run(iterator.initializer)

accuracy_list = []

while not sess.should_stop():
    summaries, _, acc = sess.run([merged, train, total_loss])
    accuracy_list.append(acc)
    if i%100==0:
        print('Accuracy at step %s: %s' % (i, sum(accuracy_list)/len(accuracy_list)))
        accuracy_list = []
    train_writer.add_summary(summaries, i)
'''
EPOCHS = 45
#df = pd.read_csv('./data/stage_1_train_labels.csv')
#df[df.Target==1].patientId.nunique()
n_batches = int(np.floor(5659/FLAGS.batch_size))

with tf.Session() as sess:
    saver = tf.train.Saver()
    train_writer = tf.summary.FileWriter(FLAGS.summaries_dir, sess.graph)
    sess.run(tf.global_variables_initializer())
    sess.run(iterator.initializer)
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
    print('Training...')
    
    for i in range(EPOCHS):
        tot_loss = 0
        for j in range(n_batches):
            summaries, _, loss_value = sess.run([merged, train, total_loss])
            
            tot_loss += loss_value
            train_writer.add_summary(summaries, j)
        print("Iter: {}, Loss: {:.4f}".format(i, tot_loss / n_batches))
    saver.save(sess, FLAGS.summaries_dir+'ssd_512.ckpt')