Skip to content
This repository has been archived by the owner on Jan 6, 2023. It is now read-only.

Parallelize membership discovery and train step, in order to improve the elastic performance. #60

Closed
wants to merge 3 commits into from

Conversation

umialpha
Copy link
Contributor

Due to my performance test, the overhead of coordinator.rendezvous_barrier is non-negligible. It is impractical for cluster scheduler to scale workers in minutes.

It is possible to parallelize the rendezvous_barrier and train_step because they are independent in theory.

I propose a simple solution about it.

@kiukchung
Copy link
Contributor

kiukchung commented Mar 19, 2020

Thanks for the design doc. This makes sense. We are currently in the phase of redesigning the architecture of torchelastic in our upcoming 0.2.0 release. In short, we are moving out the "elastic" functionality to an elastic-agent process which is responsible for detecting membership changes, monitoring and restarting worker processes. Now instead of the user having to manually implement State and train_step and be aware of the intricacies of worker failures, restarts, consistency etc, you can simply write a vanilla distributed pytorch job with good checkpoints and the agent will take care of ensuring that:

  1. All the information necessary to trivially create a torch process group is in the environment variable (the user just has to call torch.distributed.init_process_group() with no arguments
  2. workers are restarted as a group. e.g. if one fails everyone will be killed and restarted.

This is a much simpler model to understand than what we currently have (which is an in-process agent running inside the user space).

Take a sneak peak here in this unittest: https://github.com/pytorch/elastic/blob/master/test/agent/server/local_elastic_agent_test.py

In this new design re-rendezvous naturally runs in parallel to the train_step.

@umialpha
Copy link
Contributor Author

Do you have any timeline about the new release? I am eager to read about the new design and integrate the new release into our framework.

@kiukchung
Copy link
Contributor

The agent is already committed (https://github.com/pytorch/elastic/blob/master/torchelastic/agent/server/local_elastic_agent.py#L52)

I just published a PR with the launcher that is completely compatible and similar in usage with torch.distributed.launch. #65.

Check out the docs here: https://github.com/pytorch/elastic/pull/65/files#diff-d337650690ddced88d1c0c7187c979f9R17

Would love your feedback, would you be open to sharing your use-case? I'm curious about the setup (cloud - aws, gcp, azure - , on-prem, using k8 or not) the scale of your jobs and what makes elasticity important for you. It would really help us prioritize features and improve user experience.

@umialpha
Copy link
Contributor Author

umialpha commented Mar 23, 2020

Hi @kiukchung , I've just read your agent feature. I have some questions/concerns in mind.

  1. It is hard to specify the "nproc_per_node" on each node. Nodes within a cluster are sometimes heterogeneous. There is always trade-off between throughput
    and GPU efficiency, Consider the following scenario,

with 3 nodes, Node A and Node B with 4 gpus, Node C with 8 gpus, Now C is fully occupied, but A and B are totally free. A user submits a job requiring 8 gpus. If we set "nproc_per_node" to 8, the job will be hold. But if we set "nproc_per_node" to 4, we cannot benefit from GPU locality with Node C when it is available.

  1. What's the benefit of LocalWorkGroup? Is it for performance reason? To be honest, I am not familiar with the backend of pytorch.distributed. I would like to learn more about it. From the code, I haven't found place to 1) reduce between nodes and 2)sync state when new node comes in. Is this covered within the Store? Or every time we restart the agent(scale out | scale in), we have to restart from the latest checkpoint?

  2. What's the relationship between agent and the existing train_loop? Is it an alternative way or it will replace train_loop in the future?

@umialpha
Copy link
Contributor Author

Would love your feedback, would you be open to sharing your use-case? I'm curious about the setup (cloud - aws, gcp, azure - , on-prem, using k8 or not) the scale of your jobs and what makes elasticity important for you. It would really help us prioritize features and improve user experience.

Sure, I am glad to share some info about our use-case. I am in the Scheduling Team belonged to an internal cluster service in Microsoft. We used to offer the scheduling abilities to the traditional/ general jobs. Nowadays, we are going to support AI workload. It's still in POC phase. IMO, elasticity is one of the most features we need. It could offer us a lot of abilities. I will list some of examples.

  1. More flexible adaption to cluster load variations. Jobs are scaled in when the cluster is over-loaded and scaled out when the cluster is under-loaded.

  2. Job migration. With elasticity, we can migrate workers into one PCIe or one node to improve performance. Also, we can migrate workers to reduce internal fragment.

  3. Straggler mitigation. With elasticity, we can easily kill the straggler.

There are a lot of more benefits of elasticity. But we have to consider the following things.

  1. elasticity overhead. It is impossible to make agile and smart response in seconds if elasticity overhead takes many seconds.
  2. ease of use. From the imagenet example, it seems that user has to maintain the start_index of dataset. IMO, it is better to cover this dataset partition within framework.

@kiukchung
Copy link
Contributor

Hi @kiukchung , I've just read your agent feature. I have some questions/concerns in mind.

1. It is hard to specify the "nproc_per_node" on each node. Nodes within a cluster are sometimes heterogeneous. 

Ideally for a given job you would use homogeneous nodes (even though the cluster itself is heterogeneous). This is especially true with GPUs as you never want to mix GPU architectures or number of GPUs per LocalWorkerGroup. The former is due to the fact that CUDA operations/features different between GPU architectures (e.g. there are no tensor-cores on GPUs prior to Volta, and FMA was added on the Pascal). For the later, see my reply about why we run multiple workers per node rather than a single worker per node/container.

1. What's the benefit of `LocalWorkGroup`? Is it for performance reason? To be honest, I am not familiar with the backend of `pytorch.distributed`. I would like to learn more about it. From the code, I haven't found place to 1) `reduce between nodes` and 2)`sync state` when new node comes in. Is this covered within the `Store`? Or every time we restart the agent(scale out | scale in), we have to restart from the latest checkpoint?

Yes this is for performance reasons. The two most important ones are:

  1. Creating more nodes/containers is more expensive (in terms of resources) than packing more local workers onto fewer nodes. The overhead of running more nodes is more than you think, for instance, each node will use up an IP address (in the subnet).

  2. In the case of GPUs (but the same would apply to CPUs as well), by running multiple workers (one per GPU) on the same node you can leverage certain hardware features like NVLink which allows you to do GPU-to-GPU P2P data transfers and bypassing the CPU completely. This dramatically improves performance. So often times the collective operation is run in a hierarchical fashion: local then distributed.

2. What's the relationship between `agent` and the existing `train_loop`? Is it an alternative way or it will replace `train_loop` in the future?

The existing APIs: train_loop, State, etc are scheduled to be deprecated and removed in v0.2.0

@kiukchung
Copy link
Contributor

Thanks for sharing your use-case!

  1. elasticity-overhead: for a "true" elastic workload, the DL framework needs to support elasticity. AFAIK this is not the case for most DL frameworks that support distributed training (and even for many popular HPC frameworks like MPI). For instance torch assumes that once a process group is created, the workers are static and healthy. If anything happens to the workers, the entire process group is rendered in an "undefined" state and no work can progress. This manifests as either an exception or other workers getting "stuck". Map-reduce style frameworks like hadoop or spark naturally support elasticity because the workload itself maps nicely to partitions of data, unfortunately this is not the case in DL. Until frameworks natively support elasticity the best we can do is to ensure that workers are started/stopped as a group in an atomic fashion, which is what the agent does in our case.

  2. ease of use: Because the DL frameworks do not support elasticity natively, trying to add this yourself in the application is not entirely straight forward. This is exactly what we tried doing for our v0.1.0 and we realized that not only was it finicky at best but also that there were many things that the user had to implement correctly for everything to work in an elastic way.

Happy to discuss more, feel free to PM me kiuk@fb.com so that we can set up a meeting. Thanks!

@umialpha umialpha closed this Apr 22, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants