Skip to content

[ZeRo] Parameter group support in constructor #71347

@rohan-varma

Description

@rohan-varma

🚀 The feature, motivation and pitch

Motivated by: https://discuss.pytorch.org/t/adamw-zeroredundancyoptimizer-weight-decay-dictionary/141516

User might create optimizer such as:

optimizer = optim.AdamW(
                [
                    {"params": gain_or_bias_params, "weight_decay": 0.},
                    {"params": rest_params, "weight_decay": args.wd},
                ],
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=args.eps,
            )

where the first list is a dict specifying param groups. However,

optimizer = ZeroRedundancyOptimizer(
                [
                    {"params": gain_or_bias_params, "weight_decay": 0.},
                    {"params": rest_params, "weight_decay": args.wd},
                ],
                optim.AdamW,
                lr=args.lr,
                betas=(args.beta1, args.beta2),
                eps=args.eps,
            )

does not work as expected. A workaround is the following:

optimizer.add_param_group({"params": gain_or_bias_params, "weight_decay": 0.})

but for a better dev experience we should support passing them directly in the constructor.

Alternatives

No response

Additional context

No response

cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang

Metadata

Metadata

Assignees

Labels

oncall: distributedAdd this issue/PR to distributed oncall triage queuept_distributed_rampupRamp up tasks for new developers on PT distributedtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions