Skip to content

Fail-safe and partial redundancy for HSDP on unreliable compute #561

@evkogs

Description

@evkogs

I'd like to propose a feature for implementing fail-safe mechanisms and partial redundancy in FSDP2 (possibly not FSDP already, more like HSDP) to allow for more robust training on unreliable compute resources, such as cloud spot instances. The main goal is to make training more resilient to node failures, GPU issues, and other potential interruptions.

Key points:

  1. Implement an abstraction over DDP and FSDP with configurable parameters for redundancy.
  2. Allow for partial redundancy, similar to RAID5 or RAID6 concepts, where full redundancy would be equivalent to DDP and zero redundancy would be equivalent to FSDP full-shard or Zero-3.
  3. Mitigate node failures and individual GPU failures by storing additional fractions (e.g., 1/8 or 1/4) of other nodes' optimizer states on each node.
  4. Trade-off between memory usage and all-reduce overhead (estimated 10-20%) for increased training resilience.
  5. Implement automatic downscaling with resharding and upscaling with automatic sharding, with a configurable overlapping sharding parameter (0.0 to 1.0).

Use case examples:

  1. Training on cloud spot instances that may be terminated mid-training.
  2. Giant model training on 99.9% reliable hardware, protecting against network adapter failures, power outages, etc.
  3. Enabling cross-regional model training on spot instances or multi-region clusters for colossal models.
  4. Supporting distributed training methods like DisTrO (https://github.com/NousResearch/DisTrO) that allow training over the internet with much lower throughput requirements than traditional all-reduce approach.

This feature would greatly enhance the flexibility and reliability of large-scale distributed training, especially in scenarios where compute resources are not guaranteed to be stable throughout the entire training process.

A key aspect of this implementation would be an overlapping factor, ranging from 0.0 to 1.0, which determines the degree of redundancy. For example, with 64 GPUs across 8 nodes:

  • An overlapping factor of 0.0 would be equivalent to standard FSDP (no redundancy).
  • An overlapping factor of 0.125 (1/8) would allow for one node failure without interrupting training.
  • An overlapping factor of 0.25 (1/4) would provide resilience against two simultaneous node failures.
  • An overlapping factor of 1.0 would be equivalent to full DDP (complete redundancy).

The system would need to integrate downscaling with resharding and automatic restoring, as well as upscaling with automatic sharding, all governed by this specified overlapping factor (probably using Kubernetes with torchx, for example).

I'd be happy to discuss this further and provide more details if needed! Looking forward to your thoughts on this proposal!

Metadata

Metadata

Assignees

Labels

enhancementNew feature or request

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions