From 7d5e21b864cd8be84f541ace52c87657bcf1567a Mon Sep 17 00:00:00 2001 From: xioinghhcs Date: Sun, 26 Nov 2017 19:27:17 +0800 Subject: [PATCH] fix the tl.utils.predict 's bug. when the data size can not be exactly divided by batch_size, there will be some predict result lost. --- tensorlayer/utils.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/tensorlayer/utils.py b/tensorlayer/utils.py index 32824882c..2efe74c88 100644 --- a/tensorlayer/utils.py +++ b/tensorlayer/utils.py @@ -281,7 +281,21 @@ def predict(sess, network, X, x, y_op, batch_size=None): if result is None: result = result_a else: - result = np.hstack((result, result_a)) + result = np.vstack((result, result_a)) + if result is None: + if len(X) % batch_size != 0: + dp_dict = dict_to_one(network.all_drop) + feed_dict = {x: X[-(len(X) % batch_size):, :], } + feed_dict.update(dp_dict) + result_a = sess.run(y_op, feed_dict=feed_dict) + result = result_a + else: + if len(X) != len(result) and len(X) % batch_size != 0: + dp_dict = dict_to_one(network.all_drop) + feed_dict = {x: X[-(len(X) % batch_size):, :], } + feed_dict.update(dp_dict) + result_a = sess.run(y_op, feed_dict=feed_dict) + result = np.vstack((result, result_a)) return result