From 460a1deafee68481832e3f105868e86ddee2a994 Mon Sep 17 00:00:00 2001 From: Yang Date: Mon, 12 Jun 2017 13:43:22 +0800 Subject: [PATCH] Create utils.py --- tensorlayer/utils.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tensorlayer/utils.py b/tensorlayer/utils.py index f1d171322..0dbdbdc12 100644 --- a/tensorlayer/utils.py +++ b/tensorlayer/utils.py @@ -130,15 +130,15 @@ def fit(sess, network, train_op, cost, X_train, y_train, x, y_, acc=None, batch_ result = sess.run(merged, feed_dict=feed_dict) train_writer.add_summary(result, tensorboard_train_index) tensorboard_train_index += 1 - - for X_val_a, y_val_a in iterate.minibatches( + if (X_val is not None) and (y_val is not None): + for X_val_a, y_val_a in iterate.minibatches( X_val, y_val, batch_size, shuffle=True): - dp_dict = dict_to_one( network.all_drop ) # disable noise layers - feed_dict = {x: X_val_a, y_: y_val_a} - feed_dict.update(dp_dict) - result = sess.run(merged, feed_dict=feed_dict) - val_writer.add_summary(result, tensorboard_val_index) - tensorboard_val_index += 1 + dp_dict = dict_to_one( network.all_drop ) # disable noise layers + feed_dict = {x: X_val_a, y_: y_val_a} + feed_dict.update(dp_dict) + result = sess.run(merged, feed_dict=feed_dict) + val_writer.add_summary(result, tensorboard_val_index) + tensorboard_val_index += 1 if epoch + 1 == 1 or (epoch + 1) % print_freq == 0: if (X_val is not None) and (y_val is not None):