-
Notifications
You must be signed in to change notification settings - Fork 859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stateful inference #2513
stateful inference #2513
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not familiar with the implementation details so I can only comment on the API
|
||
self.sequence_ids = {} | ||
results = [] | ||
for idx, row in enumerate(data): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm, is it the case that batchSize
is the least upper bound of len(data)
, i.e.len(data) <= batchSize
and for all l
such that len(data) <= l
, batchSize <= l
?
Is it possible for two separate requests to get batched to this worker? If so, suppose there are two separate streaming requests that are batched to this worker. What happens if one client is much much faster than the other? Do we throttle the faster client to match the speed of the slower one by buffering the faster client's messages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Q1: yes, len(data) <= batchSize. data is a batch of requests received at realtime.
-
Q2: Yes, a batch of requests comes from different sequences. eg. len(data) = 4, it means there are 4 sequences. Each sequence has its own dedicated jobQ. Only the parameter "maxBatchDelay" decides the msec of batching a group of requests from different sequences. In other words, the different traffic volume of different sequences has no impact on batching latency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok but if two streams produce data at drastically different rates, how do you keep the batch index coherent? For instance, fix a stateful worker. At time t_0
, the worker receives data d_0_0
and d_1_0
from two streams. So then len(data) == 2
and data[0]
is the payload for stream 0 and data[1]
is the payload for stream 1.
At t_1
, stream 0 does not produce any data because it took longer than maxBatchDelay
, but stream 1 produces data d_1_1
. So then len(data) == 1
and data[0]
is the payload for stream 1. In the line below, idx == 0
, so then you fetch the sequence ID for index 0. It seems like this would fetch the sequence ID for stream 0,
sequence_id = self.context.get_sequence_id(idx)
but you actually want the sequence ID for stream 1. Am I understanding the API semantics correctly? Perhaps I am misunderstanding how context.get_sequence_id
works. Does it keep track of which stream corresponds to the elements of the data
list passed to the handler?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
each request's sequence id is added into its header with key = "ts_request_sequence_id". Backend can get a request's sequence id via its header. This can guarantee we can always get the sequence id regardless the real batch size is changed or the request of a sequence enters into a different batch slot.
Codecov Report
@@ Coverage Diff @@
## master #2513 +/- ##
==========================================
- Coverage 72.44% 72.43% -0.02%
==========================================
Files 85 85
Lines 3963 3965 +2
Branches 58 58
==========================================
+ Hits 2871 2872 +1
- Misses 1088 1089 +1
Partials 4 4
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
&& model.getParallelLevel() > 1 | ||
&& model.getParallelType() | ||
!= ModelConfig.ParallelType.PP) | ||
? model.getParallelLevel() | ||
: 1; | ||
List<CompletableFuture<Void>> futureRequests = new ArrayList<>(repeats); | ||
for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, in that case we should move the check out of the loop condition and start from the beginning. Otherwise we're getting an undefined delay before we retry sending the job through the check for results (that cannot be there as we never sent the request).
CompletableFuture.runAsync( | ||
() -> { | ||
Job job = jobGroup.pollJob((long) model.getMaxBatchDelay()); | ||
if (job != null) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you change this part into pushing the jobs instead of polling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're already in a good shape, left some comments.
break; | ||
} | ||
|
||
if (cmd == WorkerCommands.STREAMPREDICT2) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate still persists
examples/large_models/Huggingface_accelerate/llama2/custom_handler_code.py
Outdated
Show resolved
Hide resolved
frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java
Outdated
Show resolved
Hide resolved
frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java
Outdated
Show resolved
Hide resolved
frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java
Show resolved
Hide resolved
frontend/server/src/main/java/org/pytorch/serve/wlm/SequenceBatchAggregator.java
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be deleted.
Description
Please read our CONTRIBUTING.md prior to creating your first pull request.
Please include a summary of the feature or issue being fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes #(issue)
Type of change
Please delete options that are not relevant.
Feature/Issue validation/testing
Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.
Regression test
reg.txt
Normal Sequential Inference
Checklist: