From 8b2c382f77c6180081b2da68953f1a161e788395 Mon Sep 17 00:00:00 2001 From: zsdonghao Date: Wed, 27 Sep 2017 11:05:31 +0100 Subject: [PATCH] move pyplot into functions --- tensorlayer/visualize.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorlayer/visualize.py b/tensorlayer/visualize.py index d49ea10ae..44d7e1c65 100644 --- a/tensorlayer/visualize.py +++ b/tensorlayer/visualize.py @@ -9,7 +9,6 @@ # matplotlib.use('Agg') -import matplotlib.pyplot as plt import numpy as np import os from . import prepro @@ -114,6 +113,7 @@ def W(W=None, second=10, saveable=True, shape=[28,28], name='mnist', fig_idx=239 -------- >>> tl.visualize.W(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012) """ + import matplotlib.pyplot as plt if saveable is False: plt.ion() fig = plt.figure(fig_idx) # show all feature images @@ -177,6 +177,7 @@ def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=1283 >>> observation = env.reset() >>> tl.visualize.frame(observation) """ + import matplotlib.pyplot as plt if saveable is False: plt.ion() fig = plt.figure(fig_idx) # show all feature images @@ -215,6 +216,7 @@ def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362): -------- >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012) """ + import matplotlib.pyplot as plt # print(CNN.shape) # (5, 5, 3, 64) # exit() n_mask = CNN.shape[3] @@ -280,6 +282,7 @@ def images2d(images=None, second=10, saveable=True, name='images', dtype=None, >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False) >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212) """ + import matplotlib.pyplot as plt # print(images.shape) # (50000, 32, 32, 3) # exit() if dtype: @@ -350,6 +353,7 @@ def tsne_embedding(embeddings, reverse_dictionary, plot_only=500, >>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary, ... plot_only=500, second=5, saveable=False, name='tsne') """ + import matplotlib.pyplot as plt def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5, saveable=True, name='tsne', fig_idx=9862): assert low_dim_embs.shape[0] >= len(labels), "More labels than embeddings"