Skip to content

Commit

Permalink
UNDERTOW-1463 Support proxy protocol v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Ulrich Herberg authored and stuartwdouglas committed Dec 17, 2018
1 parent 90c4485 commit dce25f0
Show file tree
Hide file tree
Showing 2 changed files with 431 additions and 114 deletions.
Expand Up @@ -22,6 +22,7 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* Implementation of version 1 of the proxy protocol (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
Expand All @@ -30,6 +31,7 @@
* fragmentation of
*
* @author Stuart Douglas
* @author Ulrich Herberg
*/
class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel> {

Expand All @@ -38,7 +40,9 @@ class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel>
private static final byte[] NAME = "PROXY ".getBytes(StandardCharsets.US_ASCII);
private static final String UNKNOWN = "UNKNOWN";
private static final String TCP4 = "TCP4";
private static final String TCP_6 = "TCP6";
private static final String TCP6 = "TCP6";

private static final byte[] SIG = new byte[] {0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A};

private final StreamConnection streamConnection;
private final OpenListener openListener;
Expand Down Expand Up @@ -71,115 +75,29 @@ class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel>
@Override
public void handleEvent(StreamSourceChannel streamSourceChannel) {
PooledByteBuffer buffer = bufferPool.allocate();
boolean freeBuffer = true;
AtomicBoolean freeBuffer = new AtomicBoolean(true);
try {
for (; ; ) {
int res = streamSourceChannel.read(buffer.getBuffer());
if (res == -1) {
IoUtils.safeClose(streamConnection);
return;
} else if (res == 0) {
return;
} else {
buffer.getBuffer().flip();
while (buffer.getBuffer().hasRemaining()) {
char c = (char) buffer.getBuffer().get();
if (byteCount < NAME.length) {
//first we verify that we have the correct protocol
if (c != NAME[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else {
if (parsingUnknown) {
//we are parsing the UNKNOWN protocol
//we just ignore everything till \r\n
if (c == '\r') {
carriageReturnSeen = true;
} else if (c == '\n') {
if (!carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
//we are done
if (buffer.getBuffer().hasRemaining()) {
freeBuffer = false;
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
} else if (carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (carriageReturnSeen) {
if (c == '\n') {
//we are done
SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer = false;
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else switch (c) {
case ' ':
//we have a space
if (sourcePort != -1 || stringBuilder.length() == 0) {
//header was invalid, either we are expecting a \r or a \n, or the previous character was a space
throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP4) && !protocol.equals(TCP_6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (sourceAddress == null) {
sourceAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else if (destAddress == null) {
destAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else {
sourcePort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
}
break;
case '\r':
if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) {
destPort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
carriageReturnSeen = true;
} else if (protocol == null) {
if (UNKNOWN.equals(stringBuilder.toString())) {
parsingUnknown = true;
carriageReturnSeen = true;
}
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
break;
case '\n':
throw UndertowMessages.MESSAGES.invalidProxyHeader();
default:
stringBuilder.append(c);
}

}
byteCount++;
if (byteCount == MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.headerSizeToLarge();
}
int res = streamSourceChannel.read(buffer.getBuffer());
if (res == -1) {
IoUtils.safeClose(streamConnection);
return;
} else if (res == 0) {
return;
} else {
buffer.getBuffer().flip();

if (buffer.getBuffer().hasRemaining()) {
byte firstByte = buffer.getBuffer().get(); // get first byte to determine whether Proxy Protocol V1 or V2 is used
byteCount++;
if (firstByte == SIG[0]) { // Could be Proxy Protocol V2
parseProxyProtocolV2(buffer, freeBuffer);
} else if ((char) firstByte == NAME[0]){ // Could be Proxy Protocol V1
parseProxyProtocolV1(buffer, freeBuffer);
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}


}
return;
}

} catch (IOException e) {
Expand All @@ -189,13 +107,221 @@ public void handleEvent(StreamSourceChannel streamSourceChannel) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(e));
IoUtils.safeClose(streamConnection);
} finally {
if (freeBuffer) {
if (freeBuffer.get()) {
buffer.close();
}
}
}



private void parseProxyProtocolV2(PooledByteBuffer buffer, AtomicBoolean freeBuffer) throws Exception {
while (byteCount < SIG.length) {
byte c = buffer.getBuffer().get();

//first we verify that we have the correct protocol
if (c != SIG[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
byteCount++;
}

byte ver_cmd = buffer.getBuffer().get();
byte fam = buffer.getBuffer().get();
int len = (buffer.getBuffer().getShort() & 0xffff);

if ((ver_cmd & 0xF0) != 0x20) { // expect version 2
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}

switch (ver_cmd & 0x0F) {
case 0x01: // PROXY command
switch (fam) {
case 0x11: { // TCP over IPv4
if (len < 12) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}

byte[] sourceAddressBytes = new byte[4];
buffer.getBuffer().get(sourceAddressBytes);
sourceAddress = InetAddress.getByAddress(sourceAddressBytes);

byte[] dstAddressBytes = new byte[4];
buffer.getBuffer().get(dstAddressBytes);
destAddress = InetAddress.getByAddress(dstAddressBytes);

sourcePort = buffer.getBuffer().getShort() & 0xffff;
destPort = buffer.getBuffer().getShort() & 0xffff;

if (len > 12) {
int skipAhead = len - 12;
int currentPosition = buffer.getBuffer().position();
buffer.getBuffer().position(currentPosition + skipAhead);
}

break;
}

case 0x21: { // TCP over IPv6
if (len < 36) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}

byte[] sourceAddressBytes = new byte[16];
buffer.getBuffer().get(sourceAddressBytes);
sourceAddress = InetAddress.getByAddress(sourceAddressBytes);

byte[] dstAddressBytes = new byte[16];
buffer.getBuffer().get(dstAddressBytes);
destAddress = InetAddress.getByAddress(dstAddressBytes);

sourcePort = buffer.getBuffer().getShort() & 0xffff;
destPort = buffer.getBuffer().getShort() & 0xffff;

if (len > 36) {
int skipAhead = len - 36;
int currentPosition = buffer.getBuffer().position();
buffer.getBuffer().position(currentPosition + skipAhead);
}

break;
}

default: // AF_UNIX sockets not supported
throw UndertowMessages.MESSAGES.invalidProxyHeader();

}
break;
case 0x00: // LOCAL command
if (len > 0) {
int skipAhead = len;
int currentPosition = buffer.getBuffer().position();
buffer.getBuffer().position(currentPosition + skipAhead);
}

if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
default:
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}


SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
}

private void parseProxyProtocolV1(PooledByteBuffer buffer, AtomicBoolean freeBuffer) throws Exception {
while (buffer.getBuffer().hasRemaining()) {
char c = (char) buffer.getBuffer().get();
if (byteCount < NAME.length) {
//first we verify that we have the correct protocol
if (c != NAME[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else {
if (parsingUnknown) {
//we are parsing the UNKNOWN protocol
//we just ignore everything till \r\n
if (c == '\r') {
carriageReturnSeen = true;
} else if (c == '\n') {
if (!carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
//we are done
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
} else if (carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (carriageReturnSeen) {
if (c == '\n') {
//we are done
SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer.set(false);
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else switch (c) {
case ' ':
//we have a space
if (sourcePort != -1 || stringBuilder.length() == 0) {
//header was invalid, either we are expecting a \r or a \n, or the previous character was a space
throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP4) && !protocol.equals(TCP6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (sourceAddress == null) {
sourceAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else if (destAddress == null) {
destAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else {
sourcePort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
}
break;
case '\r':
if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) {
destPort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
carriageReturnSeen = true;
} else if (protocol == null) {
if (UNKNOWN.equals(stringBuilder.toString())) {
parsingUnknown = true;
carriageReturnSeen = true;
}
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
break;
case '\n':
throw UndertowMessages.MESSAGES.invalidProxyHeader();
default:
stringBuilder.append(c);
}

}

byteCount++;
if (byteCount == MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.headerSizeToLarge();
}

}
}


private void proxyAccept(SocketAddress source, SocketAddress dest, PooledByteBuffer additionalData) {
StreamConnection streamConnection = this.streamConnection;
if (source != null) {
Expand Down Expand Up @@ -275,5 +401,4 @@ public SocketAddress getLocalAddress() {
return dest;
}
}

}

0 comments on commit dce25f0

Please sign in to comment.