Skip to content

Conversation

@hgt312
Copy link
Collaborator

@hgt312 hgt312 commented Feb 17, 2023

Intro

A simple ZeRO-1 implementation in for torch-xla. Similar API with pytorch's: https://github.com/pytorch/pytorch/blob/master/torch/distributed/optim/zero_redundancy_optimizer.py, but different implmentation.

PyTorch use the legacy logic in DeepSpeed, this approach separates params in several groups, and each rank is responsible for a group. This will make the workload on each rank be different and may be unbalanced.
We are proposing to use the updated logic in DeepSpeed, which split/slice each tensor to world size and each rank is responsible for a partition, so that the workload on each are same. Furthermore, we use reduce_scatter instead of all_reduce + split/chunk/slice and shard params on CPU firstly to reduce generated graphs and achieve SPMD.

Usage

Just wrap the optimizer:

from torch_xla.distributed.zero_redundancy_optimizer import ZeroRedundancyOptimizer

optimizer_wrapper = ZeROStage1(model.parameters(), SGD, ...)
xm.mark_step()
for ...:
    loss = model(inputs)
    loss.backward()
    xm.mark_step()
    ...
    optimizer_wrapper.step()
    xm.mark_step()
    ...

Misc

About ZeRO-1, see https://arxiv.org/abs/1910.02054

ZeRO-1 workflow (underlined parts are core logics): 
model -> all-reduce grads of parameters ->
shard grads and assign to pre-sharded parameters -> optimizer ->
all-gather sharded parameters

Status

  • test on colab, it works
  • with an option to enable/disable grad clipping
  • with an option to determine optimizer dtype
  • missing functionality: padding

@jspisak
Copy link
Contributor

jspisak commented Feb 17, 2023

@mrshenli - any thoughts here on usage and whether to land this in torchxla?

@JackCaoG JackCaoG requested a review from Liyang90 February 17, 2023 20:35
@miladm miladm requested a review from mrshenli February 17, 2023 23:14
Copy link

@mrshenli mrshenli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry about the delay. LGTM! Left some minor comments. Question, will FSDP ZeRO2 always performs better than doing this in the optimizer? The former also runs reduce_scatter + all_gather, but those comm ops can better overlap with forward and backward computations.

Hey @wconstab, is there any items that we need to check before landing?

"""
xm.unlazy(self.params)
for param in self.params:
shard_data = param.data.to(device="cpu") # move to cpu

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, any reason this needs to be moved to CPU before sharding? The saw the following explanation in the PR summary:

shard params on CPU firstly to reduce generated graphs and achieve SPMD

Could you please elaborate on why moving to CPU has an impact on generated graphs and SPMD? Does it mean there is no way to disable lazy op recording besides moving things to CPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a mistake that it has no impact on num of graphs (1 vs 1).
In our use case, usually there is a mark_step after model init, this will generate one graph:

model, optimizer = init_model_and_optimizer(...)
xm.mark_step()

as shard depends on rank, if we shard on xla device, each rank will generate a different graph. by moving this to cpu, we can share only one compiled graph between processes

@hgt312 hgt312 force-pushed the add_zero1 branch 2 times, most recently from 7654718 to 3d9fcd2 Compare February 21, 2023 01:50
@hgt312
Copy link
Collaborator Author

hgt312 commented Feb 21, 2023

Sorry about the delay. LGTM! Left some minor comments. Question, will FSDP ZeRO2 always performs better than doing this in the optimizer? The former also runs reduce_scatter + all_gather, but those comm ops can better overlap with forward and backward computations.

Hey @wconstab, is there any items that we need to check before landing?

In our use cases, the scripts are like:

for ...:
    loss = model(inputs)
    loss.backward()
    xm.mark_step()
    if ...:  # meet grad acc boundary
        optimizer_wrapper.step()
        xm.mark_step()
    ...

fwd+bwd in one graph, optimizer and some misc (norm, grad clipping) in one graph, and the two graphs are not overlapped

for the performance comparison between ZeRO-1 and ZeRO-2, for some model config, ZeRO-1 is better on GPU
for xla, we have not test them and do not have a ZeRO-2 implementation for xla now

we'd like to use FSDP ZeRO-2 but FSDP in xla only have ZeRO-3

@wconstab
Copy link
Collaborator

i'm wondering why this lands in pytorch/pytorch instead of pytorch/xla - after reading the code it seems specific to xla.

otherwise, it looks ok to me. (Similar questions as @hgt312 about perf, i'm curious what the overall plan for this is and whether someone is also working on a version with overlap).

@JackCaoG
Copy link
Collaborator

@wconstab this pr is going to land to pytorch/xla, through the branch name pytorch:master is a bit misleading.

@JackCaoG
Copy link
Collaborator

I will land this change to the master, but I will most likely leave this one out of the 2.0 release.

@hgt312
Copy link
Collaborator Author

hgt312 commented Feb 21, 2023

I will land this change to the master, but I will most likely leave this one out of the 2.0 release.

Thanks for your reply! Which version do you consider this?

And do your have the plan to update FSDP in ptxla to align FSDP in pytorch to support more features?

@JackCaoG
Copy link
Collaborator

It will most likely be in the 2.1 release, we are shifting our development focus to the SPMD/DTensor. We will maintain the FSDP code and add features as user request. This might change when the compiler version of the FSDP is being developed on the upstream, hopefully we can share the same implementation by then.

@JackCaoG JackCaoG merged commit 1d313bb into pytorch:master Feb 21, 2023
mateuszlewko pushed a commit that referenced this pull request Mar 15, 2023
* init

* test

* lint

* address comments
@alanwaketan
Copy link
Collaborator

@JackCaoG Does it work in TPU? So far, the test (test/test_zero1.py) crashes on my local TPU run.

root@t1v-n-307ffe96-w-0:/workspaces/work/pytorch/xla# PJRT_DEVICE=TPU python test/test_zero1.py
2023-04-20 17:57:03.565392: F ./third_party/xla_client/debug_macros.h:20] Non-OK-status: status.status() status: INTERNAL: during context [pre-optimization]: RET_CHECK failure (third_party/tensorflow/compiler/xla/service/hlo_verifier.cc:368) replica_count == 1 || n == replica_count In kCrossReplica mode, replica groups should contain 4 replicas, but found 1: %reduce-scatter.17 = f32[8,8]{1,0} reduce-scatter(f32[8,8]{1,0} %add.12), replica_groups={{0}}, constrain_layout=true, dimensions={0}, to_apply=%AddComputation.13
*** Begin stack trace ***
        tsl::CurrentStackTrace[abi:cxx11]()
        std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > ConsumeValue<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >(tsl::StatusOr<std::unique_ptr<xla::PjRtLoadedExecutable, std::default_delete<xla::PjRtLoadedExecutable> > >&&)
        xla::PjRtComputationClient::Compile(std::vector<xla::ComputationClient::CompileInstance, std::allocator<xla::ComputationClient::CompileInstance> >)
        torch_xla::XLAGraphExecutor::Compile(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > > const&, absl::lts_20220623::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorCollection const&, torch::lazy::LazyGraphExecutor::PostOrderData*, std::vector<torch::lazy::Value, std::allocator<torch::lazy::Value> > const&)
        torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20220623::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, torch::lazy::LazyGraphExecutor::SyncTensorsConfig const&, bool)
        torch_xla::XLAGraphExecutor::SyncTensorsGraph(std::vector<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> >, std::allocator<c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > > >*, absl::lts_20220623::Span<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const>, bool, bool, bool)
        torch_xla::XLATensor::ApplyPendingGraph()
        torch_xla::XLATensor::GetXlaData()
        torch_xla::XLATensor::ToTensor(bool)
        torch_xla::XLANativeFunctions::_to_copy(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)




        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)



        at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)





        at::_ops::_to_copy::redispatch(c10::DispatchKeySet, at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)





        at::_ops::_to_copy::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, c10::optional<c10::MemoryFormat>)


        at::native::to(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, bool, c10::optional<c10::MemoryFormat>)



        at::_ops::to_dtype_layout::call(at::Tensor const&, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, bool, bool, c10::optional<c10::MemoryFormat>)




        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall


        PyObject_Repr

        PyObject_Repr

        PyObject_Repr

        PyObject_Str


        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall

        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyObject_FastCallDict
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyObject_FastCallDict
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyFunction_Vectorcall

        PyVectorcall_Call
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyObject_FastCallDict
        _PyObject_Call_Prepend

        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        _PyObject_FastCallDict
        _PyObject_Call_Prepend


        _PyObject_MakeTpCall
        _PyEval_EvalFrameDefault
        _PyEval_EvalCodeWithName
        PyEval_EvalCodeEx
        PyEval_EvalCode



        PyRun_SimpleFileExFlags

        Py_BytesMain
        __libc_start_main
        _start
*** End stack trace ***

https://symbolize.stripped_domain/r/?trace=7f5c4e2a87bb,7f5c4e5d272f,7f5adc7ec918,7f5adc7f1c49,7f5aeeb245c6,7f5aeeb0bbaa,7f5aeeb08235,7f5aef620da0,7f5aef61f69c,7f5aef6276d7,7f5aef3a6991,7f5af059299d,7f5af05933f5,7f5c39d06061,7f5c39d06f3b,7f5c39a89861,7f5c3ac5b917,7f5c3ac5befa,7f5c39d06061,7f5c39a89180,7f5c388eeee4,7f5c388ebaa7,7f5c388f7f2a,7f5c39d06061,7f5c39d06f3b,7f5c39a89861,7f5c3d17938c,7f5c3d016180,7f5c3d015566,7f5c3d01670d,7f5c39d06061,7f5c39a89180,7f5c388eeee4&map=0f4b7d67a422054324b2fbb70065a3735113a8bd:7f5ad9538000-7f5aed1d12f0 
*** SIGABRT received by PID 1770238 (TID 1770238) on cpu 239 from PID 1770238; stack trace: ***
PC: @     0x7f5c4e2a87bb  (unknown)  raise
    @     0x7f5ad89e4c5a       1152  (unknown)
    @     0x7f5c4e5d2730  (unknown)  (unknown)
    @     0x7f5adc7ec919        496  ConsumeValue<>()
    @     0x7f5adc7f1c4a       4592  xla::PjRtComputationClient::Compile()
    @     0x7f5aeeb245c7      13280  torch_xla::XLAGraphExecutor::Compile()
    @     0x7f5aeeb0bbab       3040  torch_xla::XLAGraphExecutor::SyncTensorsGraphInternal()
    @     0x7f5aeeb08236       1120  torch_xla::XLAGraphExecutor::SyncTensorsGraph()
    @     0x7f5aef620da1        432  torch_xla::XLATensor::ApplyPendingGraph()
    @     0x7f5aef61f69d       2352  torch_xla::XLATensor::GetXlaData()
    @     0x7f5aef6276d8        928  torch_xla::XLATensor::ToTensor()
    @     0x7f5aef3a6992       1488  torch_xla::XLANativeFunctions::_to_copy()
    @     0x7f5af059299e        224  at::(anonymous namespace)::(anonymous namespace)::wrapper_XLA___to_copy()
    @     0x7f5af05933f6        416  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f5c39d06062        288  c10::callUnboxedKernelFunction<>()
    @     0x7f5c39d06f3c        688  c10::Dispatcher::redispatch<>()
    @     0x7f5c39a89862        448  at::_ops::_to_copy::redispatch()
    @     0x7f5c3ac5b918        320  at::(anonymous namespace)::_to_copy()
    @     0x7f5c3ac5befb        368  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f5c39d06062        288  c10::callUnboxedKernelFunction<>()
    @     0x7f5c39a89181       1216  at::_ops::_to_copy::call()
    @     0x7f5c388eeee5        176  at::_to_copy()
    @     0x7f5c388ebaa8        336  _to_copy_functionalize()
    @     0x7f5c388f7f2b        368  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f5c39d06062        288  c10::callUnboxedKernelFunction<>()
    @     0x7f5c39d06f3c        688  c10::Dispatcher::redispatch<>()
    @     0x7f5c39a89862        448  at::_ops::_to_copy::redispatch()
    @     0x7f5c3d17938d        208  at::redispatch::_to_copy()
    @     0x7f5c3d016181        208  torch::autograd::VariableType::(anonymous namespace)::_to_copy()::$_49::operator()()
    @     0x7f5c3d015567        912  torch::autograd::VariableType::(anonymous namespace)::_to_copy()
    @     0x7f5c3d01670e        416  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f5c39d06062        288  c10::callUnboxedKernelFunction<>()
    @     0x7f5c39a89181       1216  at::_ops::_to_copy::call()
    @     0x7f5c388eeee5        176  at::_to_copy()
    @     0x7f5c3931d829        224  at::native::to_impl()
    @     0x7f5c3931d081        240  at::native::to()
    @     0x7f5c3b397383        192  at::(anonymous namespace)::(anonymous namespace)::wrapper_CompositeImplicitAutograd_dtype_layout_to()
    @     0x7f5c3b49c8a5        384  c10::impl::wrap_kernel_functor_unboxed_<>::call()
    @     0x7f5c3a104cca        304  c10::callUnboxedKernelFunction<>()
    @     0x7f5c39f12a7d       1248  at::_ops::to_dtype_layout::call()
    @     0x7f5c4c514793        208  at::Tensor::to()
    @     0x7f5c4c4ef4af        192  torch::autograd::dispatch_to()
    @     0x7f5c4c42937a        880  torch::autograd::THPVariable_to()
    @     0x7f5c4e7492f7  (unknown)  method_vectorcall_VARARGS_KEYWORDS
https://symbolize.stripped_domain/r/?trace=7f5c4e2a87bb,7f5ad89e4c59,7f5c4e5d272f,7f5adc7ec918,7f5adc7f1c49,7f5aeeb245c6,7f5aeeb0bbaa,7f5aeeb08235,7f5aef620da0,7f5aef61f69c,7f5aef6276d7,7f5aef3a6991,7f5af059299d,7f5af05933f5,7f5c39d06061,7f5c39d06f3b,7f5c39a89861,7f5c3ac5b917,7f5c3ac5befa,7f5c39d06061,7f5c39a89180,7f5c388eeee4,7f5c388ebaa7,7f5c388f7f2a,7f5c39d06061,7f5c39d06f3b,7f5c39a89861,7f5c3d17938c,7f5c3d016180,7f5c3d015566,7f5c3d01670d,7f5c39d06061,7f5c39a89180,7f5c388eeee4,7f5c3931d828,7f5c3931d080,7f5c3b397382,7f5c3b49c8a4,7f5c3a104cc9,7f5c39f12a7c,7f5c4c514792,7f5c4c4ef4ae,7f5c4c429379,7f5c4e7492f6&map=0f4b7d67a422054324b2fbb70065a3735113a8bd:7f5ad9538000-7f5aed1d12f0,53db522eddd332defd5077af3b470d11:7f5acd633000-7f5ad8bfb980 
E0420 17:57:04.111062 1770238 coredump_hook.cc:408] RAW: Remote crash data gathering hook invoked.
E0420 17:57:04.111081 1770238 client.cc:265] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0420 17:57:04.111096 1770238 coredump_hook.cc:506] RAW: Sending fingerprint to remote end.
E0420 17:57:04.111106 1770238 coredump_socket.cc:119] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0420 17:57:04.111110 1770238 coredump_hook.cc:514] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0420 17:57:04.111121 1770238 coredump_hook.cc:566] RAW: Dumping core locally.
E0420 17:57:47.103827 1770238 process_state.cc:784] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

Will skip the test for TPU.

alanwaketan pushed a commit that referenced this pull request May 2, 2023
implementation and test in previous PR #4648

reduce local norm across shards
@hgt312 hgt312 mentioned this pull request Jun 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants