Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continous batching for single GPU LLM inference #2628

Merged
merged 49 commits into from Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
83b88aa
Make test/pytest/test_handler.py run stand-alone
mreso Jun 1, 2023
5138de5
Refactor if else statement
mreso Jun 1, 2023
afd183a
First working poc for streaming inference with continuous batching
mreso Jun 2, 2023
490f676
WIP: stopping criteria + caching
mreso Jun 3, 2023
ae3d64e
FE add continuousbatching
lxning Aug 31, 2023
5395cc5
fmt
lxning Aug 31, 2023
c73c46c
Fmt
lxning Aug 31, 2023
fd7b93a
Fix continuous batching PoC; remove batchDelay if jobs are being proc…
mreso Sep 15, 2023
e94ffc4
Add model_config.yaml
mreso Sep 19, 2023
e81b79d
Added ipynb for generate_next_token
mreso Sep 20, 2023
7eef111
Update notebook and move to right subfolder
mreso Sep 20, 2023
8f19e00
Fix buffer underruns; wait until enough bytes are in for reading
mreso Sep 20, 2023
cef6e33
Add bandaid for bug in our otf
mreso Sep 20, 2023
c6aece6
Added test for otf protocol with context
mreso Sep 21, 2023
2883a30
Fix buffer underrun; handle batch quota of zero correctly
mreso Sep 21, 2023
e1f000f
Initial implementation of prefill + decode without kv caching for now
mreso Sep 21, 2023
e4f9b56
adds missing __init__py files
mreso Sep 21, 2023
2f4ef20
WIP kv caching
mreso Sep 21, 2023
6db970c
Fixed kv cache; missing tuple;
mreso Sep 22, 2023
8c3a890
Cleaned up streaming handler code
mreso Sep 22, 2023
481ce10
Added cache cleaning
mreso Sep 22, 2023
6beb42c
clean up aggregator jobs forcontrol cmd
lxning Sep 26, 2023
eb396a5
fmt
lxning Sep 26, 2023
be28f29
fix streaming handler test
mreso Sep 26, 2023
1bc2154
Rename streaming test into continuous batching test
mreso Sep 26, 2023
b67942e
fmt
lxning Sep 26, 2023
1ec6982
Enable gpu usage in continuous batching unit test
mreso Sep 27, 2023
ae205be
Add llama to stream notebook
mreso Sep 27, 2023
48694a8
skip pull mgmt job if jobs is not empty
lxning Sep 27, 2023
67ec104
set pollMgmtJobStatus init value as false
lxning Sep 27, 2023
b2b33f1
fmt
lxning Sep 27, 2023
267bf07
only take describe request if jobsrepo is empty
lxning Sep 28, 2023
387f548
init job
lxning Sep 28, 2023
5f9b8fe
Remove cont batching job if connection to client gets closed
mreso Sep 29, 2023
e6d1df0
Fix and reenable cached request logic
mreso Sep 29, 2023
68829ba
Fix linter error
mreso Sep 29, 2023
2825ae0
merege origin/feature/continous_batching_for_streaming commit 68829ba
lxning Sep 29, 2023
289c702
fmt
lxning Sep 29, 2023
c5e7f7e
remove llama2-13b stream_handler.py
lxning Sep 29, 2023
088f330
revert otf
lxning Sep 30, 2023
c8b4604
update maxDelay logic
lxning Oct 2, 2023
d78ce15
replace size checking with isEmpty
lxning Oct 3, 2023
835e17d
Use handler section
mreso Oct 3, 2023
0b6e309
Fix linter errors
mreso Oct 3, 2023
b774002
Fix linter error in oft mesg handler
mreso Oct 3, 2023
8ac8626
Fix linter error in test_otf_codec_protocol.py
mreso Oct 3, 2023
7af5f2a
Merge remote-tracking branch 'origin/master' into feature/continous_b…
mreso Oct 3, 2023
4e802f0
Merge branch 'master' into feature/continous_batching_for_streaming
mreso Oct 3, 2023
7855a9c
Merge branch 'master' into feature/continous_batching_for_streaming
mreso Oct 3, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -55,6 +55,9 @@ public class ModelConfig {
*/
private boolean useJobTicket;

/** continuousBatching is a flag to enable continuous batching. */
private boolean continuousBatching;

public static ModelConfig build(Map<String, Object> yamlMap) {
ModelConfig modelConfig = new ModelConfig();
yamlMap.forEach(
Expand Down Expand Up @@ -158,6 +161,15 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
logger.warn("Invalid useJobTicket: {}, should be true or false", v);
}
break;
case "continuousBatching":
if (v instanceof Boolean) {
modelConfig.setContinuousBatching((boolean) v);
} else {
logger.warn(
"Invalid continuousBatching: {}, should be true or false",
v);
}
break;
default:
break;
}
Expand Down Expand Up @@ -313,6 +325,14 @@ public void setUseJobTicket(boolean useJobTicket) {
this.useJobTicket = useJobTicket;
}

public boolean isContinuousBatching() {
return continuousBatching;
}

public void setContinuousBatching(boolean continuousBatching) {
this.continuousBatching = continuousBatching;
}

public enum ParallelType {
NONE(""),
PP("pp"),
Expand Down
Expand Up @@ -146,4 +146,10 @@ public void sendError(int status, String error) {
.asRuntimeException());
}
}

@Override
public boolean isOpen() {
return ((ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver)
.isCancelled();
}
}
13 changes: 10 additions & 3 deletions frontend/server/src/main/java/org/pytorch/serve/job/Job.java
Expand Up @@ -44,9 +44,14 @@ public WorkerCommands getCmd() {
}

public boolean isControlCmd() {
return !WorkerCommands.PREDICT.equals(cmd)
&& !WorkerCommands.STREAMPREDICT.equals(cmd)
&& !WorkerCommands.DESCRIBE.equals(cmd);
switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case DESCRIBE:
return false;
default:
return true;
}
}

public RequestInput getPayload() {
Expand All @@ -73,4 +78,6 @@ public abstract void response(
Map<String, String> responseHeaders);

public abstract void sendError(int status, String error);

public abstract boolean isOpen();
}
Expand Up @@ -4,6 +4,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
Expand Down Expand Up @@ -258,4 +259,10 @@ public CompletableFuture<byte[]> getResponsePromise() {
public void setResponsePromise(CompletableFuture<byte[]> responsePromise) {
this.responsePromise = responsePromise;
}

@Override
public boolean isOpen() {
Channel c = ctx.channel();
return c.isOpen();
}
}
@@ -1,7 +1,8 @@
package org.pytorch.serve.util.codec;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -10,24 +11,27 @@ public final class CodecUtils {

public static final int END = -1;
public static final int BUFFER_UNDER_RUN = -3;
public static final long TIMEOUT_IN_MILLIS = 100;

private CodecUtils() {}

public static int readLength(ByteBuf byteBuf, int maxLength) {
int size = byteBuf.readableBytes();

if (size < 4) {
return BUFFER_UNDER_RUN;
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}

int len = byteBuf.readInt();
if (len > maxLength) {
throw new CorruptedFrameException(
throw new TooLongFrameException(
"Message size exceed limit: "
+ len
+ "\nConsider increasing the 'max_response_size' in 'config.properties' to fix.");
}

if (len > byteBuf.readableBytes()) {
return BUFFER_UNDER_RUN;
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}
return len;
}
Expand All @@ -38,7 +42,7 @@ public static String readString(ByteBuf byteBuf, int len) {

public static byte[] read(ByteBuf in, int len) {
if (len < 0) {
throw new CorruptedFrameException("Invalid message size: " + len);
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}

byte[] buf = new byte[len];
Expand All @@ -49,9 +53,19 @@ public static byte[] read(ByteBuf in, int len) {
public static Map<String, String> readMap(ByteBuf in, int len) {
HashMap<String, String> ret = new HashMap<>();
for (; len > 0; len--) {
int l = readLength(in, in.readableBytes());
int l =
readLength(
in,
6500000); // We replace len here with 6500000 as a workaround before we
// can fix the whole otf. Basically, were mixing up bytes
// (expected by readLength) and number of entries (given to
// readMap). If we only have a small number of entries our
// values in the map are not allowed to be very big as we
// compare the given number of entries with the byte size
// we're expecting after reading the length of the next
// message.
String key = readString(in, l);
l = readLength(in, in.readableBytes());
l = readLength(in, 6500000);
String val = readString(in, l);
ret.put(key, val);
}
Expand Down
Expand Up @@ -76,6 +76,12 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
out.writeInt(buf.length);
out.writeBytes(buf);

if (req.isCached()) {
out.writeInt(-1); // End of List
out.writeInt(-1); // End of List
return;
}

for (Map.Entry<String, String> entry : req.getHeaders().entrySet()) {
encodeField(entry.getKey(), out);
encodeField(entry.getValue(), out);
Expand All @@ -86,6 +92,7 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
encodeParameter(input, out);
}
out.writeInt(-1); // End of List
req.setCached(true);
}

private void encodeParameter(InputParameter parameter, ByteBuf out) {
Expand Down
Expand Up @@ -3,6 +3,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException;
import java.util.ArrayList;
import java.util.List;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
Expand Down Expand Up @@ -82,6 +83,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
resp.setPredictions(predictions);
out.add(resp);
completed = true;
} catch (NotEnoughDataDecoderException e) {
} finally {
if (!completed) {
in.resetReaderIndex();
Expand Down
Expand Up @@ -13,6 +13,7 @@ public class RequestInput {
private Map<String, String> headers;
private List<InputParameter> parameters;
private long clientExpireTS;
private boolean cached;

public RequestInput(String requestId) {
this.requestId = requestId;
Expand Down Expand Up @@ -71,4 +72,12 @@ public void setClientExpireTS(long clientTimeoutInMills) {
this.clientExpireTS = System.currentTimeMillis() + clientTimeoutInMills;
}
}

public boolean isCached() {
return cached;
}

public void setCached(boolean cached) {
this.cached = cached;
}
}
Expand Up @@ -17,8 +17,10 @@ public class BatchAggregator {

private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class);

private Model model;
private Map<String, Job> jobs;
protected Model model;
protected Map<String, Job> jobs;

public BatchAggregator() {}

public BatchAggregator(Model model) {
this.model = model;
Expand Down Expand Up @@ -171,4 +173,10 @@ public void sendError(BaseModelRequest message, String error, int status) {
}
jobs.clear();
}

public void cleanJobs() {
if (jobs != null) {
jobs.clear();
}
}
}
@@ -0,0 +1,149 @@
package org.pytorch.serve.wlm;

import java.util.Map;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.ModelInferenceRequest;
import org.pytorch.serve.util.messages.ModelLoadModelRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.util.messages.Predictions;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ContinuousBatching extends BatchAggregator {
private static final Logger logger = LoggerFactory.getLogger(ContinuousBatching.class);

public ContinuousBatching(Model model) {
super(model);
}

public BaseModelRequest getRequest(String threadName, WorkerState state)
throws InterruptedException {
int batchQuota = model.getBatchSize() - jobs.size();

ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName());

pollBatch(threadName, state, batchQuota);

if (model.isUseJobTicket() && jobs.isEmpty()) {
model.decNumJobTickets();
return req;
}

for (Job j : jobs.values()) {
if (j.isControlCmd()) {
if (jobs.size() > 1) {
throw new IllegalStateException(
"Received more than 1 control command. "
+ "Control messages should be processed/retrieved one at a time.");
}
RequestInput input = j.getPayload();
int gpuId = -1;
String gpu = input.getStringParameter("gpu");
if (gpu != null) {
gpuId = Integer.parseInt(gpu);
}
return new ModelLoadModelRequest(model, gpuId);
} else {
if (j.getCmd() == WorkerCommands.STREAMPREDICT) {
req.setCommand(WorkerCommands.STREAMPREDICT);
}
j.setScheduled();
req.addRequest(j.getPayload());
}
}
return req;
}

/**
* @param message: a response of a batch inference requests
* @return - true: either a non-stream response or last stream response is sent - false: a
* stream response (not include the last stream) is sent
*/
public boolean sendResponse(ModelWorkerResponse message) {
// TODO: Handle prediction level code
if (message.getCode() == 200) {
if (message.getPredictions().isEmpty()) {
// The jobs size is always 1 in the case control command
for (Map.Entry<String, Job> j : jobs.entrySet()) {
Job job = j.getValue();
if (job.isControlCmd()) {
jobs.clear();
return true;
}
}
}
for (Predictions prediction : message.getPredictions()) {
String jobId = prediction.getRequestId();
Job job = jobs.get(jobId);

if (job == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with 200 status code: " + jobId);
}

if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
job.response(
prediction.getResp(),
prediction.getContentType(),
prediction.getStatusCode(),
prediction.getReasonPhrase(),
prediction.getHeaders());
} else {
logger.warn(
"Drop response for inference request {} due to client timeout",
job.getPayload().getRequestId());
}
String streamNext =
prediction
.getHeaders()
.get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT);
if (streamNext != null && streamNext.equals("false")) {
jobs.remove(jobId);
} else if (!job.isOpen()) {
jobs.remove(job.getJobId());
logger.info(
"Connection to client got closed; Removing job: {}",
job.getPayload().getRequestId());
}
}
} else {
for (Map.Entry<String, Job> j : jobs.entrySet()) {
if (j.getValue() == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with non 200 status code: "
+ j.getKey());
}
Job job = j.getValue();
if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
job.sendError(message.getCode(), message.getMessage());
} else {
logger.warn(
"Drop error response for inference request {} due to client timeout",
job.getPayload().getRequestId());
}
}
jobs.clear();
}

return true;
}

private void pollBatch(String threadName, WorkerState state, int batchSize)
throws InterruptedException {
boolean pollMgmtJobStatus = false;
if (jobs.isEmpty()) {
pollMgmtJobStatus =
model.pollMgmtJob(
threadName,
(state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE,
jobs);
}

if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) {
model.pollInferJob(jobs, batchSize);
}
}
}