Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC]: Refactor Worker and ModelRunner to consolidate control plane communication #5552

Closed
3 tasks
stephanie-wang opened this issue Jun 14, 2024 · 13 comments
Closed
3 tasks
Labels

Comments

@stephanie-wang
Copy link
Contributor

stephanie-wang commented Jun 14, 2024

Motivation.

Currently, both the Worker and the ModelRunner classes contain multi-GPU control plane communication code, i.e. broadcast_tensor_dict calls. They look something like this:

class Worker:
  def execute_model(self, execute_model_req=None):
    # Do some broadcast here.
    ...
    return self.model_runner.execute_model(execute_model_req)

class ModelRunner:
  def execute_model(self, execute_model_req=None):
    # Do some more broadcast here.
    ...
    return model_executable(...)

Because the ModelRunner class contains both model execution code and multi-GPU control plane communication code, it makes it difficult to improve upon the performance:

  • Cannot swap out the control plane mechanism, e.g., using NCCL vs CPU-based serialization to move the inputs from the LLMEngine to the Workers
  • Cannot switch to an SPMD design, where the rank 0 worker is moved off of the driver and executes the same code as the rest of the workers
  • Difficult to overlap GPU data movement with other compute on Workers, because these are all done ad-hoc in different ModelRunner implementations.
  • Difficult to optimize control plane performance, because the broadcast calls are scattered throughout the code.

Proposed Change.

Refactor Worker and ModelRunner classes to consolidate all control plane communication to the Worker class. Both the Worker and ModelRunner classes should implement new worker-local methods to prepare inputs and execute the model. There should be no control plane communication within these methods; this could be enforced using runtime checks.

Here is the new proposed interface. The Worker and ModelRunner create a WorkerInput and a ModelInput respectively from the ExecuteModelRequest. The contract is that ExecuteModelRequest contains CPU-only metadata, while any tensors in WorkerInput and ModelInput should already be on the correct device. Now, the ModelRunnerBase class looks approximately like this:

class ModelRunnerBase(ABC, Generic[T]):
    """
    Model runner interface that abstracts a particular hardware and/or type of
    model. Model execution may communicate data with model runners in other
    processes, but it should not include control plane metadata communication.

    Each ModelRunnerBase subclass should define a corresponding ModelInput
    subclass.
    """
    @abstractmethod
    def prepare_model_input(
        self,
        seq_group_metadata_list: List[SequenceGroupMetadata],
    ) -> T:
        """
        Prepare the inputs to ModelRunnerBase.execute_model from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: T,
        kv_caches: Optional[List[torch.Tensor]],
    ) -> Optional[SamplerOutput]:
        """
        Execute the model on the given input.
        """
        raise NotImplementedError

This interface allows for cleaner separation between control plane communication vs. single-GPU logic. Each ModelRunner needs to explicitly state what inputs it requires by defining a ModelInput (subclass). This requires a bit more developer effort but should make it easier to introduce the optimizations discussed above.

We also add a new LocalOrDistributedWorkerBase. The idea behind this class is that as long as the developer implements this interface plus a ModelRunnerBase, they will get support out-of-the-box for both local and distributed execution. This class has a default implementation for execute_model that contains all of the control plane communication needed for distributed execution.

class LocalOrDistributedWorkerBase:
    """
    Partial implementation of WorkerBase that has a default `execute_model`
    definition to perform metadata transfer between workers when in distributed
    mode. Subclasses of this interface should use model runners that inherit
    from ModelRunnerBase, and should only need to implement worker-local logic.
    If custom control plane logic is needed to transfer metadata, or if the
    model runner cannot inherit from ModelRunnerBase, use WorkerBase instead.
    """
    @abstractmethod
    def prepare_worker_input(
            self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
        """
        Prepare the inputs to WorkerBase.execute_worker from an execution
        request. This method may move data to the worker's local device. It is
        not allowed to communicate with other workers or devices.
        """
        raise NotImplementedError

    @abstractmethod
    def execute_worker(self, worker_input: WorkerInput) -> None:
        """
        Process an execution request.
        """
        raise NotImplementedError

    def execute_model(
        self, execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> Optional[List[SamplerOutput]]:
      ...

Custom model runners: For workers / model runners that need some custom logic, they can inherit directly from the generic WorkerBase and do not need to follow these interfaces. In that case, they are responsible for implementing their own control plane communication too.

Speculative decoding: One complication is that the speculative decoding code goes back and forth between ExecuteModelRequest and ModelInput, whereas other workers only convert from ExecuteModelRequest to ModelInput. Thus, for the speculative decoding path, it's easier for now to keep the per-step broadcast. These extra k broadcasts could also be consolidated in the future, by either supporting ModelInput -> ExecuteModelRequest, or by making it possible to modify a ModelInput. Happily, the latter should be compatible with the solutions proposed in #5561.

Pipeline parallelism: In pipeline parallelism, workers before the last PP rank will return some intermediate tensor(s) instead of a SamplerOutput. To support this case, we should define an IntermediateOutput type for models that support PP. Then, we extend ModelRunnerBase.execute_model to return a Union[SamplerOutput, IntermediateOutput] instead of just a SamplerOutput.

Feedback Period.

One week. See #5408 for code.

CC List.

@youkaichao @zhuohan123 @zhisbug @cadedaniel @rkooo567

Any Other Things.

Checklist:

@andoorve
Copy link
Contributor

Great RFC! Have you had the chance to verify this with PP? #4412

@stephanie-wang
Copy link
Contributor Author

Great RFC! Have you had the chance to verify this with PP? #4412

Yes, definitely. Actually ideally we should design these together. One possibility is to merge #4412 in two parts, one for changes to worker-local execution and then a second PR to add the control plane to glue different workers together. That way, we can follow the interfaces that I'm proposing here and keep the control plane separate from model execution.

Concretely, the first PR would then contain these changes:

  • support multiple cache engines in one Worker
  • support workers executing one shard of a model

The proposed Worker.execute_model_local function would directly return the hidden states from the model shard instead of using send/recv calls inside the model definition. We would then glue together all of the p2p connections in the Worker class as a separate PR.

This way, it will be easier to try different control plane methods. We can use the approach you have in #4412. Another option is a new backend in Ray that that we've been developing to improve performance for static task graphs. I wrote an integration for this based off of an earlier version of #4412: graph definition and how to call it.

@andoorve
Copy link
Contributor

Yes agreed, we should chat more about this, what you're suggesting make sense to me. There are 3-4 optimizations that I know of that we can do on top of #4412, but my current plan is to have #4412 merged as a base on top of the current logic in order to have basic PP fully functional as soon as possible before moving on to performance refactoring.

@youkaichao
Copy link
Member

Does it make sense:

class WorkerBase:
    class WorkerInput:
        def get_broadcastable_data(self):
            pass
        
        @staticmethod
        def from_broadcast_data(data):
            pass

    def prepare_worker_input(self, seq_group_metadata_list) -> WorkerInput:
        pass

    def execute_worker(self, input_data: WorkerInput):
        pass

    def execute_model(self, seq_group_metadata_list):
        if self.is_driver_worker:
            worker_input = self.prepare_worker_input(seq_group_metadata_list)
            self.execute_worker(worker_input)
            model_input = self.model_runner.prepare_model_input(seq_group_metadata_list)
            data_to_broadcast = worker_input.get_broadcastable_data()
            data_to_broadcast.update(model_input.get_broadcastable_data())
            broadcast_tensor_dict(data_to_broadcast, src=0)
        else:
            data_to_broadcast = broadcast_tensor_dict(src=0)
            worker_input = self.WorkerInput.from_broadcast_data(data_to_broadcast)
            self.execute_worker(worker_input)
            model_input = self.ModelRunnerInput.from_broadcast_data(data_to_broadcast)
        self.model_runner.execute_model(model_input)

class ModelRunnerBase:
    class ModelRunnerInput:
        def get_broadcastable_data(self):
            pass
        
        @staticmethod
        def from_broadcast_data(data):
            pass
    
    def prepare_model_input(self, seq_group_metadata_list) -> ModelRunnerInput:
        pass

    def execute_model(self, input_data: ModelRunnerInput):
        pass

Then, say we want to add GPU worker and GPUModelRunner:

class GPUWorker(WorkerBase):
    class WorkerInput(WorkerBase.WorkerInput):
        def get_broadcastable_data(self):
            pass
        
        @staticmethod
        def from_broadcast_data(data):
            pass

    def prepare_worker_input(self, seq_group_metadata_list) -> WorkerInput:
        pass

    def execute_worker(self, input_data: WorkerInput):
        pass

class GPUModelRunner(ModelRunnerBase):
    class ModelRunnerInput(ModelRunnerBase.ModelRunnerInput):
        def get_broadcastable_data(self):
            pass
        
        @staticmethod
        def from_broadcast_data(data):
            pass
    
    def prepare_model_input(self, seq_group_metadata_list) -> ModelRunnerInput:
        pass

    def execute_model(self, input_data: ModelRunnerInput):
        pass

The control-plane communication is centralized in WorkerBase.execute_model .

@stephanie-wang
Copy link
Contributor Author

Yes agreed, we should chat more about this, what you're suggesting make sense to me. There are 3-4 optimizations that I know of that we can do on top of #4412, but my current plan is to have #4412 merged as a base on top of the current logic in order to have basic PP fully functional as soon as possible before moving on to performance refactoring.

Great! Can you say a bit more about what optimizations you were thinking of? The reason I suggested splitting #4412 is that I think it will be easier to introduce some optimizations for PP if we can merge in this refactor first.

@stephanie-wang
Copy link
Contributor Author

@youkaichao that sounds good to me. I can make those changes in #5408.

@andoorve
Copy link
Contributor

Great! Can you say a bit more about what optimizations you were thinking of? The reason I suggested splitting #4412 is that I think it will be easier to introduce some optimizations for PP if we can merge in this refactor first.

Off the top of my head, the following optimizations are possible:

  1. Merge cache engines and scheduler to take advantage of prefix caching.
  2. Reduce RPC overhead by sending/recving metadata instead of having multiple driver workers.
  3. Reduce CUDAGraph overhead with PP to make CUDAGraph more viable.

In general, I agree that it would potentially make optimizations easier (particularly 2 above). My concern is that we are prioritizing optimization prematurely here and further delaying the already delayed PP feature. IMO we should have the functionality available fully first, and then implement this refactoring on top of that.

cc: @zhuohan123 - this is what we chatted about last week

@stephanie-wang
Copy link
Contributor Author

Sounds good. Yes, I also don't want to block PP; I just think that it may actually be faster long-term to merge a version that's compatible with this refactor.

If you do not want to split #4412, at least I think we need to move the p2p communication out of the model definitions and into the Worker class. That will make the merge with this refactor smoother, plus it makes #4412 unit-testable.

@andoorve
Copy link
Contributor

Sounds good. Yes, I also don't want to block PP; I just think that it may actually be faster long-term to merge a version that's compatible with this refactor.

Makes sense! I'm not opposed to splitting #4412, but just think that if we are going to do so it's a good idea to coordinate more closely on the details so we have a good plan to get everything in with low friction. We can meet next week to talk in more detail if you are free.

If you do not want to split #4412, at least I think we need to move the p2p communication out of the model definitions and into the Worker class. That will make the merge with this refactor smoother, plus it makes #4412 unit-testable.

To give more context, I had thought about this initially but decided against it at first since different models send/recv different numbers of tensors. For example, gpt2 sends only the hidden states while LLaMa sends both hidden states and residuals. We would need to modify each model file to return a list of tensors to the ModelRunner layer as opposed to just hidden_states which I found more intrusive than what I have now which is localized to each model definition - and at the time there wasn't a good reason to do so. It makes sense to me to make this change given this new refactor.

Given what you proposed above it also gets a little more complicated in the context of this PR if I understand correctly:

The proposed Worker.execute_model_local function would directly return the hidden states from the model shard instead of using send/recv calls inside the model definition. We would then glue together all of the p2p connections in the Worker class as a separate PR.

We would have to bubble up hidden_states and residuals through the model definition to ModelRunner and then to Worker. If we calculated SamplingOutputs in ModelRunner the way it's done now in this PR as well we would also need to bubble that up to Worker or calculate SamplingOutputs in Worker instead. I'm not opposed to doing it this way, but this is an example of one of those details it would be good to discuss further and brainstorm if we can do it in a more elegant fashion.

@stephanie-wang
Copy link
Contributor Author

Yes, will reach out to find some time to chat!

We would have to bubble up hidden_states and residuals through the model definition to ModelRunner and then to Worker. If we calculated SamplingOutputs in ModelRunner the way it's done now in this PR as well we would also need to bubble that up to Worker or calculate SamplingOutputs in Worker instead. I'm not opposed to doing it this way, but this is an example of one of those details it would be good to discuss further and brainstorm if we can do it in a more elegant fashion.

For this, I don't think we need to bubble up the sampling procedure to Worker. Let's discuss more offline, but I'm imagining the following interface, where ModelRunner.execute_model would return either a SamplingOutput or anIntermediateOutput, which can be a dict of tensors (or a separate class). Then the Worker can have the logic to either return the output directly or send it to a different worker.

IntermediateOutput = Dict[str, torch.Tensor]

class ModelRunner:
  def execute_model(self, model_input: ModelInput) -> Union[List[SamplingOutput], IntermediateOutput]:
    pass

@zhuohan123
Copy link
Collaborator

I think this RFC makes a lot of sense. It's a great idea to put all of the communication logic in the same place. I previously misunderstood this as a bigger scope change that changes how we do control-plane communication. Some smaller questions about this RFC:

  • What is the difference and relationship between prepare_model_input_local and prepare_model_input? Can we just keep prepare_model_input?
  • What is the relationship between this RFC and the ray DAG change you are proposing?

@stephanie-wang
Copy link
Contributor Author

I think this RFC makes a lot of sense. It's a great idea to put all of the communication logic in the same place. I previously misunderstood this as a bigger scope change that changes how we do control-plane communication. Some smaller questions about this RFC:

Yes! For now this is just proposing a refactor that would allow changing the control plane more easily. This RFC doesn't propose any behavior changes (except for squashing some broadcasts).

* What is the difference and relationship between `prepare_model_input_local ` and `prepare_model_input`? Can we just keep `prepare_model_input`?

Ah, yes, I actually updated the RFC since @youkaichao suggested something similar. Now there is only prepare_model_input.

* What is the relationship between this RFC and the ray DAG change you are proposing?

This would make it easier to integrate Ray DAG / any other control plane method that broadcasts the ExecuteModelRequest to all workers instead of broadcasting the ModelInput. If we want to support Ray DAG right now, we need to update the control flow in the worker and model runner, e.g., to skip the tensor broadcasts. We can do that for the main Worker codepath, but it makes the code pretty messy and we'd have to do the same thing for every other worker and model runner that we want to support. With the new APIs, we can just override Worker.execute_model and only call the worker-local methods.

class RayDAGWorkerBase(LocalOrDistributedWorkerBase):
  def execute_model(self, execute_model_req: ExecuteModelRequest):
    worker_input = self.prepare_worker_input(execute_model_req)
    self.execute_worker(worker_input)
    model_input = self.model_runner.prepare_model_input(execute_model_req)
    return self.model_runner.execute_model(model_input)

@richardliaw
Copy link
Collaborator

Merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants