diff --git a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java index 3211275595..c86cba8d6c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java +++ b/frontend/server/src/main/java/org/pytorch/serve/ServerInitializer.java @@ -7,6 +7,7 @@ import io.netty.handler.codec.http.HttpServerCodec; import io.netty.handler.ssl.SslContext; import org.pytorch.serve.http.ApiDescriptionRequestHandler; +import org.pytorch.serve.http.ExtendedSSLHandler; import org.pytorch.serve.http.HttpRequestHandler; import org.pytorch.serve.http.HttpRequestHandlerChain; import org.pytorch.serve.http.InferenceRequestHandler; @@ -48,7 +49,7 @@ public void initChannel(Channel ch) { int maxRequestSize = ConfigManager.getInstance().getMaxRequestSize(); if (sslCtx != null) { - pipeline.addLast("ssl", sslCtx.newHandler(ch.alloc())); + pipeline.addLast("ssl", new ExtendedSSLHandler(sslCtx, connectorType)); } pipeline.addLast("http", new HttpServerCodec()); pipeline.addLast("aggregator", new HttpObjectAggregator(maxRequestSize)); diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/ExtendedSSLHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/ExtendedSSLHandler.java new file mode 100644 index 0000000000..99c092bdf2 --- /dev/null +++ b/frontend/server/src/main/java/org/pytorch/serve/http/ExtendedSSLHandler.java @@ -0,0 +1,46 @@ +package org.pytorch.serve.http; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.ssl.OptionalSslHandler; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslHandler; +import java.util.List; +import org.pytorch.serve.util.ConfigManager; +import org.pytorch.serve.util.ConnectorType; +import org.pytorch.serve.util.NettyUtils; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ExtendedSSLHandler extends OptionalSslHandler { + private static final Logger logger = LoggerFactory.getLogger(ExtendedSSLHandler.class); + /** the length of the ssl record header (in bytes) */ + private static final int SSL_RECORD_HEADER_LENGTH = 5; + + private ConnectorType connectorType; + + public ExtendedSSLHandler(SslContext sslContext, ConnectorType connectorType) { + super(sslContext); + this.connectorType = connectorType; + } + + @Override + protected void decode(ChannelHandlerContext context, ByteBuf in, List out) + throws Exception { + if (in.readableBytes() < SSL_RECORD_HEADER_LENGTH) { + return; + } + ConfigManager configMgr = ConfigManager.getInstance(); + if (SslHandler.isEncrypted(in) || !configMgr.isSSLEnabled(connectorType)) { + super.decode(context, in, out); + } else { + logger.error("Recieved HTTP request!"); + NettyUtils.sendJsonResponse( + context, + new StatusResponse( + "This TorchServe instance only accepts HTTPS requests", + HttpResponseStatus.FORBIDDEN.code())); + } + } +} diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java index bb3a6bc72e..37c1bfa749 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/ConfigManager.java @@ -105,9 +105,15 @@ public final class ConfigManager { public static final String PYTHON_EXECUTABLE = "python"; + public static final Pattern ADDRESS_PATTERN = + Pattern.compile( + "((https|http)://([^:^/]+)(:([0-9]+))?)|(unix:(/.*))", + Pattern.CASE_INSENSITIVE); + private static Pattern pattern = Pattern.compile("\\$\\$([^$]+[^$])\\$\\$"); + private Pattern blacklistPattern; private Properties prop; - private static Pattern pattern = Pattern.compile("\\$\\$([^$]+[^$])\\$\\$"); + private boolean snapshotDisabled; private static ConfigManager instance; @@ -718,6 +724,30 @@ public boolean isSnapshotDisabled() { return snapshotDisabled; } + public boolean isSSLEnabled(ConnectorType connectorType) { + String address = prop.getProperty(TS_INFERENCE_ADDRESS, "http://127.0.0.1:8080"); + switch (connectorType) { + case MANAGEMENT_CONNECTOR: + address = prop.getProperty(TS_MANAGEMENT_ADDRESS, "http://127.0.0.1:8081"); + break; + case METRICS_CONNECTOR: + address = prop.getProperty(TS_METRICS_ADDRESS, "http://127.0.0.1:8082"); + break; + default: + break; + } + // String inferenceAddress = prop.getProperty(TS_INFERENCE_ADDRESS, + // "http://127.0.0.1:8080"); + Matcher matcher = ConfigManager.ADDRESS_PATTERN.matcher(address); + if (!matcher.matches()) { + throw new IllegalArgumentException("Invalid binding address: " + address); + } + + String protocol = matcher.group(2); + + return "https".equalsIgnoreCase(protocol); + } + public int getIniitialWorkerPort() { return Integer.parseInt(prop.getProperty(TS_INITIAL_WORKER_PORT, "9000")); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/util/Connector.java b/frontend/server/src/main/java/org/pytorch/serve/util/Connector.java index ab4cb82f7a..b98ec8b28c 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/util/Connector.java +++ b/frontend/server/src/main/java/org/pytorch/serve/util/Connector.java @@ -24,16 +24,10 @@ import java.net.SocketAddress; import java.util.Objects; import java.util.regex.Matcher; -import java.util.regex.Pattern; import org.apache.commons.io.FileUtils; public class Connector { - private static final Pattern ADDRESS_PATTERN = - Pattern.compile( - "((https|http)://([^:^/]+)(:([0-9]+))?)|(unix:(/.*))", - Pattern.CASE_INSENSITIVE); - private static boolean useNativeIo = ConfigManager.getInstance().useNativeIo(); private boolean uds; @@ -75,7 +69,7 @@ private Connector( } public static Connector parse(String binding, ConnectorType connectorType) { - Matcher matcher = ADDRESS_PATTERN.matcher(binding); + Matcher matcher = ConfigManager.ADDRESS_PATTERN.matcher(binding); if (!matcher.matches()) { throw new IllegalArgumentException("Invalid binding address: " + binding); }