/
ProxyProtocolReadListener.java
280 lines (250 loc) · 12.4 KB
/
ProxyProtocolReadListener.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
package io.undertow.server.protocol.proxy;
import io.undertow.UndertowLogger;
import io.undertow.UndertowMessages;
import io.undertow.connector.ByteBufferPool;
import io.undertow.connector.PooledByteBuffer;
import io.undertow.protocols.ssl.UndertowXnioSsl;
import io.undertow.server.DelegateOpenListener;
import io.undertow.server.OpenListener;
import io.undertow.util.NetworkUtils;
import io.undertow.util.PooledAdaptor;
import org.xnio.ChannelListener;
import org.xnio.IoUtils;
import org.xnio.OptionMap;
import org.xnio.StreamConnection;
import org.xnio.channels.StreamSourceChannel;
import org.xnio.conduits.PushBackStreamSourceConduit;
import org.xnio.ssl.SslConnection;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.charset.StandardCharsets;
/**
* Implementation of version 1 of the proxy protocol (https://www.haproxy.org/download/1.8/doc/proxy-protocol.txt)
* <p>
* Even though it is not required by the spec this implementation provides a stateful parser, that can handle
* fragmentation of
*
* @author Stuart Douglas
*/
class ProxyProtocolReadListener implements ChannelListener<StreamSourceChannel> {
private static final int MAX_HEADER_LENGTH = 107;
private static final byte[] NAME = "PROXY ".getBytes(StandardCharsets.US_ASCII);
private static final String UNKNOWN = "UNKNOWN";
private static final String TCP4 = "TCP4";
private static final String TCP_6 = "TCP6";
private final StreamConnection streamConnection;
private final OpenListener openListener;
private final UndertowXnioSsl ssl;
private final ByteBufferPool bufferPool;
private final OptionMap sslOptionMap;
private int byteCount;
private String protocol;
private InetAddress sourceAddress;
private InetAddress destAddress;
private int sourcePort = -1;
private int destPort = -1;
private StringBuilder stringBuilder = new StringBuilder();
private boolean carriageReturnSeen = false;
private boolean parsingUnknown = false;
ProxyProtocolReadListener(StreamConnection streamConnection, OpenListener openListener, UndertowXnioSsl ssl, ByteBufferPool bufferPool, OptionMap sslOptionMap) {
this.streamConnection = streamConnection;
this.openListener = openListener;
this.ssl = ssl;
this.bufferPool = bufferPool;
this.sslOptionMap = sslOptionMap;
if (bufferPool.getBufferSize() < MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.bufferPoolTooSmall(MAX_HEADER_LENGTH);
}
}
@Override
public void handleEvent(StreamSourceChannel streamSourceChannel) {
PooledByteBuffer buffer = bufferPool.allocate();
boolean freeBuffer = true;
try {
for (; ; ) {
int res = streamSourceChannel.read(buffer.getBuffer());
if (res == -1) {
IoUtils.safeClose(streamConnection);
return;
} else if (res == 0) {
return;
} else {
buffer.getBuffer().flip();
while (buffer.getBuffer().hasRemaining()) {
char c = (char) buffer.getBuffer().get();
if (byteCount < NAME.length) {
//first we verify that we have the correct protocol
if (c != NAME[byteCount]) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else {
if (parsingUnknown) {
//we are parsing the UNKNOWN protocol
//we just ignore everything till \r\n
if (c == '\r') {
carriageReturnSeen = true;
} else if (c == '\n') {
if (!carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
//we are done
if (buffer.getBuffer().hasRemaining()) {
freeBuffer = false;
proxyAccept(null, null, buffer);
} else {
proxyAccept(null, null, null);
}
return;
} else if (carriageReturnSeen) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (carriageReturnSeen) {
if (c == '\n') {
//we are done
SocketAddress s = new InetSocketAddress(sourceAddress, sourcePort);
SocketAddress d = new InetSocketAddress(destAddress, destPort);
if (buffer.getBuffer().hasRemaining()) {
freeBuffer = false;
proxyAccept(s, d, buffer);
} else {
proxyAccept(s, d, null);
}
return;
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else switch (c) {
case ' ':
//we have a space
if (sourcePort != -1 || stringBuilder.length() == 0) {
//header was invalid, either we are expecting a \r or a \n, or the previous character was a space
throw UndertowMessages.MESSAGES.invalidProxyHeader();
} else if (protocol == null) {
protocol = stringBuilder.toString();
stringBuilder.setLength(0);
if (protocol.equals(UNKNOWN)) {
parsingUnknown = true;
} else if (!protocol.equals(TCP4) && !protocol.equals(TCP_6)) {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
} else if (sourceAddress == null) {
sourceAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else if (destAddress == null) {
destAddress = parseAddress(stringBuilder.toString(), protocol);
stringBuilder.setLength(0);
} else {
sourcePort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
}
break;
case '\r':
if (destPort == -1 && sourcePort != -1 && !carriageReturnSeen && stringBuilder.length() > 0) {
destPort = Integer.parseInt(stringBuilder.toString());
stringBuilder.setLength(0);
carriageReturnSeen = true;
} else if (protocol == null) {
if (UNKNOWN.equals(stringBuilder.toString())) {
parsingUnknown = true;
carriageReturnSeen = true;
}
} else {
throw UndertowMessages.MESSAGES.invalidProxyHeader();
}
break;
case '\n':
throw UndertowMessages.MESSAGES.invalidProxyHeader();
default:
stringBuilder.append(c);
}
}
byteCount++;
if (byteCount == MAX_HEADER_LENGTH) {
throw UndertowMessages.MESSAGES.headerSizeToLarge();
}
}
}
}
} catch (IOException e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(e);
IoUtils.safeClose(streamConnection);
} catch (Exception e) {
UndertowLogger.REQUEST_IO_LOGGER.ioException(new IOException(e));
IoUtils.safeClose(streamConnection);
} finally {
if (freeBuffer) {
buffer.close();
}
}
}
private void proxyAccept(SocketAddress source, SocketAddress dest, PooledByteBuffer additionalData) {
StreamConnection streamConnection = this.streamConnection;
if (source != null) {
streamConnection = new AddressWrappedConnection(streamConnection, source, dest);
}
if (ssl != null) {
//we need to apply the additional data before the SSL wrapping
if (additionalData != null) {
PushBackStreamSourceConduit conduit = new PushBackStreamSourceConduit(streamConnection.getSourceChannel().getConduit());
conduit.pushBack(new PooledAdaptor(additionalData));
streamConnection.getSourceChannel().setConduit(conduit);
}
SslConnection sslConnection = ssl.wrapExistingConnection(streamConnection, sslOptionMap == null ? OptionMap.EMPTY : sslOptionMap);
UndertowXnioSsl.getSslEngine(sslConnection).setUseClientMode(false);
streamConnection = sslConnection;
callOpenListener(streamConnection, null);
} else {
callOpenListener(streamConnection, additionalData);
}
}
private void callOpenListener(StreamConnection streamConnection, final PooledByteBuffer buffer) {
if (openListener instanceof DelegateOpenListener) {
((DelegateOpenListener) openListener).handleEvent(streamConnection, buffer);
} else {
if (buffer != null) {
PushBackStreamSourceConduit conduit = new PushBackStreamSourceConduit(streamConnection.getSourceChannel().getConduit());
conduit.pushBack(new PooledAdaptor(buffer));
streamConnection.getSourceChannel().setConduit(conduit);
}
openListener.handleEvent(streamConnection);
}
}
static InetAddress parseAddress(String addressString, String protocol) throws IOException {
if (protocol.equals(TCP4)) {
return NetworkUtils.parseIpv4Address(addressString);
} else {
return NetworkUtils.parseIpv6Address(addressString);
}
}
private static final class AddressWrappedConnection extends StreamConnection {
private final StreamConnection delegate;
private final SocketAddress source;
private final SocketAddress dest;
AddressWrappedConnection(StreamConnection delegate, SocketAddress source, SocketAddress dest) {
super(delegate.getIoThread());
this.delegate = delegate;
this.source = source;
this.dest = dest;
setSinkConduit(delegate.getSinkChannel().getConduit());
setSourceConduit(delegate.getSourceChannel().getConduit());
}
@Override
protected void notifyWriteClosed() {
IoUtils.safeClose(delegate.getSinkChannel());
}
@Override
protected void notifyReadClosed() {
IoUtils.safeClose(delegate.getSourceChannel());
}
@Override
public SocketAddress getPeerAddress() {
return source;
}
@Override
public SocketAddress getLocalAddress() {
return dest;
}
}
}