- 
                Notifications
    
You must be signed in to change notification settings  - Fork 559
 
[Dist] Add ZeRO-1 Optimizer #4648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 
           @mrshenli - any thoughts here on usage and whether to land this in torchxla?  | 
    
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry about the delay. LGTM! Left some minor comments. Question, will FSDP ZeRO2 always performs better than doing this in the optimizer? The former also runs reduce_scatter + all_gather, but those comm ops can better overlap with forward and backward computations.
Hey @wconstab, is there any items that we need to check before landing?
| """ | ||
| xm.unlazy(self.params) | ||
| for param in self.params: | ||
| shard_data = param.data.to(device="cpu") # move to cpu | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
curious, any reason this needs to be moved to CPU before sharding? The saw the following explanation in the PR summary:
shard params on CPU firstly to reduce generated graphs and achieve SPMD
Could you please elaborate on why moving to CPU has an impact on generated graphs and SPMD? Does it mean there is no way to disable lazy op recording besides moving things to CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a mistake that it has no impact on num of graphs (1 vs 1).
In our use case, usually there is a mark_step after model init, this will generate one graph:
model, optimizer = init_model_and_optimizer(...)
xm.mark_step()as shard depends on rank, if we shard on xla device, each rank will generate a different graph. by moving this to cpu, we can share only one compiled graph between processes
7654718    to
    3d9fcd2      
    Compare
  
    
          
 In our use cases, the scripts are like: for ...:
    loss = model(inputs)
    loss.backward()
    xm.mark_step()
    if ...:  # meet grad acc boundary
        optimizer_wrapper.step()
        xm.mark_step()
    ...fwd+bwd in one graph, optimizer and some misc (norm, grad clipping) in one graph, and the two graphs are not overlapped for the performance comparison between ZeRO-1 and ZeRO-2, for some model config, ZeRO-1 is better on GPU we'd like to use FSDP ZeRO-2 but FSDP in xla only have ZeRO-3  | 
    
| 
           i'm wondering why this lands in pytorch/pytorch instead of pytorch/xla - after reading the code it seems specific to xla. otherwise, it looks ok to me. (Similar questions as @hgt312 about perf, i'm curious what the overall plan for this is and whether someone is also working on a version with overlap).  | 
    
| 
           @wconstab this pr is going to land to pytorch/xla, through the branch name pytorch:master is a bit misleading.  | 
    
| 
           I will land this change to the master, but I will most likely leave this one out of the 2.0 release.  | 
    
          
 Thanks for your reply! Which version do you consider this? And do your have the plan to update FSDP in ptxla to align FSDP in pytorch to support more features?  | 
    
| 
           It will most likely be in the 2.1 release, we are shifting our development focus to the SPMD/DTensor. We will maintain the FSDP code and add features as user request. This might change when the compiler version of the FSDP is being developed on the upstream, hopefully we can share the same implementation by then.  | 
    
* init * test * lint * address comments
| 
           @JackCaoG Does it work in TPU? So far, the test (test/test_zero1.py) crashes on my local TPU run. Will skip the test for TPU.  | 
    
implementation and test in previous PR #4648 reduce local norm across shards
Intro
A simple ZeRO-1 implementation in for torch-xla. Similar API with pytorch's: https://github.com/pytorch/pytorch/blob/master/torch/distributed/optim/zero_redundancy_optimizer.py, but different implmentation.
PyTorch use the legacy logic in DeepSpeed, this approach separates params in several groups, and each rank is responsible for a group. This will make the workload on each rank be different and may be unbalanced.
We are proposing to use the updated logic in DeepSpeed, which split/slice each tensor to world size and each rank is responsible for a partition, so that the workload on each are same. Furthermore, we use
reduce_scatterinstead ofall_reduce + split/chunk/sliceand shard params on CPU firstly to reduce generated graphs and achieve SPMD.Usage
Just wrap the optimizer:
Misc
About ZeRO-1, see https://arxiv.org/abs/1910.02054
Status