In [None]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import math
import os
import random
import time
from absl import app
from absl import flags
import model
import numpy as np
import tensorflow as tf
import util

gfile = tf.gfile

DEFAULT_DATA_DIR = '/data1/depth/endo/endo/endo_pre2'
DEFAULT_CHECKPOINT_DIR = 'checkpoints_endo_resize'
data_dir = DEFAULT_DATA_DIR
learning_rate = 0.0002
beta1 = 0.9
reconstr_weight = 0.85
smooth_weight = 0.05
ssim_weight = 0.15
batch_size = 32
img_height = 256
img_width = 256
seq_length = 3
pretrained_ckpt = None
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
train_steps = 200000
summary_freq = 100
legacy_mode = False

# Maximum number of checkpoints to keep.
MAX_TO_KEEP = 100

def main(_):
    # Fixed seed for repeatability
    seed = 8964
    tf.set_random_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    if legacy_mode and seq_length < 3:
        raise ValueError('Legacy mode supports sequence length > 2 only.')

    if not gfile.Exists(checkpoint_dir):
        gfile.MakeDirs(checkpoint_dir)

    train_model = model.Model(data_dir=data_dir,
                            is_training=True,
                            learning_rate=learning_rate,
                            beta1=beta1,
                            reconstr_weight=reconstr_weight,
                            smooth_weight=smooth_weight,
                            ssim_weight=ssim_weight,
                            batch_size=batch_size,
                            img_height=img_height,
                            img_width=img_width,
                            seq_length=seq_length,
                            legacy_mode=legacy_mode)

    train(train_model, pretrained_ckpt, checkpoint_dir, train_steps, summary_freq)


def train(train_model, pretrained_ckpt, checkpoint_dir, train_steps, summary_freq):
    if pretrained_ckpt is not None:
        vars_to_restore = util.get_vars_to_restore(pretrained_ckpt)
        pretrain_restorer = tf.train.Saver(vars_to_restore)
    vars_to_save = util.get_vars_to_restore()
    saver = tf.train.Saver(vars_to_save + [train_model.global_step], max_to_keep=MAX_TO_KEEP)
    sv = tf.train.Supervisor(logdir=checkpoint_dir, save_summaries_secs=0, saver=None)
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with sv.managed_session(config=config) as sess:
        if pretrained_ckpt is not None:
            print('Restoring pretrained weights from %s'%pretrained_ckpt)
            pretrain_restorer.restore(sess, pretrained_ckpt)
        print('Attempting to resume training from %s...' %checkpoint_dir)
        checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
        print('Last checkpoint found: %s' %checkpoint)
        if checkpoint:
            saver.restore(sess, checkpoint)

        print('Training...')
        start_time = time.time()
        last_summary_time = time.time()
        steps_per_epoch = train_model.reader.steps_per_epoch
        step = 1
        while step <= train_steps:
            fetches = {'train': train_model.train_op,
                      'global_step': train_model.global_step,
                      'incr_global_step': train_model.incr_global_step}

            if step % summary_freq == 0:
                fetches['loss'] = train_model.total_loss
                fetches['summary'] = sv.summary_op

            results = sess.run(fetches)
            global_step = results['global_step']

            if step % summary_freq == 0:
                sv.summary_writer.add_summary(results['summary'], global_step)
                train_epoch = math.ceil(global_step / steps_per_epoch)
                train_step = global_step - (train_epoch - 1) * steps_per_epoch
                this_cycle = time.time() - last_summary_time
                last_summary_time += this_cycle
                print('Epoch: [%2d] [%5d/%5d] time: %4.2fs (%ds total) loss: %.3f'%(
                            train_epoch, train_step, steps_per_epoch, this_cycle,
                            time.time() - start_time, results['loss']))

            if step % steps_per_epoch == 0:
                print('[*] Saving checkpoint to %s...'% checkpoint_dir)
                saver.save(sess, os.path.join(checkpoint_dir, 'model'), global_step=global_step)

            # Setting step to global_step allows for training for a total of
            # train_steps even if the program is restarted during training.
            step = global_step + 1


if __name__ == '__main__':
    main(_)