Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open Inference Protocol Implementation. #2609

Merged
merged 11 commits into from
Jan 24, 2024
6 changes: 5 additions & 1 deletion .github/workflows/kserve_cpu_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ jobs:
with:
python-version: 3.8
architecture: x64
- name: Install grpcurl
run: |
sudo curl -sSL https://github.com/fullstorydev/grpcurl/releases/download/v1.8.0/grpcurl_1.8.0_linux_x86_64.tar.gz | sudo tar -xz -C /usr/local/bin grpcurl
sudo chmod +x /usr/local/bin/grpcurl
- name: Checkout TorchServe
uses: actions/checkout@v3
- name: Checkout kserve repo
Expand All @@ -37,5 +41,5 @@ jobs:
repository: kserve/kserve
ref: v0.11.1
path: kserve
- name: Validate torchserve-kfs
- name: Validate torchserve-kfs and Open Inference Protocol
run: ./kubernetes/kserve/tests/scripts/test_mnist.sh
11 changes: 11 additions & 0 deletions frontend/server/src/main/java/org/pytorch/serve/ModelServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ public class ModelServer {
private ServerGroups serverGroups;
private Server inferencegRPCServer;
private Server managementgRPCServer;
private Server OIPgRPCServer;
private List<ChannelFuture> futures = new ArrayList<>(2);
private AtomicBoolean stopped = new AtomicBoolean(false);
private ConfigManager configManager;
Expand Down Expand Up @@ -453,6 +454,16 @@ private Server startGRPCServer(ConnectorType connectorType) throws IOException {
GRPCServiceFactory.getgRPCService(connectorType),
new GRPCInterceptor()));

if (connectorType == ConnectorType.INFERENCE_CONNECTOR
&& ConfigManager.getInstance().isOpenInferenceProtocol()) {
s.maxInboundMessageSize(configManager.getMaxRequestSize())
.addService(
ServerInterceptors.intercept(
GRPCServiceFactory.getgRPCService(
ConnectorType.OPEN_INFERENCE_CONNECTOR),
new GRPCInterceptor()));
}

if (configManager.isGRPCSSLEnabled()) {
s.useTransportSecurity(
new File(configManager.getCertificateFile()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
import org.pytorch.serve.http.api.rest.ApiDescriptionRequestHandler;
import org.pytorch.serve.http.api.rest.InferenceRequestHandler;
import org.pytorch.serve.http.api.rest.ManagementRequestHandler;
import org.pytorch.serve.http.api.rest.OpenInferenceProtocolRequestHandler;
import org.pytorch.serve.http.api.rest.PrometheusMetricsRequestHandler;
import org.pytorch.serve.servingsdk.impl.PluginsManager;
import org.pytorch.serve.util.ConfigManager;
import org.pytorch.serve.util.ConnectorType;
import org.pytorch.serve.workflow.api.http.WorkflowInferenceRequestHandler;
import org.pytorch.serve.workflow.api.http.WorkflowMgmtRequestHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* A special {@link io.netty.channel.ChannelInboundHandler} which offers an easy way to initialize a
Expand All @@ -29,6 +32,7 @@ public class ServerInitializer extends ChannelInitializer<Channel> {

private ConnectorType connectorType;
private SslContext sslCtx;
private static final Logger logger = LoggerFactory.getLogger(ServerInitializer.class);

/**
* Creates a new {@code HttpRequestHandler} instance.
Expand Down Expand Up @@ -65,6 +69,14 @@ public void initChannel(Channel ch) {
PluginsManager.getInstance().getInferenceEndpoints()));
httpRequestHandlerChain =
httpRequestHandlerChain.setNextHandler(new WorkflowInferenceRequestHandler());

// Added OIP protocol with inference connector
if (ConfigManager.getInstance().isOpenInferenceProtocol()) {
logger.info("OIP added with handler chain");
httpRequestHandlerChain =
httpRequestHandlerChain.setNextHandler(
new OpenInferenceProtocolRequestHandler());
}
}
if (ConnectorType.ALL.equals(connectorType)
|| ConnectorType.MANAGEMENT_CONNECTOR.equals(connectorType)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ public static BindableService getgRPCService(ConnectorType connectorType) {
case INFERENCE_CONNECTOR:
torchServeService = new InferenceImpl();
break;
case OPEN_INFERENCE_CONNECTOR:
torchServeService = new OpenInferenceProtocolImpl();
break;
default:
break;
}
Expand Down
Loading
Loading