Skip to content

Commit

Permalink
perf(draw): 保存绘制图片
Browse files Browse the repository at this point in the history
  • Loading branch information
zjZSTU committed May 5, 2020
1 parent f16c6fa commit 5243103
Show file tree
Hide file tree
Showing 8 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/2_nn_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@
plt = Draw()
plt(solver.loss_history)
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
title='准确率', xlabel='迭代/次', ylabel='准确率')
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))
2 changes: 1 addition & 1 deletion examples/3_nn_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@
plt = Draw()
plt(solver.loss_history)
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
title='准确率', xlabel='迭代/次', ylabel='准确率')
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))
2 changes: 1 addition & 1 deletion examples/3_nn_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@
plt = Draw()
plt(solver.loss_history)
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
title='准确率', xlabel='迭代/次', ylabel='准确率')
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))
2 changes: 1 addition & 1 deletion examples/3_nn_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@
plt = Draw()
plt(solver.loss_history)
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
title='准确率', xlabel='迭代/次', ylabel='准确率')
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))
2 changes: 1 addition & 1 deletion examples/3_nn_orl.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@
plt = Draw()
plt(solver.loss_history)
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
title='准确率', xlabel='迭代/次', ylabel='准确率')
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))
2 changes: 1 addition & 1 deletion examples/lenet5_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@
plt = Draw()
plt(solver.loss_history)
plt.multi_plot((solver.train_acc_history, solver.val_acc_history), ('train', 'val'),
title='准确率', xlabel='迭代/次', ylabel='准确率')
title='准确率', xlabel='迭代/次', ylabel='准确率', save_path='acc.png')
print('best_train_acc: %f; best_val_acc: %f' % (solver.best_train_acc, solver.best_val_acc))
3 changes: 2 additions & 1 deletion examples/nin_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def nin_train():

draw = vision.Draw()
draw(loss_list, xlabel='迭代/20次')
draw.multi_plot((train_list, test_list), ('训练集', '测试集'), title='精度图', xlabel='迭代/20次', ylabel='精度值')
draw.multi_plot((train_list, test_list), ('训练集', '测试集'),
title='精度图', xlabel='迭代/20次', ylabel='精度值', save_path='acc.png')


if __name__ == '__main__':
Expand Down
6 changes: 4 additions & 2 deletions pynet/vision/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ class Draw(object):
def __call__(self, values, title='损失图', xlabel='迭代/次', ylabel='损失值'):
self.forward(values, title='损失图', xlabel='迭代/次', ylabel='损失值')

def forward(self, values, title='损失图', xlabel='迭代/次', ylabel='损失值'):
def forward(self, values, title='损失图', xlabel='迭代/次', ylabel='损失值', save_path='./loss.png'):
assert isinstance(values, list)
plt.title(title)
plt.ylabel(ylabel)
plt.xlabel(xlabel)
plt.plot(values)
plt.savefig(save_path)
plt.show()

def multi_plot(self, values_list, labels_list, title='损失图', xlabel='迭代/次', ylabel='损失值'):
def multi_plot(self, values_list, labels_list, title='损失图', xlabel='迭代/次', ylabel='损失值', save_path='./loss.png'):
assert isinstance(values_list, tuple)
assert isinstance(labels_list, tuple)
assert len(values_list) == len(labels_list)
Expand All @@ -31,4 +32,5 @@ def multi_plot(self, values_list, labels_list, title='损失图', xlabel='迭代
for i in range(len(values_list)):
plt.plot(values_list[i], label=labels_list[i])
plt.legend()
plt.savefig(save_path)
plt.show()

0 comments on commit 5243103

Please sign in to comment.