Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance]: Empirical Measurement of how to broadcast python object in vLLM #4440

Closed
youkaichao opened this issue Apr 28, 2024 · 7 comments
Labels
performance Performance-related issues

Comments

@youkaichao
Copy link
Sponsor Member

Proposal to improve performance

When we use tensor parallel in vLLM, the driver worker need to broadcast some metadata to all workers, such as the input, the lora requests, etc. This functionality is currently implemented in:

def broadcast_tensor_dict(

In essence, it uses torch.distributed.broadcast_object_list to broadcast a Python object. This function has many overhead. The overall procedure is:

image

There are three layers of overhead:

  1. device memory move: pickle works only for cpu memory. so we need to move data from cpu to device back and forth.
  2. pickled data of multiple objects are concated, leading to one memory copy
  3. two broadcast operation is needed, one for broadcasting the size of each pickled object, and the other for broadcasting data.

Current vLLM implementation packs the data in a list of size one, thus overhead 2 is eliminated:

torch.distributed.broadcast_object_list([metadata_list],
src=src,
group=group)

To remove overhead 1, we can use CPU operation to broadcast this kind of metadata.

In addition, if we can know the rough size of picked object, we can remove overhead 3 as well. Only one broadcast is required, which is the optimal case for broadcasting a Python object.

I have wrote some benchmark code in https://gist.github.com/youkaichao/b33fcd70286eb45a4a2d5a6dc32d096b and the result is in https://docs.google.com/spreadsheets/d/1c9xgR0fGvm6SROfk7vrjwOZdYnKQk9oOafWK4_KgOyo/edit?usp=sharing .

The short conclusion is:

  1. using cpu (gloo) to broadcast the data indeed works better than nccl (gpu). For small size metadata, the broadcast time reduces from 400us to 300us.
  2. if we can estimate the rough size, the broadcast time can be reduced to 100us. That requires us to design the object to be broadcast.

Report of performance regression

No response

Misc discussion on performance

No response

Your current environment (if you think it is necessary)

The output of `python collect_env.py`
@youkaichao youkaichao added the performance Performance-related issues label Apr 28, 2024
@youkaichao
Copy link
Sponsor Member Author

Note: the memory alignment feature depends on the fact that pickle format is self-ended:

s = [1] * 5
import pickle
d = pickle.dumps(s)
d = d + b"whatever"
import pickletools
pickletools.dis(d)

Output:

    0: \x80 PROTO      4
    2: \x95 FRAME      15
   11: ]    EMPTY_LIST
   12: \x94 MEMOIZE    (as 0)
   13: (    MARK
   14: K        BININT1    1
   16: K        BININT1    1
   18: K        BININT1    1
   20: K        BININT1    1
   22: K        BININT1    1
   24: e        APPENDS    (MARK at 13)
   25: .    STOP
highest protocol among opcodes = 4

There is a STOP code in the end. Therefore it is safe to pad/align the pickled data.

@cadedaniel
Copy link
Collaborator

The optimization makes sense to me (nice writeup!)

@AllenDou
Copy link
Contributor

Note: the memory alignment feature depends on the fact that pickle format is self-ended:

s = [1] * 5
import pickle
d = pickle.dumps(s)
d = d + b"whatever"
import pickletools
pickletools.dis(d)

Output:

    0: \x80 PROTO      4
    2: \x95 FRAME      15
   11: ]    EMPTY_LIST
   12: \x94 MEMOIZE    (as 0)
   13: (    MARK
   14: K        BININT1    1
   16: K        BININT1    1
   18: K        BININT1    1
   20: K        BININT1    1
   22: K        BININT1    1
   24: e        APPENDS    (MARK at 13)
   25: .    STOP
highest protocol among opcodes = 4

There is a STOP code in the end. Therefore it is safe to pad/align the pickled data.

s = ['s'] * 5
import pickle
d = pickle.dumps(s)
d = d + b"whatever"
import pickletools
pickletools.dis(d)

    0: \x80 PROTO      4
    2: \x95 FRAME      17
   11: ]    EMPTY_LIST
   12: \x94 MEMOIZE    (as 0)
   13: (    MARK
   14: \x8c     SHORT_BINUNICODE 's'
   17: \x94     MEMOIZE    (as 1)
   18: h        BINGET     1
   20: h        BINGET     1
   22: h        BINGET     1
   24: h        BINGET     1
   26: e        APPENDS    (MARK at 13)
   27: .    STOP
highest protocol among opcodes = 4

The result of pickle.dump does not always seem to be aligned to 4 bytes.

@youkaichao
Copy link
Sponsor Member Author

The result of pickle.dump does not always seem to be aligned to 4 bytes.

It does not matter though. The point is it is self-ended, so we can pad with arbitary bytes. Padding does not affect unpickle.

@sfc-gh-zhwang
Copy link
Contributor

Very cool!
Do we know how much latency does broadcast_tensor_dict contribute to the whole inference?

@youkaichao
Copy link
Sponsor Member Author

There are two broadcast_tensor_dict in vllm, one to broadcast blocks for copy/swap, the other for broadcasting input data (tokens, block tables, etc). The former takes about 0.4 ms, the latter takes more time. But I don't have a detailed measurement yet.

@youkaichao
Copy link
Sponsor Member Author

the performance of broadcasting python object is largely resolved by #5399 , in single node case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues
Projects
None yet
Development

No branches or pull requests

4 participants