diff --git a/tensorlayer/utils.py b/tensorlayer/utils.py index f341780f3..dcbe0f74d 100644 --- a/tensorlayer/utils.py +++ b/tensorlayer/utils.py @@ -1,12 +1,16 @@ #! /usr/bin/python # -*- coding: utf8 -*- import tensorflow as tf +import tensorlayer as tl from . import iterate import numpy as np import time +import math -def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True): +def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, + n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True, + tensorboard=False, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True, tensorboard_graph_vis=True): """Traing a given non time-series network by the given cost function, training data, batch_size, n_epoch etc. Parameters @@ -39,17 +43,69 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_ the target of validation data eval_train : boolean if X_val and y_val are not None, it refects whether to evaluate the training data - + tensorboard : boolean + if True summary data will be stored to the log/ direcory for visualization with tensorboard. + See also detailed tensorboard_X settings for specific configurations of features. (default False) + Also runs tl.layers.initialize_global_variables(sess) internally in fit() to setup the summary nodes, see Note: + tensorboard_epoch_freq : int + how many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5) + tensorboard_weight_histograms : boolean + if True updates tensorboard data in the logs/ directory for visulaization + of the weight histograms every tensorboard_epoch_freq epoch (default True) + tensorboard_graph_vis : boolean + if True stores the graph in the tensorboard summaries saved to log/ (default True) Examples -------- >>> see tutorial_mnist_simple.py >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, ... X_val=X_val, y_val=y_val, eval_train=False) + >>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_, + ... acc=acc, batch_size=500, n_epoch=200, print_freq=5, + ... X_val=X_val, y_val=y_val, eval_train=False, + ... tensorboard=True, tensorboard_weight_histograms=True, tensorboard_graph_vis=True) + + Note + -------- + If tensorboard=True, the global_variables_initializer will be run inside the fit function + in order to initalize the automatically generated summary nodes used for tensorboard visualization, + thus tf.global_variables_initializer().run() before the fit() call will be undefined. """ assert X_train.shape[0] >= batch_size, "Number of training examples should be bigger than the batch size" + + if(tensorboard): + print("Setting up tensorboard ...") + #Set up tensorboard summaries and saver + tl.files.exists_or_mkdir('logs/') + + #Only write summaries for more recent TensorFlow versions + if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'): + if tensorboard_graph_vis: + train_writer = tf.summary.FileWriter('logs/train',sess.graph) + val_writer = tf.summary.FileWriter('logs/validation',sess.graph) + else: + train_writer = tf.summary.FileWriter('logs/train') + val_writer = tf.summary.FileWriter('logs/validation') + + #Set up summary nodes + if(tensorboard_weight_histograms): + for param in network.all_params: + if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'): + print('Param name ', param.name) + tf.summary.histogram(param.name, param) + + if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'): + tf.summary.scalar('cost', cost) + + merged = tf.summary.merge_all() + + #Initalize all variables and summaries + tl.layers.initialize_global_variables(sess) + print("Finished! use $tensorboard --logdir=logs/ to start server") + print("Start training the network ...") start_time_begin = time.time() + tensorboard_train_index, tensorboard_val_index = 0, 0 for epoch in range(n_epoch): start_time = time.time() loss_ep = 0; n_step = 0 @@ -62,6 +118,26 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_ n_step += 1 loss_ep = loss_ep/ n_step + if tensorboard and hasattr(tf, 'summary'): + if epoch+1 == 1 or (epoch+1) % tensorboard_epoch_freq == 0: + for X_train_a, y_train_a in iterate.minibatches( + X_train, y_train, batch_size, shuffle=True): + dp_dict = dict_to_one( network.all_drop ) # disable noise layers + feed_dict = {x: X_train_a, y_: y_train_a} + feed_dict.update(dp_dict) + result = sess.run(merged, feed_dict=feed_dict) + train_writer.add_summary(result, tensorboard_train_index) + tensorboard_train_index += 1 + + for X_val_a, y_val_a in iterate.minibatches( + X_val, y_val, batch_size, shuffle=True): + dp_dict = dict_to_one( network.all_drop ) # disable noise layers + feed_dict = {x: X_val_a, y_: y_val_a} + feed_dict.update(dp_dict) + result = sess.run(merged, feed_dict=feed_dict) + val_writer.add_summary(result, tensorboard_val_index) + tensorboard_val_index += 1 + if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: if (X_val is not None) and (y_val is not None): print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))