Skip to content

Commit

Permalink
Added a new worker command to handle the KServe OIP inference request.
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>
  • Loading branch information
andyi2it committed Jan 15, 2024
1 parent 572ec36 commit 0a2cef9
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,9 @@ public void modelReady(ModelReadyRequest request, StreamObserver<ModelReadyRespo
throw new ModelNotFoundException("Model not found: " + modelName);
}

// int numScaled = model.getMinWorkers();
// int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName());
// isModelReady = numHealthy >= numScaled;

List<WorkerThread> workers = modelManager.getWorkers(model.getModelVersionName());
for (WorkerThread worker : workers) {
isModelReady = worker.isRunning() && worker.getState() == WorkerState.WORKER_MODEL_LOADED;

}
int numScaled = model.getMinWorkers();
int numHealthy = modelManager.getNumHealthyWorkers(model.getModelVersionName());
isModelReady = numHealthy >= numScaled;

ModelReadyResponse modelReadyResponse = ModelReadyResponse.newBuilder()
.setReady(isModelReady)
Expand Down Expand Up @@ -247,7 +241,7 @@ public void modelInfer(ModelInferRequest request, StreamObserver<ModelInferRespo
try {
ModelManager modelManager = ModelManager.getInstance();
inputData.addParameter(new InputParameter("body", byteArray, contentsType));
Job job = new GRPCJob(responseObserver, modelName, modelVersion, inputData, WorkerCommands.PREDICT);
Job job = new GRPCJob(responseObserver, modelName, modelVersion, inputData, WorkerCommands.OIPPREDICT);

if (!modelManager.addJob(job)) {
String responseMessage = ApiUtils.getStreamingInferenceErrorResponseMessage(modelName, modelVersion);
Expand Down
118 changes: 57 additions & 61 deletions frontend/server/src/main/java/org/pytorch/serve/job/GRPCJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,62 +117,30 @@ public void response(
Map<String, String> responseHeaders) {
ByteString output = ByteString.copyFrom(body);
WorkerCommands cmd = this.getCmd();
Gson gson = new Gson();
String jsonResponse = output.toStringUtf8();
JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class);

switch (cmd) {
case PREDICT:
case STREAMPREDICT:
case STREAMPREDICT2:
// condition for OIP grpc ModelInfer Call
if (ConfigManager.getInstance().isOpenInferenceProtocol() && isResponseStructureOIP(jsonObject)) {
if (((ServerCallStreamObserver<ModelInferResponse>) modelInferResponseObserver)
.isCancelled()) {
logger.warn(
"grpc client call already cancelled, not able to send this response for requestId: {}",
getPayload().getRequestId());
return;
}
ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder();
responseBuilder.setId(jsonObject.get("id").getAsString());
responseBuilder.setModelName(jsonObject.get("model_name").getAsString());
responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString());
JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray();

for (JsonElement element : jsonOutputs) {
InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder();
outputBuilder.setName(element.getAsJsonObject().get("name").getAsString());
outputBuilder.setDatatype(element.getAsJsonObject().get("datatype").getAsString());
JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray();
shapeArray.forEach(shapeElement -> outputBuilder.addShape(shapeElement.getAsLong()));
setOutputContents(element, outputBuilder);
responseBuilder.addOutputs(outputBuilder);

}
modelInferResponseObserver.onNext(responseBuilder.build());
modelInferResponseObserver.onCompleted();
} else {
ServerCallStreamObserver<PredictionResponse> responseObserver =
(ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver;
cancelHandler(responseObserver);
PredictionResponse reply =
PredictionResponse.newBuilder().setPrediction(output).build();
responseObserver.onNext(reply);
if (cmd == WorkerCommands.PREDICT
|| (cmd == WorkerCommands.STREAMPREDICT
&& responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
responseObserver.onCompleted();
logQueueTime();
} else if (cmd == WorkerCommands.STREAMPREDICT2
&& (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null
|| responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
logQueueTime();
}
ServerCallStreamObserver<PredictionResponse> responseObserver =
(ServerCallStreamObserver<PredictionResponse>) predictionResponseObserver;
cancelHandler(responseObserver);
PredictionResponse reply =
PredictionResponse.newBuilder().setPrediction(output).build();
responseObserver.onNext(reply);
if (cmd == WorkerCommands.PREDICT
|| (cmd == WorkerCommands.STREAMPREDICT
&& responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
responseObserver.onCompleted();
logQueueTime();
} else if (cmd == WorkerCommands.STREAMPREDICT2
&& (responseHeaders.get(RequestInput.TS_STREAM_NEXT) == null
|| responseHeaders
.get(RequestInput.TS_STREAM_NEXT)
.equals("false"))) {
logQueueTime();
}
break;
case DESCRIBE:
Expand All @@ -193,6 +161,36 @@ public void response(
managementResponseObserver, Status.NOT_FOUND, e);
}
break;
case OIPPREDICT:
Gson gson = new Gson();
String jsonResponse = output.toStringUtf8();
JsonObject jsonObject = gson.fromJson(jsonResponse, JsonObject.class);
if (((ServerCallStreamObserver<ModelInferResponse>) modelInferResponseObserver)
.isCancelled()) {
logger.warn(
"grpc client call already cancelled, not able to send this response for requestId: {}",
getPayload().getRequestId());
return;
}
ModelInferResponse.Builder responseBuilder = ModelInferResponse.newBuilder();
responseBuilder.setId(jsonObject.get("id").getAsString());
responseBuilder.setModelName(jsonObject.get("model_name").getAsString());
responseBuilder.setModelVersion(jsonObject.get("model_version").getAsString());
JsonArray jsonOutputs = jsonObject.get("outputs").getAsJsonArray();

for (JsonElement element : jsonOutputs) {
InferOutputTensor.Builder outputBuilder = InferOutputTensor.newBuilder();
outputBuilder.setName(element.getAsJsonObject().get("name").getAsString());
outputBuilder.setDatatype(element.getAsJsonObject().get("datatype").getAsString());
JsonArray shapeArray = element.getAsJsonObject().get("shape").getAsJsonArray();
shapeArray.forEach(shapeElement -> outputBuilder.addShape(shapeElement.getAsLong()));
setOutputContents(element, outputBuilder);
responseBuilder.addOutputs(outputBuilder);

}
modelInferResponseObserver.onNext(responseBuilder.build());
modelInferResponseObserver.onCompleted();
break;
default:
break;
}
Expand Down Expand Up @@ -244,6 +242,14 @@ public void sendError(int status, String error) {
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
break;
case OIPPREDICT:
modelInferResponseObserver.onError(
responseStatus
.withDescription(error)
.augmentDescription(
"org.pytorch.serve.http.InternalServerException")
.asRuntimeException());
break;
default:
break;
}
Expand Down Expand Up @@ -317,14 +323,4 @@ private void setOutputContents(JsonElement element, InferOutputTensor.Builder ou
}
outputBuilder.setContents(inferTensorContents); // set output contents
}

private boolean isResponseStructureOIP(JsonObject jsonObject) {
if (jsonObject.has("id") &&
jsonObject.has("model_name") &&
jsonObject.has("model_version") &&
jsonObject.has("outputs")) {
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ public enum WorkerCommands {
@SerializedName("streampredict")
STREAMPREDICT("streampredict"),
@SerializedName("streampredict2")
STREAMPREDICT2("streampredict2");
STREAMPREDICT2("streampredict2"),
@SerializedName("oippredict") // for kserve open inference protocol
OIPPREDICT("oippredict");

private String command;

Expand Down

0 comments on commit 0a2cef9

Please sign in to comment.