Skip to content

Commit

Permalink
Continous batching for single GPU LLM inference (#2628)
Browse files Browse the repository at this point in the history
* Make test/pytest/test_handler.py run stand-alone

* Refactor if else statement

* First working poc for streaming inference with continuous batching

* WIP: stopping criteria + caching

* FE add continuousbatching

* fmt

* Fmt

* Fix continuous batching PoC; remove batchDelay if jobs are being processed

* Add model_config.yaml

* Added ipynb for generate_next_token

* Update notebook and move to right subfolder

* Fix buffer underruns; wait until enough bytes are in for reading

* Add bandaid for bug in our otf

* Added test for otf protocol with context

* Fix buffer underrun; handle batch quota of zero correctly

* Initial implementation of prefill + decode without kv caching for now

* adds missing __init__py files

* WIP kv caching

* Fixed kv cache; missing tuple;

* Cleaned up streaming handler code

* Added cache cleaning

* clean up aggregator jobs forcontrol cmd

* fmt

* fix streaming handler test

* Rename streaming test into continuous batching test

* fmt

* Enable gpu usage in continuous batching unit test

* Add llama to stream notebook

* skip pull mgmt job if jobs is not empty

* set pollMgmtJobStatus init value as false

* fmt

* only take describe request if jobsrepo is empty

* init job

* Remove cont batching job if connection to client gets closed

* Fix and reenable cached request logic

* Fix linter error

* fmt

* remove llama2-13b stream_handler.py

* revert otf

* update maxDelay logic

* replace size checking with isEmpty

* Use handler section

* Fix linter errors

* Fix linter error in oft mesg handler

* Fix linter error in test_otf_codec_protocol.py

---------

Co-authored-by: lxning <lninga@amazon.com>
  • Loading branch information
mreso and lxning committed Oct 4, 2023
1 parent 28d9d99 commit 8d12993
Show file tree
Hide file tree
Showing 26 changed files with 2,366 additions and 128 deletions.
Expand Up @@ -55,6 +55,9 @@ public class ModelConfig {
*/
private boolean useJobTicket;

/** continuousBatching is a flag to enable continuous batching. */
private boolean continuousBatching;

public static ModelConfig build(Map<String, Object> yamlMap) {
ModelConfig modelConfig = new ModelConfig();
yamlMap.forEach(
Expand Down Expand Up @@ -158,6 +161,15 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
logger.warn("Invalid useJobTicket: {}, should be true or false", v);
}
break;
case "continuousBatching":
if (v instanceof Boolean) {
modelConfig.setContinuousBatching((boolean) v);
} else {
logger.warn(
"Invalid continuousBatching: {}, should be true or false",
v);
}
break;
default:
break;
}
Expand Down Expand Up @@ -313,6 +325,14 @@ public void setUseJobTicket(boolean useJobTicket) {
this.useJobTicket = useJobTicket;
}

public boolean isContinuousBatching() {
return continuousBatching;
}

public void setContinuousBatching(boolean continuousBatching) {
this.continuousBatching = continuousBatching;
}

public enum ParallelType {
NONE(""),
PP("pp"),
Expand Down
Expand Up @@ -146,4 +146,10 @@ public void sendError(int status, String error) {
.asRuntimeException());
}
}

@Override
public boolean isOpen() {
return ((ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver)
.isCancelled();
}
}
13 changes: 10 additions & 3 deletions frontend/server/src/main/java/org/pytorch/serve/job/Job.java
Expand Up @@ -44,9 +44,14 @@ public WorkerCommands getCmd() {
}

public boolean isControlCmd() {
return !WorkerCommands.PREDICT.equals(cmd)
&& !WorkerCommands.STREAMPREDICT.equals(cmd)
&& !WorkerCommands.DESCRIBE.equals(cmd);
switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case DESCRIBE:
return false;
default:
return true;
}
}

public RequestInput getPayload() {
Expand All @@ -73,4 +78,6 @@ public abstract void response(
Map<String, String> responseHeaders);

public abstract void sendError(int status, String error);

public abstract boolean isOpen();
}
Expand Up @@ -4,6 +4,7 @@

import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.DefaultHttpContent;
Expand Down Expand Up @@ -258,4 +259,10 @@ public CompletableFuture<byte[]> getResponsePromise() {
public void setResponsePromise(CompletableFuture<byte[]> responsePromise) {
this.responsePromise = responsePromise;
}

@Override
public boolean isOpen() {
Channel c = ctx.channel();
return c.isOpen();
}
}
@@ -1,7 +1,8 @@
package org.pytorch.serve.util.codec;

import io.netty.buffer.ByteBuf;
import io.netty.handler.codec.CorruptedFrameException;
import io.netty.handler.codec.TooLongFrameException;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
Expand All @@ -10,24 +11,27 @@ public final class CodecUtils {

public static final int END = -1;
public static final int BUFFER_UNDER_RUN = -3;
public static final long TIMEOUT_IN_MILLIS = 100;

private CodecUtils() {}

public static int readLength(ByteBuf byteBuf, int maxLength) {
int size = byteBuf.readableBytes();

if (size < 4) {
return BUFFER_UNDER_RUN;
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}

int len = byteBuf.readInt();
if (len > maxLength) {
throw new CorruptedFrameException(
throw new TooLongFrameException(
"Message size exceed limit: "
+ len
+ "\nConsider increasing the 'max_response_size' in 'config.properties' to fix.");
}

if (len > byteBuf.readableBytes()) {
return BUFFER_UNDER_RUN;
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}
return len;
}
Expand All @@ -38,7 +42,7 @@ public static String readString(ByteBuf byteBuf, int len) {

public static byte[] read(ByteBuf in, int len) {
if (len < 0) {
throw new CorruptedFrameException("Invalid message size: " + len);
throw new NotEnoughDataDecoderException("Did not receive enough data.");
}

byte[] buf = new byte[len];
Expand All @@ -49,9 +53,19 @@ public static byte[] read(ByteBuf in, int len) {
public static Map<String, String> readMap(ByteBuf in, int len) {
HashMap<String, String> ret = new HashMap<>();
for (; len > 0; len--) {
int l = readLength(in, in.readableBytes());
int l =
readLength(
in,
6500000); // We replace len here with 6500000 as a workaround before we
// can fix the whole otf. Basically, were mixing up bytes
// (expected by readLength) and number of entries (given to
// readMap). If we only have a small number of entries our
// values in the map are not allowed to be very big as we
// compare the given number of entries with the byte size
// we're expecting after reading the length of the next
// message.
String key = readString(in, l);
l = readLength(in, in.readableBytes());
l = readLength(in, 6500000);
String val = readString(in, l);
ret.put(key, val);
}
Expand Down
Expand Up @@ -76,6 +76,12 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
out.writeInt(buf.length);
out.writeBytes(buf);

if (req.isCached()) {
out.writeInt(-1); // End of List
out.writeInt(-1); // End of List
return;
}

for (Map.Entry<String, String> entry : req.getHeaders().entrySet()) {
encodeField(entry.getKey(), out);
encodeField(entry.getValue(), out);
Expand All @@ -86,6 +92,7 @@ private void encodeRequest(RequestInput req, ByteBuf out) {
encodeParameter(input, out);
}
out.writeInt(-1); // End of List
req.setCached(true);
}

private void encodeParameter(InputParameter parameter, ByteBuf out) {
Expand Down
Expand Up @@ -3,6 +3,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.ByteToMessageDecoder;
import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException;
import java.util.ArrayList;
import java.util.List;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
Expand Down Expand Up @@ -82,6 +83,7 @@ protected void decode(ChannelHandlerContext ctx, ByteBuf in, List<Object> out) {
resp.setPredictions(predictions);
out.add(resp);
completed = true;
} catch (NotEnoughDataDecoderException e) {
} finally {
if (!completed) {
in.resetReaderIndex();
Expand Down
Expand Up @@ -13,6 +13,7 @@ public class RequestInput {
private Map<String, String> headers;
private List<InputParameter> parameters;
private long clientExpireTS;
private boolean cached;

public RequestInput(String requestId) {
this.requestId = requestId;
Expand Down Expand Up @@ -71,4 +72,12 @@ public void setClientExpireTS(long clientTimeoutInMills) {
this.clientExpireTS = System.currentTimeMillis() + clientTimeoutInMills;
}
}

public boolean isCached() {
return cached;
}

public void setCached(boolean cached) {
this.cached = cached;
}
}
Expand Up @@ -17,8 +17,10 @@ public class BatchAggregator {

private static final Logger logger = LoggerFactory.getLogger(BatchAggregator.class);

private Model model;
private Map<String, Job> jobs;
protected Model model;
protected Map<String, Job> jobs;

public BatchAggregator() {}

public BatchAggregator(Model model) {
this.model = model;
Expand Down Expand Up @@ -171,4 +173,10 @@ public void sendError(BaseModelRequest message, String error, int status) {
}
jobs.clear();
}

public void cleanJobs() {
if (jobs != null) {
jobs.clear();
}
}
}
@@ -0,0 +1,149 @@
package org.pytorch.serve.wlm;

import java.util.Map;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.util.messages.BaseModelRequest;
import org.pytorch.serve.util.messages.ModelInferenceRequest;
import org.pytorch.serve.util.messages.ModelLoadModelRequest;
import org.pytorch.serve.util.messages.ModelWorkerResponse;
import org.pytorch.serve.util.messages.Predictions;
import org.pytorch.serve.util.messages.RequestInput;
import org.pytorch.serve.util.messages.WorkerCommands;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ContinuousBatching extends BatchAggregator {
private static final Logger logger = LoggerFactory.getLogger(ContinuousBatching.class);

public ContinuousBatching(Model model) {
super(model);
}

public BaseModelRequest getRequest(String threadName, WorkerState state)
throws InterruptedException {
int batchQuota = model.getBatchSize() - jobs.size();

ModelInferenceRequest req = new ModelInferenceRequest(model.getModelName());

pollBatch(threadName, state, batchQuota);

if (model.isUseJobTicket() && jobs.isEmpty()) {
model.decNumJobTickets();
return req;
}

for (Job j : jobs.values()) {
if (j.isControlCmd()) {
if (jobs.size() > 1) {
throw new IllegalStateException(
"Received more than 1 control command. "
+ "Control messages should be processed/retrieved one at a time.");
}
RequestInput input = j.getPayload();
int gpuId = -1;
String gpu = input.getStringParameter("gpu");
if (gpu != null) {
gpuId = Integer.parseInt(gpu);
}
return new ModelLoadModelRequest(model, gpuId);
} else {
if (j.getCmd() == WorkerCommands.STREAMPREDICT) {
req.setCommand(WorkerCommands.STREAMPREDICT);
}
j.setScheduled();
req.addRequest(j.getPayload());
}
}
return req;
}

/**
* @param message: a response of a batch inference requests
* @return - true: either a non-stream response or last stream response is sent - false: a
* stream response (not include the last stream) is sent
*/
public boolean sendResponse(ModelWorkerResponse message) {
// TODO: Handle prediction level code
if (message.getCode() == 200) {
if (message.getPredictions().isEmpty()) {
// The jobs size is always 1 in the case control command
for (Map.Entry<String, Job> j : jobs.entrySet()) {
Job job = j.getValue();
if (job.isControlCmd()) {
jobs.clear();
return true;
}
}
}
for (Predictions prediction : message.getPredictions()) {
String jobId = prediction.getRequestId();
Job job = jobs.get(jobId);

if (job == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with 200 status code: " + jobId);
}

if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
job.response(
prediction.getResp(),
prediction.getContentType(),
prediction.getStatusCode(),
prediction.getReasonPhrase(),
prediction.getHeaders());
} else {
logger.warn(
"Drop response for inference request {} due to client timeout",
job.getPayload().getRequestId());
}
String streamNext =
prediction
.getHeaders()
.get(org.pytorch.serve.util.messages.RequestInput.TS_STREAM_NEXT);
if (streamNext != null && streamNext.equals("false")) {
jobs.remove(jobId);
} else if (!job.isOpen()) {
jobs.remove(job.getJobId());
logger.info(
"Connection to client got closed; Removing job: {}",
job.getPayload().getRequestId());
}
}
} else {
for (Map.Entry<String, Job> j : jobs.entrySet()) {
if (j.getValue() == null) {
throw new IllegalStateException(
"Unexpected job in sendResponse() with non 200 status code: "
+ j.getKey());
}
Job job = j.getValue();
if (job.getPayload().getClientExpireTS() > System.currentTimeMillis()) {
job.sendError(message.getCode(), message.getMessage());
} else {
logger.warn(
"Drop error response for inference request {} due to client timeout",
job.getPayload().getRequestId());
}
}
jobs.clear();
}

return true;
}

private void pollBatch(String threadName, WorkerState state, int batchSize)
throws InterruptedException {
boolean pollMgmtJobStatus = false;
if (jobs.isEmpty()) {
pollMgmtJobStatus =
model.pollMgmtJob(
threadName,
(state == WorkerState.WORKER_MODEL_LOADED) ? 0 : Long.MAX_VALUE,
jobs);
}

if (!pollMgmtJobStatus && state == WorkerState.WORKER_MODEL_LOADED) {
model.pollInferJob(jobs, batchSize);
}
}
}

0 comments on commit 8d12993

Please sign in to comment.