Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stop_gradient参数之后,ema存储的权重和非ema存储的不一致 #62622

Open
Happy-zyy opened this issue Mar 11, 2024 · 1 comment
Open
Assignees
Labels

Comments

@Happy-zyy
Copy link

Happy-zyy commented Mar 11, 2024

5bc7b40ddbf658893258163ba

如上图,可以看到word_embeding已freeze住,不同step是一致的,但是其对应的ema参数却不一致

存储代码如下:

def save_model_wrapper(args, exe, test_prog, graph_vars, model_name, step_name, ema_optim=None):
    save_path = os.path.join(args.output_dir, step_name, model_name)
    ema_save_path = os.path.join(args.output_dir, step_name, 'ema_' + model_name)
    try:
        fluid.io.save_persistables(exe, save_path, test_prog) # do not apply ema
        if args.use_ema and ema_optim is not None:
            log.info('save_ema_model to %s' % ema_save_path)
            with ema_optim.apply(exe, need_restore=True):
                fluid.io.save_inference_model(ema_save_path, feeded_var_names=graph_vars['infer_input_vars_name'], target_vars=graph_vars['infer_output_vars'], main_program=test_prog, executor=exe)
        else:
            log.info('save non-ema model to %s' % ema_save_path)
            fluid.io.save_inference_model(ema_save_path, feeded_var_names=graph_vars['infer_input_vars_name'], target_vars=graph_vars['infer_output_vars'], main_program=test_prog, executor=exe)

    except Exception as e:
        log.error('Save Model Error:%s' % str(e))
@JZ-LIANG
Copy link
Contributor

你们使用静态图训练对吧, 可以把 program 打印出来,看一下 program 中是否有算子修改了 对应的 parameter。
比如:
word_embedding 是否有类似 adam 等 op 修改了。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants