Skip to content

Commit

Permalink
UNDERTOW-1166: make PROXY protocol implementation compliant with spec
Browse files Browse the repository at this point in the history
  • Loading branch information
spaletta committed Aug 24, 2017
1 parent fe85429 commit 18ac4a2
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 48 deletions.
24 changes: 13 additions & 11 deletions core/src/main/java/io/undertow/Undertow.java
Expand Up @@ -76,7 +76,7 @@ public final class Undertow {
* Will be true when a {@link XnioWorker} instance was NOT provided to the {@link Builder}. * Will be true when a {@link XnioWorker} instance was NOT provided to the {@link Builder}.
* When true, a new worker will be created during {@link Undertow#start()}, * When true, a new worker will be created during {@link Undertow#start()},
* and shutdown when {@link Undertow#stop()} is called. * and shutdown when {@link Undertow#stop()} is called.
* * <p>
* Will be false when a {@link XnioWorker} instance was provided to the {@link Builder}. * Will be false when a {@link XnioWorker} instance was provided to the {@link Builder}.
* When false, the provided {@link #worker} will be used instead of creating a new one in {@link Undertow#start()}. * When false, the provided {@link #worker} will be used instead of creating a new one in {@link Undertow#start()}.
* Also, when false, the {@link #worker} will NOT be shutdown when {@link Undertow#stop()} is called. * Also, when false, the {@link #worker} will NOT be shutdown when {@link Undertow#stop()} is called.
Expand Down Expand Up @@ -153,7 +153,7 @@ public synchronized void start() {
openListener.setRootHandler(rootHandler); openListener.setRootHandler(rootHandler);


final ChannelListener<StreamConnection> finalListener; final ChannelListener<StreamConnection> finalListener;
if(listener.useProxyProtocol) { if (listener.useProxyProtocol) {
finalListener = new ProxyProtocolOpenListener(openListener, null, buffers, OptionMap.EMPTY); finalListener = new ProxyProtocolOpenListener(openListener, null, buffers, OptionMap.EMPTY);
} else { } else {
finalListener = openListener; finalListener = openListener;
Expand All @@ -170,12 +170,12 @@ public synchronized void start() {
if (listener.type == ListenerType.HTTP) { if (listener.type == ListenerType.HTTP) {
HttpOpenListener openListener = new HttpOpenListener(buffers, undertowOptions); HttpOpenListener openListener = new HttpOpenListener(buffers, undertowOptions);
HttpHandler handler = rootHandler; HttpHandler handler = rootHandler;
if(http2) { if (http2) {
handler = new Http2UpgradeHandler(handler); handler = new Http2UpgradeHandler(handler);
} }
openListener.setRootHandler(handler); openListener.setRootHandler(handler);
final ChannelListener<StreamConnection> finalListener; final ChannelListener<StreamConnection> finalListener;
if(listener.useProxyProtocol) { if (listener.useProxyProtocol) {
finalListener = new ProxyProtocolOpenListener(openListener, null, buffers, OptionMap.EMPTY); finalListener = new ProxyProtocolOpenListener(openListener, null, buffers, OptionMap.EMPTY);
} else { } else {
finalListener = openListener; finalListener = openListener;
Expand All @@ -193,9 +193,9 @@ public synchronized void start() {
HttpOpenListener httpOpenListener = new HttpOpenListener(buffers, undertowOptions); HttpOpenListener httpOpenListener = new HttpOpenListener(buffers, undertowOptions);
httpOpenListener.setRootHandler(rootHandler); httpOpenListener.setRootHandler(rootHandler);


if(http2) { if (http2) {
AlpnOpenListener alpn = new AlpnOpenListener(buffers, undertowOptions, httpOpenListener); AlpnOpenListener alpn = new AlpnOpenListener(buffers, undertowOptions, httpOpenListener);
if(http2) { if (http2) {
Http2OpenListener http2Listener = new Http2OpenListener(buffers, undertowOptions); Http2OpenListener http2Listener = new Http2OpenListener(buffers, undertowOptions);
http2Listener.setRootHandler(rootHandler); http2Listener.setRootHandler(rootHandler);
alpn.addProtocol(Http2OpenListener.HTTP2, http2Listener, 10); alpn.addProtocol(Http2OpenListener.HTTP2, http2Listener, 10);
Expand All @@ -212,15 +212,15 @@ public synchronized void start() {
} else { } else {
OptionMap.Builder builder = OptionMap.builder(); OptionMap.Builder builder = OptionMap.builder();
builder.addAll(listener.overrideSocketOptions); builder.addAll(listener.overrideSocketOptions);
if(!listener.overrideSocketOptions.contains(Options.SSL_PROTOCOL)) { if (!listener.overrideSocketOptions.contains(Options.SSL_PROTOCOL)) {
builder.set(Options.SSL_PROTOCOL, "TLSv1.2"); builder.set(Options.SSL_PROTOCOL, "TLSv1.2");
} }
xnioSsl = new UndertowXnioSsl(xnio, OptionMap.create(Options.USE_DIRECT_BUFFERS, true), JsseSslUtils.createSSLContext(listener.keyManagers, listener.trustManagers, new SecureRandom(), builder.getMap())); xnioSsl = new UndertowXnioSsl(xnio, OptionMap.create(Options.USE_DIRECT_BUFFERS, true), JsseSslUtils.createSSLContext(listener.keyManagers, listener.trustManagers, new SecureRandom(), builder.getMap()));
} }


OptionMap socketOptionsWithOverrides = OptionMap.builder().addAll(socketOptions).addAll(listener.overrideSocketOptions).getMap(); OptionMap socketOptionsWithOverrides = OptionMap.builder().addAll(socketOptions).addAll(listener.overrideSocketOptions).getMap();
AcceptingChannel<? extends StreamConnection> sslServer; AcceptingChannel<? extends StreamConnection> sslServer;
if(listener.useProxyProtocol) { if (listener.useProxyProtocol) {
ChannelListener<AcceptingChannel<StreamConnection>> acceptListener = ChannelListeners.openListenerAdapter(new ProxyProtocolOpenListener(openListener, xnioSsl, buffers, socketOptionsWithOverrides)); ChannelListener<AcceptingChannel<StreamConnection>> acceptListener = ChannelListeners.openListenerAdapter(new ProxyProtocolOpenListener(openListener, xnioSsl, buffers, socketOptionsWithOverrides));
sslServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(listener.host), listener.port), (ChannelListener) acceptListener, socketOptionsWithOverrides); sslServer = worker.createStreamConnectionServer(new InetSocketAddress(Inet4Address.getByName(listener.host), listener.port), (ChannelListener) acceptListener, socketOptionsWithOverrides);
} else { } else {
Expand Down Expand Up @@ -270,21 +270,21 @@ public XnioWorker getWorker() {
} }


public List<ListenerInfo> getListenerInfo() { public List<ListenerInfo> getListenerInfo() {
if(listenerInfo == null) { if (listenerInfo == null) {
throw UndertowMessages.MESSAGES.serverNotStarted(); throw UndertowMessages.MESSAGES.serverNotStarted();
} }
return Collections.unmodifiableList(listenerInfo); return Collections.unmodifiableList(listenerInfo);
} }





public enum ListenerType { public enum ListenerType {
HTTP, HTTP,
HTTPS, HTTPS,
AJP AJP
} }


private static class ListenerConfig { private static class ListenerConfig {

final ListenerType type; final ListenerType type;
final int port; final int port;
final String host; final String host;
Expand Down Expand Up @@ -333,6 +333,7 @@ private ListenerConfig(final ListenerBuilder listenerBuilder) {
} }


public static final class ListenerBuilder { public static final class ListenerBuilder {

ListenerType type; ListenerType type;
int port; int port;
String host; String host;
Expand Down Expand Up @@ -485,6 +486,7 @@ public Builder addAjpListener(int port, String host, HttpHandler rootHandler) {
listeners.add(new ListenerConfig(ListenerType.AJP, port, host, null, null, rootHandler)); listeners.add(new ListenerConfig(ListenerType.AJP, port, host, null, null, rootHandler));
return this; return this;
} }

public Builder setBufferSize(final int bufferSize) { public Builder setBufferSize(final int bufferSize) {
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
return this; return this;
Expand Down Expand Up @@ -536,7 +538,7 @@ public <T> Builder setWorkerOption(final Option<T> option, final T value) {
* when {@link Undertow#start()} is called. * when {@link Undertow#start()} is called.
* Additionally, this newly created worker will be shutdown when {@link Undertow#stop()} is called. * Additionally, this newly created worker will be shutdown when {@link Undertow#stop()} is called.
* <br/> * <br/>
* * <p>
* When non-null, the provided {@link XnioWorker} will be reused instead of creating a new {@link XnioWorker} * When non-null, the provided {@link XnioWorker} will be reused instead of creating a new {@link XnioWorker}
* when {@link Undertow#start()} is called. * when {@link Undertow#start()} is called.
* Additionally, the provided {@link XnioWorker} will NOT be shutdown when {@link Undertow#stop()} is called. * Additionally, the provided {@link XnioWorker} will NOT be shutdown when {@link Undertow#stop()} is called.
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/io/undertow/UndertowMessages.java
Expand Up @@ -552,7 +552,7 @@ public interface UndertowMessages {
@Message(id = 178, value = "Buffer pool is too small, min size is %s") @Message(id = 178, value = "Buffer pool is too small, min size is %s")
IllegalArgumentException bufferPoolTooSmall(int minSize); IllegalArgumentException bufferPoolTooSmall(int minSize);


@Message(id = 179, value = "Invalid proxy header") @Message(id = 179, value = "Invalid PROXY protocol header")
IOException invalidProxyHeader(); IOException invalidProxyHeader();


@Message(id = 180, value = "PROXY protocol header exceeded max size of 107 bytes") @Message(id = 180, value = "PROXY protocol header exceeded max size of 107 bytes")
Expand Down
Expand Up @@ -35,7 +35,7 @@ class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel>
private static final int MAX_HEADER_LENGTH = 107; private static final int MAX_HEADER_LENGTH = 107;


private static final byte[] NAME = "PROXY ".getBytes(StandardCharsets.US_ASCII); private static final byte[] NAME = "PROXY ".getBytes(StandardCharsets.US_ASCII);
private static final String UNKOWN = "UNKOWN"; private static final String UNKNOWN = "UNKNOWN";
private static final String TCP = "TCP"; private static final String TCP = "TCP";
private static final String TCP_6 = "TCP6"; private static final String TCP_6 = "TCP6";


Expand All @@ -53,7 +53,7 @@ class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel>
private int destPort = -1; private int destPort = -1;
private StringBuilder stringBuilder = new StringBuilder(); private StringBuilder stringBuilder = new StringBuilder();
private boolean carriageReturnSeen = false; private boolean carriageReturnSeen = false;
private boolean parsingUnkown = false; private boolean parsingUnknown = false;




ProxyProtocolReadListener(StreamConnection streamConnection, OpenListener openListener, UndertowXnioSsl ssl, ByteBufferPool bufferPool, OptionMap sslOptionMap) { ProxyProtocolReadListener(StreamConnection streamConnection, OpenListener openListener, UndertowXnioSsl ssl, ByteBufferPool bufferPool, OptionMap sslOptionMap) {
Expand Down Expand Up @@ -89,8 +89,8 @@ public void handleEvent(StreamSourceChannel streamSourceChannel) {
throw UndertowMessages.MESSAGES.invalidProxyHeader(); throw UndertowMessages.MESSAGES.invalidProxyHeader();
} }
} else { } else {
if (parsingUnkown) { if (parsingUnknown) {
//we are parsing the UNKOWN protocol //we are parsing the UNKNOWN protocol
//we just ignore everything till \r\n //we just ignore everything till \r\n
if (c == '\r') { if (c == '\r') {
carriageReturnSeen = true; carriageReturnSeen = true;
Expand Down Expand Up @@ -124,41 +124,49 @@ public void handleEvent(StreamSourceChannel streamSourceChannel) {
} else { } else {
throw UndertowMessages.MESSAGES.invalidProxyHeader(); throw UndertowMessages.MESSAGES.invalidProxyHeader();
} }
} else if (c == ' ') { } else switch (c) {
//we have a space case ' ':
if (sourcePort != -1 || stringBuilder.length() == 0) { //we have a space
//header was invalid, either we are expecting a \r or a \n, or the previous character was a space if (sourcePort != -1 || stringBuilder.length() == 0) {
throw UndertowMessages.MESSAGES.invalidProxyHeader(); //header was invalid, either we are expecting a \r or a \n, or the previous character was a space
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKOWN)) {
parsingUnkown = true;
} else if (!protocol.equals(TCP) && !protocol.equals(TCP_6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader(); throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP) && !protocol.equals(TCP_6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (sourceAddress == null) {
sourceAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else if (destAddress == null) {
destAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else {
sourcePort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
} }
} else if (sourceAddress == null) { break;
sourceAddress = parseAddress(stringBuilder.toString(), protocol); case '\r':
stringBuilder.setLength(0); if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) {
} else if (destAddress == null) { destPort = Integer.parseInt(stringBuilder.toString());
destAddress = parseAddress(stringBuilder.toString(), protocol); stringBuilder.setLength(0);
stringBuilder.setLength(0); carriageReturnSeen = true;
} else { } else if (protocol == null) {
sourcePort = Integer.parseInt(stringBuilder.toString()); if (UNKNOWN.equals(stringBuilder.toString())) {
stringBuilder.setLength(0); parsingUnknown = true;
} carriageReturnSeen = true;
} else if (c == '\r') { }
if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) { } else {
destPort = Integer.parseInt(stringBuilder.toString()); throw UndertowMessages.MESSAGES.invalidProxyHeader();
stringBuilder.setLength(0); }
carriageReturnSeen = true; break;
} else { case '\n':
throw UndertowMessages.MESSAGES.invalidProxyHeader(); throw UndertowMessages.MESSAGES.invalidProxyHeader();
} default:
} else if (c == '\n') { stringBuilder.append(c);
throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else {
stringBuilder.append(c);
} }


} }
Expand Down
Expand Up @@ -81,4 +81,48 @@ public void handleRequest(HttpServerExchange exchange) throws Exception {
undertow.stop(); undertow.stop();
} }
} }

@Test
public void testProxyProtocolUnknownEmpty() throws Exception {
doTestProxyProtocolUnknown("");
}

@Test
public void testProxyProtocolUnknownSpace() throws Exception {
doTestProxyProtocolUnknown(" ");
}

@Test
public void testProxyProtocolUnknownJunk() throws Exception {
doTestProxyProtocolUnknown(" mekmitasdigoat");
}

public void doTestProxyProtocolUnknown(String extra) throws Exception {
Undertow undertow = Undertow.builder().addListener(
new Undertow.ListenerBuilder()
.setType(Undertow.ListenerType.HTTP)
.setHost(DefaultServer.getHostAddress())
.setUseProxyProtocol(true)
.setPort(0)
)
.setHandler(new HttpHandler() {
@Override
public void handleRequest(HttpServerExchange exchange) throws Exception {
exchange.setPersistent(false);
exchange.getResponseHeaders().put(new HttpString("result"), exchange.getSourceAddress().toString() + " " + exchange.getDestinationAddress().toString());
}
})
.build();
try {
undertow.start();
InetSocketAddress serverAddress = (InetSocketAddress) undertow.getListenerInfo().get(0).getAddress();
Socket s = new Socket(serverAddress.getAddress(), serverAddress.getPort());
String expected = String.format("result: /%s:%d /%s:%d", s.getLocalAddress().getHostAddress(), s.getLocalPort(), serverAddress.getAddress().getHostAddress(), serverAddress.getPort());
s.getOutputStream().write(("PROXY UNKNOWN" + extra + "\r\nGET / HTTP/1.0\r\n\r\n").getBytes(StandardCharsets.US_ASCII));
String result = FileUtils.readFile(s.getInputStream());
Assert.assertTrue(result, result.contains(expected));
} finally {
undertow.stop();
}
}
} }

0 comments on commit 18ac4a2

Please sign in to comment.