Skip to content

Commit

Permalink
added worker status check.
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>
  • Loading branch information
andyi2it committed Nov 8, 2023
1 parent aa5c80d commit 7cc7579
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class ServerInitializer extends ChannelInitializer<Channel> {

private ConnectorType connectorType;
private SslContext sslCtx;
private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);
private static final Logger logger = LoggerFactory.getLogger(ServerInitializer.class);

/**
* Creates a new {@code HttpRequestHandler} instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public void serverLive(ServerLiveRequest request, StreamObserver<ServerLiveRespo
});

ServerLiveResponse readyResponse = ServerLiveResponse.newBuilder()
.setLive(true)
.setLive(ApiUtils.getTsWorkerStatus())
.build();
responseObserver.onNext(readyResponse);
responseObserver.onCompleted();
Expand All @@ -80,7 +80,7 @@ public void serverReady(ServerReadyRequest request, StreamObserver<ServerReadyRe
});

ServerReadyResponse readyResponse = ServerReadyResponse.newBuilder()
.setReady(true)
.setReady(ApiUtils.getTsWorkerStatus())
.build();
responseObserver.onNext(readyResponse);
responseObserver.onCompleted();
Expand Down Expand Up @@ -236,7 +236,7 @@ public void modelInfer(ModelInferRequest request, StreamObserver<ModelInferRespo
Job job = new GRPCJob(responseObserver, modelName, modelVersion, inputData, WorkerCommands.PREDICT);

if (!modelManager.addJob(job)) {
String responseMessage = ApiUtils.getInferenceErrorResponseMessage(modelName, modelVersion);
String responseMessage = ApiUtils.getStreamingInferenceErrorResponseMessage(modelName, modelVersion);
InternalServerException e = new InternalServerException(responseMessage);
sendErrorResponse(
responseObserver, Status.INTERNAL, e, "InternalServerException.()");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.pytorch.serve.wlm.WorkerState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.gson.JsonObject;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import com.google.gson.JsonArray;
import org.pytorch.serve.wlm.WorkerThread;

/**
* A class handling inbound HTTP requests to the Kserve's Open Inference
Expand All @@ -56,7 +58,7 @@
*/
public class OpenInferenceProtocolRequestHandler extends HttpRequestHandlerChain {

private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);
private static final Logger logger = LoggerFactory.getLogger(OpenInferenceProtocolRequestHandler.class);
private static final String TS_VERSION_FILE_PATH = "ts/version.txt";
private static final String SERVER_METADATA_API = "/v2";
private static final String SERVER_LIVE_API = "/v2/health/live";
Expand All @@ -80,12 +82,12 @@ public void handleRequest(
if (concatenatedSegments.equals(SERVER_READY_API)) {
// for serve ready check
JsonObject response = new JsonObject();
response.addProperty("ready", true);
response.addProperty("ready", ApiUtils.getTsWorkerStatus());
NettyUtils.sendJsonResponse(ctx, response);
} else if (concatenatedSegments.equals(SERVER_LIVE_API)) {
// for serve live check
JsonObject response = new JsonObject();
response.addProperty("live", true);
response.addProperty("live", ApiUtils.getTsWorkerStatus());
NettyUtils.sendJsonResponse(ctx, response);
} else if (concatenatedSegments.equals(SERVER_METADATA_API)) {
// For fetch server metadata
Expand All @@ -108,6 +110,7 @@ public void handleRequest(
}

private String getTsVersion() {
String tsVersion = "";
try {
BufferedReader reader = new BufferedReader(new FileReader(TS_VERSION_FILE_PATH));
String version = reader.readLine();
Expand All @@ -116,8 +119,7 @@ private String getTsVersion() {
} catch (IOException e) {
e.printStackTrace();
}
return null;
return tsVersion;

}

}
16 changes: 16 additions & 0 deletions frontend/server/src/main/java/org/pytorch/serve/util/ApiUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -459,4 +459,20 @@ public static String getDescribeErrorResponseMessage(String modelName) {
"\" has no worker to serve describe request. Please use scale workers API to add workers.";
return responseMessage;
}

public static boolean getTsWorkerStatus() {
boolean isTsWorkerStarted = false;
ModelManager modelManager = ModelManager.getInstance();
Map<Integer, WorkerThread> workersMap = modelManager.getWorkers();

List<WorkerThread> workers = new ArrayList<>(workersMap.values());

for (WorkerThread worker : workers) {
if (worker.getState() == WorkerState.WORKER_MODEL_LOADED) {
isTsWorkerStarted = true;
}
}
return isTsWorkerStarted;
}

}

0 comments on commit 7cc7579

Please sign in to comment.