-
Notifications
You must be signed in to change notification settings - Fork 114
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add mnist swamp loss figure #97
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- 没有更新 README.md,说明如何使用这个修改过的程序。
- 没有在这个PR的description写”Fix https://github.com/wangkuiyi/elasticdl/issues/87“
time_costs.append(round(time.time() - start_time)) | ||
losses.append(round(loss.item(), 4)) | ||
|
||
if batch_idx % args.log_interval == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然这个程序可以自己画图了,就不用输出log来依赖其他工具(例如gnuplot)画图了吧?那么 args.log_interval
貌似可以不要了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
args.log_interval
这个参数用于打印当前运行的进度(epoch, batch id),从而让使用者有一个合理的预期(当前进度,预计结束时间),防止因长时间运行没有任何输出,导致调用者误认为程序hung住。
plot.ylabel('loss') | ||
plot.legend(loc=7) | ||
plot.title('swamp training of mnist data') | ||
plot.savefig(args.loss_file_prefix + '_' + str(tid) + '.png') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样是每个 trainer 都会画一个 png 出来?按照 #87 的描述,应该是一共只有一个 png 文件,每个 trianer 以及 ps 对应其中一条曲线。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的,我合并一下绘制成一张png图片。
help='metrics-dir') | ||
parser.add_argument('--loss-file-prefix', default='loss', | ||
help='the name of loss figure file') | ||
parser.add_argument('--accuracy-file', default='accuracy.png', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
accuracy_file 没有被用到
|
||
if args.metrics_loss_enabled and batch_idx % args.metrics_sample_interval == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
既然loss都算出来了,为什么不都加入到 losses 里,还要sample呢?我理解 args.metrics_sample_interval 这个参数可以不要了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
本意是通过sample降低用户绘图数据集的大小,减小process的内存压力,可以改为用全量loss进行绘图。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用全量的loss数据测试后发现程序运行会变慢一些,由于点比较多,随着trainer个数的增加,绘制出曲线图的可读性有些下降。
@@ -159,14 +178,26 @@ def main(): | |||
help='batch size for validation dataset in ps') | |||
parser.add_argument('--validate_max_batch', default=5, | |||
help='max batch for validate model in ps') | |||
parser.add_argument('--metrics-dir', default='./', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metrics_dir 这个参数没有被用到。应该删了。
help='the name of accuracy figure file') | ||
parser.add_argument('--metrics-sample-interval', type=int, default=10, metavar='N', | ||
help='how many batches to wait before sampling a metircs value') | ||
parser.add_argument('--metrics-loss-enabled', type=bool, default=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metrics-loss-enabled 和 loss-file-prefix 二者留一个就行了。 loss-file-prefix => loss_file 如果不是 "",则画图。
@@ -122,7 +141,7 @@ def ps(args, up, down): | |||
model_and_score = d | |||
score = s | |||
updates = updates + 1 | |||
print("updated", updates, score.data.item(), double_check_loss) | |||
print("updated", updates, round(score.data.item(), 4), round(double_check_loss, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
round(*, 4)
里的4是什么意思?我理解这一行log也不需要打印了,因为我们自己画图了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原意是让日志更简洁,只保留loss小数点后四位,用图表方式确实不用打印这条log了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
请按上面comments修改之后再merge。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the change!
Approved with some comments. We could fire a new PR in the future to fix the comments if necessary.
loss_val = eval_loss / max_batch | ||
return loss_val | ||
|
||
def time_costs(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么要定义 time_costs 这个 method呢?直接写 ps.time_costs 访问 data member 看上去更简单?
def tid(self): | ||
return self._tid | ||
|
||
class Ps(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PS 是 parameter server 的缩写,所以两个字母都应该大写。
def losses(self): | ||
return self._losses | ||
|
||
def join(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
join 操作是指 几个 threads 的 join -- 也就是等到都结束。这个函数做的只是设置一个变量值,使得 thread 里的循环停止 —— 这个操作不是 join,而是 terminate。
self._losses = [] | ||
self._exit = False | ||
|
||
def run(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
run 函数太长了。得拆成几个函数。Python 用 indentation 标志程序结构,每个函数的行数不能超过屏幕的一半高度,否则就看不明白了。
|
||
trainers = [] | ||
trainer_threads = [] | ||
for t in range(4): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 4 可能值得弄成一个command line argument?
self._time_costs = [] | ||
self._losses = [] | ||
|
||
def train(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
train 行数太长了。得拆成几个函数。
if up != None: | ||
up.put(pickle.dumps( | ||
{"model": model.state_dict(), "opt": optimizer.state_dict(), "loss": loss.data})) | ||
class Trainer(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
赞引入 class来分组函数!
for thread in trainer_threads: | ||
thread.join() | ||
|
||
ps.join() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是你想做的是让 class PS 从 Thread 派生出来,然后 重载 PS.join 来设置 _exit = True 并且调用 super class Thread 的 join ?
另外,我按照 README.md 运行这个程序,报错没有 matplotlib:
所以得有一个Dockerfile,FROM pytorch/pytorch:latest,并且安装 matplotlib;并且得更新 README.md |
Fix #87
add swamp (loss / timestamp) figure using matplotlib
steps to run:
cd ${elasticdl_root}/experimental/swamp-optimization/mnist
.python mnist.py
.more metric figures like accuracy will be added later.