Skip to content

xiabingquan/zero_optim

Repository files navigation

zero_optim_toy

A minimal, from-scratch implementation of ZeRO Stage 1 (Distributed Optimizer) for educational purposes.

中文版

Architecture

Three-layer design following the real distributed optimizer stack:

Buffer  →  DDP  →  Distributed Optimizer
Layer File Responsibility
Buffer buffer.py Flat contiguous storage, padding, shard view, all-gather / reduce-scatter
DDP ddp.py Param / grad buffer creation, remapping .data and .main_grad, gradient and parameter sync
Distributed Optimizer distributed_optimizer.py fp32 Adam on 1/N shard, bf16 writeback

What ZeRO-1 shards

Only optimizer state (fp32 master params + Adam m/v) is sharded. Each rank keeps full copies of bf16 model parameters and gradients — no forward/backward hooks needed.

Training step

forward           full bf16 params in param_buffer
backward          grads written to param.grad
sync_grads()      copy param.grad → grad_buffer, reduce-scatter(SUM)
optimizer.step()  grad shard bf16→fp32, Adam, fp32→bf16 writeback, all-gather

Mixed precision flow

grad_buffer (bf16, full)
  → reduce-scatter → grad shard (bf16, P/N)
  → float() → shard_fp32.grad (fp32, P/N)
  → Adam.step() → shard_fp32 (fp32, P/N)
  → bfloat16() → param_buffer shard (bf16, P/N)
  → all-gather → param_buffer (bf16, full)

Files

model.py                  Simple multi-layer MLP (test model)
buffer.py                 Buffer (contiguous storage + shard view)
ddp.py                    DistributedDataParallel (param/grad remapping + sync)
distributed_optimizer.py  DistributedOptimizer (fp32 Adam on shards)
test_zero.py              Multi-step training correctness tests
profile_memory.py         GPU memory profiling (baseline vs ZeRO-1)
DESIGN.md                 Detailed design document (Chinese)

Running tests

Requires multiple CUDA GPUs and the nccl backend:

# all tests
python -m pytest test_zero.py -v

# single test
python -m pytest test_zero.py::TestZeROTraining::test_multi_step_2gpu -v

# or directly
python test_zero.py

The tests verify that ZeRO-1 training produces bit-exact identical parameters as single-process reference training following the same precision path.

Requirements

  • Python >= 3.8
  • PyTorch >= 2.1
  • Multiple CUDA GPUs

About

A toy implementation of ZeRO Optimizer

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages