diff --git a/tensorlayer/utils.py b/tensorlayer/utils.py index 32824882c..2efe74c88 100644 --- a/tensorlayer/utils.py +++ b/tensorlayer/utils.py @@ -281,7 +281,21 @@ def predict(sess, network, X, x, y_op, batch_size=None): if result is None: result = result_a else: - result = np.hstack((result, result_a)) + result = np.vstack((result, result_a)) + if result is None: + if len(X) % batch_size != 0: + dp_dict = dict_to_one(network.all_drop) + feed_dict = {x: X[-(len(X) % batch_size):, :], } + feed_dict.update(dp_dict) + result_a = sess.run(y_op, feed_dict=feed_dict) + result = result_a + else: + if len(X) != len(result) and len(X) % batch_size != 0: + dp_dict = dict_to_one(network.all_drop) + feed_dict = {x: X[-(len(X) % batch_size):, :], } + feed_dict.update(dp_dict) + result_a = sess.run(y_op, feed_dict=feed_dict) + result = np.vstack((result, result_a)) return result