-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🚀 The feature, motivation and pitch
Models are getting harder to fit on a limited number of GPUs and ADAM doesn't help since its memory overhead is 2N where N is the number of parameters in a model
We don't like to merge optimizers in core because they rarely stand the test of time but ADAM has and a memory efficient alternative that's been in use at many startups I've talked to, Twitter and larger companies https://github.com/google-research/t5x/blob/main/t5x/adafactor.py, has been Adafactor, see the discussion here https://twitter.com/HamelHusain/status/1702004478369743125 - it's also come up in a github issue before here #98434
Assigning to myself since I want a starter task in optimizers
Alternatives
There is a working implementation in fairseq which Huggingface has borrowed https://github.com/facebookresearch/fairseq/blob/main/fairseq/optim/adafactor.py#L66 which is a good starting point
Additional context
You might be thinking why merge yet another optimizer? Optimizers are plagued by lack of reproducibility and sensitivity to hyperparameters
However, ADAM has stood the test of time but ADAM also has a high memory overhead, for each parameter you need to store the first and second moment so if your model has N parameters then your optimizer state is 2N. Also few people change its hyperparameters in fact the hyperparams have remained the same since torch days
- https://github.com/pytorch/pytorch/blob/main/torch/optim/adam.py
- https://github.com/torch/optim/blob/master/adam.lua
So as Blueberries projects move towards finetuning after PyTorch conference and as more members of the community try to fit larger models on their GPUs it's critical we find memory efficient alternative to ADAM which is why Adafactor might just be a strong candidate for a default optimizer we recommend to people finetuning with a limited amount of GPUs and this includes ourselves in the torchbench CI, Adafactor comes with its own LR scheduler and will help us not to have to dramatically reduce batch sizes in our benchmarks