[Performance] Optimize e2e overheads: Reduce python allocations#7162
[Performance] Optimize e2e overheads: Reduce python allocations#7162comaniac merged 4 commits intovllm-project:mainfrom
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge). To run full CI, you can do one of these:
🚀 |
|
This is one way to introduce incremental prepare input. Will review soon. |
|
I took a brief look and have the following impressions:
Please correct me if I misunderstood anything. IIUC, It seems to me that this optimization should be done with multi-step worker, as it focuses exactly on the scenarios listed above. For general cases, this optimization introduces decoding specific branches (e.g., |
|
@comaniac and I discussed the relation to multi-step. In general, a micro-benchmark shows that the benefit from flash-attn decode optimization in this PR is around 2-3%, and most of the benefit comes from python object allocation reductions. There are two possible ways to go about this:
|
|
/ready |
|
Thanks @alexm-neuralmagic this is great!
I think (1) isn't a concern since the proportion of steps in which reqs are added or removed is typically small.
I didn't look closely enough yet but it presumably it will still work once in this mode once the prefill tokens have been exhausted i.e. once the batch returns to decode-only state?
Given my comments above I think both parts of this would still make sense to include, but perhaps separate into two separate PRs anyhow? |
njhill
left a comment
There was a problem hiding this comment.
@alexm-neuralmagic the speed-up from allocation reduction is really encouraging!
Added some comments from quick glance through, will try to look more closely soon.
There was a problem hiding this comment.
could de-dup a bit here by just setting those three values as vars in the if/else and then having a single call to the constructor?
There was a problem hiding this comment.
Thanks for the suggestion, applied this change in #7206
There was a problem hiding this comment.
Could make this a bit more concise (same for others below):
| if self.is_single_seq: | |
| if self.seqs[0].status == status: | |
| return self.seqs | |
| else: | |
| return [] | |
| else: | |
| return [seq for seq in self.seqs if seq.status == status] | |
| if self.is_single_seq: | |
| return self.seqs if self.seqs[0].status == status else [] | |
| return [seq for seq in self.seqs if seq.status == status] |
There was a problem hiding this comment.
Good idea, changed all of them to the form you proposed.
There was a problem hiding this comment.
I feel like the above could be simplified .. maybe collect all the things to be cleared and all the things to be zeroed and then have a single loop over the sequences? These lists themselves could possibly also be reused.
There was a problem hiding this comment.
This one is harder to generalize because it uses different fields. I tried to do lists, but then you still need to separate to different fields and it complicates things again. I will see if I can make the code cleaner here.
There was a problem hiding this comment.
Please note that we were planning to use dataclass with kw_only=True, but it is supported in Python 3.10 so we fallback to the current implementation. We should eventually use dataclass tho.
There was a problem hiding this comment.
The standard python way of doing this is:
input_tokens = [
in_toks for inter_data in self.inter_data_list
for in_toks in inter_data.input_tokens
]I'd guess that this would be more efficient but can never be sure, would have to microbench it...
There was a problem hiding this comment.
Yeah, this is what flatten_2d_lists(..) was doing, but extend() is faster.
There was a problem hiding this comment.
We should change flatten_2d_list to use extend then.
There was a problem hiding this comment.
The prior code wasn't doing the same thing. Technically this isn't flattening 2d lists - it's a list of objects each having a list inter_tokens field. So flatten_2d_lists() isn't appropriate here anyhow. If you look at the prior code is calling flatten_2d_lists twice, which is probably the reason for the excess allocations and slower speed.
There was a problem hiding this comment.
Hmm yeah you're right. I was hoping there's a way to extract this logic to be a utility function, but seems not trivial.
There was a problem hiding this comment.
I ran a benchmark, I guess the extend version is still much faster ! (more than 2x)
Isn't we continue adding/removing decoding requests at every step as long as there are new requests coming?
I'm not sure if prefill requests exhausting is a reasonable assumption especially for high QPS. In summary I feel your assumptions make sense in offline batching, but I'm not sure about online serving. Meanwhile, considering the code complexity this PR will introduce, I personally would prefer to isolate it to the scenario it fits into. Of course we could still take the general Python code optimization. |
Offline is getting more interest lately with SDG etc. It obviously depends on the particular workload in terms of input/output split sizes and request patterns but I think even in the online serving case the batch constituency doesn't change for a significant proportion of steps. We should probably add metrics for this if we don't have already :)
Again I think this is very workload dependent, what about for "Write me an essay about x" type use cases.
I agree about taking care to minimze additional complexity. |
There was a problem hiding this comment.
Extract this common part to be an inner function?
There was a problem hiding this comment.
I'm a bit worry about this because block_table and block_table_ids are not strongly associated. If someone updated block_table somewhere else without using the update function then these 2 attributes are mismatched. Can we make block_table and block_table_ids properties and update them with setters? In this way you also don't need to change other parts in block manager
There was a problem hiding this comment.
This is a bit complicated to do since "block_tables" is a dict of seq_ids . I have introduced update_block_table() and append_block() object functions to modify these two variables together, and made sure that every place in the class that needs to modify block_tables, is using these two functions. I could also change the BlockTable type from a List[PhysicalTokenBlock] to a full class object, so it can hold the cached ids, like we did for block_manager_v2. Not sure what's better here.
There was a problem hiding this comment.
@comaniac I was able to fix it by doing the same thing we did originally for block_manager_v2. By introducing a single class that holds the "ids" and overrides "list methods", the code in block_manager_v1 does not need to change anymore and is cleaner.
There was a problem hiding this comment.
- I feel we could make this cache a common class, and use it like metadata=PyObjectCache(SequenceGroupMetadata).
- Please add docstring to the class and methods.
There was a problem hiding this comment.
Good idea! Introduced a single python object caching class and reused it in both cases.
There was a problem hiding this comment.
It may be more straightforward to use a single API for this, such as
| seq_group_metadata = seq_group_metadata_cache.get_object() | |
| seq_group_metadata.__init__( | |
| seq_group_metadata = seq_group_metadata_cache.new_object( |
There was a problem hiding this comment.
Refactored to a single API call in both cases
There was a problem hiding this comment.
Please note that we were planning to use dataclass with kw_only=True, but it is supported in Python 3.10 so we fallback to the current implementation. We should eventually use dataclass tho.
There was a problem hiding this comment.
Can we use a common cache class for this as well instead of introducing another similar class?
There was a problem hiding this comment.
Refactored to a common class
There was a problem hiding this comment.
We should change flatten_2d_list to use extend then.
There was a problem hiding this comment.
Good catch, removed.
There was a problem hiding this comment.
Can avoid allocating lists here, same below
| input_tokens.extend([0] * cuda_graph_pad_size) | |
| input_positions.extend([0] * cuda_graph_pad_size) | |
| input_tokens.extend(itertools.repeat(0, cuda_graph_pad_size)) | |
| input_positions.extend(itertools.repeat(0, cuda_graph_pad_size)) |
There was a problem hiding this comment.
Nice trick, did not knew about this one.
There was a problem hiding this comment.
| for _, cache in cls.inter_data_cache.items(): | |
| for cache in cls.inter_data_cache.values(): |
|
@youkaichao PTAL and let's try to merge this PR by today or tomorrow. |
There was a problem hiding this comment.
| self.block_tables[wait_seqs[0].seq_id] = block_table | |
| else: | |
| for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): | |
| self.block_tables[seq.seq_id] = block_table.copy() | |
| self.block_tables[seq.seq_id] = block_table | |
| else: | |
| for seq in wait_seqs: | |
| self.block_tables[seq.seq_id] = block_table.copy() |
does this make sense?
There was a problem hiding this comment.
not related with this PR, but I think we can do it in a followup PR, to refactor the way we store the data, e.g. self.cached_data[seq_id].input_positions. Then we can have just one for-loop to iterate over seq_id, and set all fields in one pass.
youkaichao
left a comment
There was a problem hiding this comment.
thanks for the great optimization! my previous concern is multiple instances might share (and reuse) the same object. This is fixed now.
left one nit comment and one possible followup improvement. LGTM in general!
cbaeefb to
1d2e873
Compare
|
working on fixing tests |
dc4ce31 to
2ecc115
Compare
|
fixed all of the failing tests, should be green now I hope. |
There was a problem hiding this comment.
| self._obj_cache = [] | |
| for _ in range(128): | |
| self._obj_cache.append(self._obj_builder()) | |
| self._obj_cache = [self._obj_builder() for _ in range(128)] |
There was a problem hiding this comment.
I ran a benchmark, I guess the extend version is still much faster ! (more than 2x)
|
|
||
| return [seq for seq in self.seqs if not seq.is_finished()] | ||
|
|
||
| def get_finished_seqs(self) -> List[Sequence]: |
|
Thanks @alexm-neuralmagic for this, it's a huge speedup! It looks like there's a still couple of unaddressed comments though? #7162 (comment) and #7162 (comment) |
…-project#7162) Signed-off-by: Alvant <alvasian@yandex.ru>
…-project#7162) Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
This PR introduces a bunch of end-to-end overhead optimizations to reduce python object allocations/deallocations over scheduler iterations. In particular:
End-to-end throughput of Llama3 8B on 1xH100 is 24% faster with this PR. Command used:
python3 benchmark_throughput.py --model meta-llama/Meta-Llama-3.1-8B-Instruct --backend vllm --input-len 512 --output-len 256 --num-prompts 1000 --tensor-parallel 1Main branch from 08/05/2024:
Throughput: 19.48 requests/s, 14962.18 tokens/sThis PR:
Throughput: 24.32 requests/s, 18123.51 tokens/s