Skip to content

Commit

Permalink
added worker status check.
Browse files Browse the repository at this point in the history
Signed-off-by: Andrews Arokiam <andrews.arokiam@ideas2it.com>
  • Loading branch information
andyi2it committed Oct 11, 2023
1 parent 29168dc commit 6c967fd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class ServerInitializer extends ChannelInitializer<Channel> {

private ConnectorType connectorType;
private SslContext sslCtx;
private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);
private static final Logger logger = LoggerFactory.getLogger(ServerInitializer.class);

/**
* Creates a new {@code HttpRequestHandler} instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@
import org.pytorch.serve.wlm.Model;
import org.pytorch.serve.wlm.ModelManager;
import org.pytorch.serve.wlm.WorkerInitializationException;
import org.pytorch.serve.wlm.WorkerState;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.gson.JsonObject;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import com.google.gson.JsonArray;
import org.pytorch.serve.wlm.WorkerThread;

/**
* A class handling inbound HTTP requests to the Kserve's Open Inference
Expand All @@ -56,7 +58,7 @@
*/
public class OpenInferenceProtocolRequestHandler extends HttpRequestHandlerChain {

private static final Logger logger = LoggerFactory.getLogger(InferenceRequestHandler.class);
private static final Logger logger = LoggerFactory.getLogger(OpenInferenceProtocolRequestHandler.class);
private static final String TS_VERSION_FILE_PATH = "ts/version.txt";
private static final String SERVER_METADATA_API = "/v2";
private static final String SERVER_LIVE_API = "/v2/health/live";
Expand All @@ -80,12 +82,12 @@ public void handleRequest(
if (concatenatedSegments.equals(SERVER_READY_API)) {
// for serve ready check
JsonObject response = new JsonObject();
response.addProperty("ready", true);
response.addProperty("ready", getWorkerStatus());
NettyUtils.sendJsonResponse(ctx, response);
} else if (concatenatedSegments.equals(SERVER_LIVE_API)) {
// for serve live check
JsonObject response = new JsonObject();
response.addProperty("live", true);
response.addProperty("live", getWorkerStatus());
NettyUtils.sendJsonResponse(ctx, response);
} else if (concatenatedSegments.equals(SERVER_METADATA_API)) {
// For fetch server metadata
Expand All @@ -108,6 +110,7 @@ public void handleRequest(
}

private String getTsVersion() {
String tsVersion = "";
try {
BufferedReader reader = new BufferedReader(new FileReader(TS_VERSION_FILE_PATH));
String version = reader.readLine();
Expand All @@ -116,8 +119,22 @@ private String getTsVersion() {
} catch (IOException e) {
e.printStackTrace();
}
return null;
return tsVersion;

}

private boolean getWorkerStatus() {
boolean isTsWorkerStarted = false;
ModelManager modelManager = ModelManager.getInstance();
Map<Integer, WorkerThread> workersMap = modelManager.getWorkers();

List<WorkerThread> workers = new ArrayList<>(workersMap.values());

for (WorkerThread worker : workers) {
if (worker.getState() == WorkerState.WORKER_MODEL_LOADED) {
isTsWorkerStarted = true;
}
}
return isTsWorkerStarted;
}
}

0 comments on commit 6c967fd

Please sign in to comment.