Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions tensorlayer/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
#! /usr/bin/python
# -*- coding: utf8 -*-
import tensorflow as tf
import tensorlayer as tl
from . import iterate
import numpy as np
import time
import math


def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100, n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True):
def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_size=100,
n_epoch=100, print_freq=5, X_val=None, y_val=None, eval_train=True,
tensorboard=False, tensorboard_epoch_freq=5, tensorboard_weight_histograms=True, tensorboard_graph_vis=True):
"""Traing a given non time-series network by the given cost function, training data, batch_size, n_epoch etc.

Parameters
Expand Down Expand Up @@ -39,17 +43,69 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_
the target of validation data
eval_train : boolean
if X_val and y_val are not None, it refects whether to evaluate the training data

tensorboard : boolean
if True summary data will be stored to the log/ direcory for visualization with tensorboard.
See also detailed tensorboard_X settings for specific configurations of features. (default False)
Also runs tl.layers.initialize_global_variables(sess) internally in fit() to setup the summary nodes, see Note:
tensorboard_epoch_freq : int
how many epochs between storing tensorboard checkpoint for visualization to log/ directory (default 5)
tensorboard_weight_histograms : boolean
if True updates tensorboard data in the logs/ directory for visulaization
of the weight histograms every tensorboard_epoch_freq epoch (default True)
tensorboard_graph_vis : boolean
if True stores the graph in the tensorboard summaries saved to log/ (default True)
Examples
--------
>>> see tutorial_mnist_simple.py
>>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
... X_val=X_val, y_val=y_val, eval_train=False)
>>> tl.utils.fit(sess, network, train_op, cost, X_train, y_train, x, y_,
... acc=acc, batch_size=500, n_epoch=200, print_freq=5,
... X_val=X_val, y_val=y_val, eval_train=False,
... tensorboard=True, tensorboard_weight_histograms=True, tensorboard_graph_vis=True)

Note
--------
If tensorboard=True, the global_variables_initializer will be run inside the fit function
in order to initalize the automatically generated summary nodes used for tensorboard visualization,
thus tf.global_variables_initializer().run() before the fit() call will be undefined.
"""
assert X_train.shape[0] >= batch_size, "Number of training examples should be bigger than the batch size"

if(tensorboard):
print("Setting up tensorboard ...")
#Set up tensorboard summaries and saver
tl.files.exists_or_mkdir('logs/')

#Only write summaries for more recent TensorFlow versions
if hasattr(tf, 'summary') and hasattr(tf.summary, 'FileWriter'):
if tensorboard_graph_vis:
train_writer = tf.summary.FileWriter('logs/train',sess.graph)
val_writer = tf.summary.FileWriter('logs/validation',sess.graph)
else:
train_writer = tf.summary.FileWriter('logs/train')
val_writer = tf.summary.FileWriter('logs/validation')

#Set up summary nodes
if(tensorboard_weight_histograms):
for param in network.all_params:
if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'):
print('Param name ', param.name)
tf.summary.histogram(param.name, param)

if hasattr(tf, 'summary') and hasattr(tf.summary, 'histogram'):
tf.summary.scalar('cost', cost)

merged = tf.summary.merge_all()

#Initalize all variables and summaries
tl.layers.initialize_global_variables(sess)
print("Finished! use $tensorboard --logdir=logs/ to start server")

print("Start training the network ...")
start_time_begin = time.time()
tensorboard_train_index, tensorboard_val_index = 0, 0
for epoch in range(n_epoch):
start_time = time.time()
loss_ep = 0; n_step = 0
Expand All @@ -62,6 +118,26 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_
n_step += 1
loss_ep = loss_ep/ n_step

if tensorboard and hasattr(tf, 'summary'):
if epoch+1 == 1 or (epoch+1) % tensorboard_epoch_freq == 0:
for X_train_a, y_train_a in iterate.minibatches(
X_train, y_train, batch_size, shuffle=True):
dp_dict = dict_to_one( network.all_drop ) # disable noise layers
feed_dict = {x: X_train_a, y_: y_train_a}
feed_dict.update(dp_dict)
result = sess.run(merged, feed_dict=feed_dict)
train_writer.add_summary(result, tensorboard_train_index)
tensorboard_train_index += 1

for X_val_a, y_val_a in iterate.minibatches(
X_val, y_val, batch_size, shuffle=True):
dp_dict = dict_to_one( network.all_drop ) # disable noise layers
feed_dict = {x: X_val_a, y_: y_val_a}
feed_dict.update(dp_dict)
result = sess.run(merged, feed_dict=feed_dict)
val_writer.add_summary(result, tensorboard_val_index)
tensorboard_val_index += 1

if epoch + 1 == 1 or (epoch + 1) % print_freq == 0:
if (X_val is not None) and (y_val is not None):
print("Epoch %d of %d took %fs" % (epoch + 1, n_epoch, time.time() - start_time))
Expand Down