Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add mnist swamp loss figure #97

Merged
merged 3 commits into from
Dec 20, 2018
Merged

add mnist swamp loss figure #97

merged 3 commits into from
Dec 20, 2018

Conversation

yuyicg
Copy link
Contributor

@yuyicg yuyicg commented Dec 19, 2018

Fix #87
add swamp (loss / timestamp) figure using matplotlib
steps to run:

  1. exec cd ${elasticdl_root}/experimental/swamp-optimization/mnist.
  2. exec python mnist.py.
  3. you would get a loss curve image file with default name loss.png in current dir.
  4. pull the png image to your local pc to show.

more metric figures like accuracy will be added later.

zou000
zou000 previously approved these changes Dec 19, 2018
Copy link
Collaborator

@wangkuiyi wangkuiyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 没有更新 README.md,说明如何使用这个修改过的程序。
  2. 没有在这个PR的description写”Fix https://github.com/wangkuiyi/elasticdl/issues/87“

time_costs.append(round(time.time() - start_time))
losses.append(round(loss.item(), 4))

if batch_idx % args.log_interval == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然这个程序可以自己画图了,就不用输出log来依赖其他工具(例如gnuplot)画图了吧?那么 args.log_interval貌似可以不要了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

args.log_interval 这个参数用于打印当前运行的进度(epoch, batch id),从而让使用者有一个合理的预期(当前进度,预计结束时间),防止因长时间运行没有任何输出,导致调用者误认为程序hung住。

plot.ylabel('loss')
plot.legend(loc=7)
plot.title('swamp training of mnist data')
plot.savefig(args.loss_file_prefix + '_' + str(tid) + '.png')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样是每个 trainer 都会画一个 png 出来?按照 #87 的描述,应该是一共只有一个 png 文件,每个 trianer 以及 ps 对应其中一条曲线。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,我合并一下绘制成一张png图片。

help='metrics-dir')
parser.add_argument('--loss-file-prefix', default='loss',
help='the name of loss figure file')
parser.add_argument('--accuracy-file', default='accuracy.png',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

accuracy_file 没有被用到


if args.metrics_loss_enabled and batch_idx % args.metrics_sample_interval == 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

既然loss都算出来了,为什么不都加入到 losses 里,还要sample呢?我理解 args.metrics_sample_interval 这个参数可以不要了。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本意是通过sample降低用户绘图数据集的大小,减小process的内存压力,可以改为用全量loss进行绘图。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

用全量的loss数据测试后发现程序运行会变慢一些,由于点比较多,随着trainer个数的增加,绘制出曲线图的可读性有些下降。

@@ -159,14 +178,26 @@ def main():
help='batch size for validation dataset in ps')
parser.add_argument('--validate_max_batch', default=5,
help='max batch for validate model in ps')
parser.add_argument('--metrics-dir', default='./',
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metrics_dir 这个参数没有被用到。应该删了。

help='the name of accuracy figure file')
parser.add_argument('--metrics-sample-interval', type=int, default=10, metavar='N',
help='how many batches to wait before sampling a metircs value')
parser.add_argument('--metrics-loss-enabled', type=bool, default=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metrics-loss-enabled 和 loss-file-prefix 二者留一个就行了。 loss-file-prefix => loss_file 如果不是 "",则画图。

@@ -122,7 +141,7 @@ def ps(args, up, down):
model_and_score = d
score = s
updates = updates + 1
print("updated", updates, score.data.item(), double_check_loss)
print("updated", updates, round(score.data.item(), 4), round(double_check_loss, 4))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

round(*, 4) 里的4是什么意思?我理解这一行log也不需要打印了,因为我们自己画图了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

原意是让日志更简洁,只保留loss小数点后四位,用图表方式确实不用打印这条log了。

Copy link
Collaborator

@wangkuiyi wangkuiyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请按上面comments修改之后再merge。

Copy link
Collaborator

@wangkuiyi wangkuiyi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the change!

Approved with some comments. We could fire a new PR in the future to fix the comments if necessary.

loss_val = eval_loss / max_batch
return loss_val

def time_costs(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要定义 time_costs 这个 method呢?直接写 ps.time_costs 访问 data member 看上去更简单?

def tid(self):
return self._tid

class Ps(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PS 是 parameter server 的缩写,所以两个字母都应该大写。

def losses(self):
return self._losses

def join(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

join 操作是指 几个 threads 的 join -- 也就是等到都结束。这个函数做的只是设置一个变量值,使得 thread 里的循环停止 —— 这个操作不是 join,而是 terminate。

self._losses = []
self._exit = False

def run(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run 函数太长了。得拆成几个函数。Python 用 indentation 标志程序结构,每个函数的行数不能超过屏幕的一半高度,否则就看不明白了。


trainers = []
trainer_threads = []
for t in range(4):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 4 可能值得弄成一个command line argument?

self._time_costs = []
self._losses = []

def train(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

train 行数太长了。得拆成几个函数。

if up != None:
up.put(pickle.dumps(
{"model": model.state_dict(), "opt": optimizer.state_dict(), "loss": loss.data}))
class Trainer(object):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞引入 class来分组函数!

for thread in trainer_threads:
thread.join()

ps.join()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是你想做的是让 class PS 从 Thread 派生出来,然后 重载 PS.join 来设置 _exit = True 并且调用 super class Thread 的 join ?

@wangkuiyi
Copy link
Collaborator

另外,我按照 README.md 运行这个程序,报错没有 matplotlib:

metrics_figure)$ docker run --rm -it -v $PWD:/work -w /work pytorch/pytorch python mnist.py
Traceback (most recent call last):
  File "mnist.py", line 15, in <module>
    from matplotlib import pyplot as plot
ModuleNotFoundError: No module named 'matplotlib'

所以得有一个Dockerfile,FROM pytorch/pytorch:latest,并且安装 matplotlib;并且得更新 README.md

@yuyicg yuyicg merged commit fd07dd9 into develop Dec 20, 2018
@yuyicg yuyicg deleted the add_swamp_metrics_figure branch December 20, 2018 17:02
@yuyicg yuyicg mentioned this pull request Dec 21, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants