Skip to content

Commit

Permalink
UNDERTOW-1167 HTTP/2 does not correctly send/receive trailers
Browse files Browse the repository at this point in the history
  • Loading branch information
stuartwdouglas committed Aug 28, 2017
1 parent 1842b27 commit e014c4c
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 41 deletions.
3 changes: 3 additions & 0 deletions core/src/main/java/io/undertow/UndertowMessages.java
Expand Up @@ -557,4 +557,7 @@ public interface UndertowMessages {


@Message(id = 180, value = "PROXY protocol header exceeded max size of 107 bytes") @Message(id = 180, value = "PROXY protocol header exceeded max size of 107 bytes")
IOException headerSizeToLarge(); IOException headerSizeToLarge();

@Message(id = 181, value = "HTTP/2 trailers too large for single buffer")
RuntimeException http2TrailerToLargeForSingleBuffer();
} }
Expand Up @@ -35,8 +35,11 @@
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;


import io.undertow.client.ClientStatistics; import io.undertow.client.ClientStatistics;
import io.undertow.protocols.http2.Http2DataStreamSinkChannel;
import io.undertow.protocols.http2.Http2GoAwayStreamSourceChannel; import io.undertow.protocols.http2.Http2GoAwayStreamSourceChannel;
import io.undertow.protocols.http2.Http2PushPromiseStreamSourceChannel; import io.undertow.protocols.http2.Http2PushPromiseStreamSourceChannel;
import io.undertow.server.protocol.http.HttpAttachments;
import io.undertow.util.HeaderMap;
import io.undertow.util.HeaderValues; import io.undertow.util.HeaderValues;
import io.undertow.util.Methods; import io.undertow.util.Methods;
import io.undertow.util.NetworkUtils; import io.undertow.util.NetworkUtils;
Expand Down Expand Up @@ -200,6 +203,12 @@ public void sendRequest(ClientRequest request, ClientCallback<ClientExchange> cl
Http2ClientExchange exchange = new Http2ClientExchange(this, sinkChannel, request); Http2ClientExchange exchange = new Http2ClientExchange(this, sinkChannel, request);
currentExchanges.put(sinkChannel.getStreamId(), exchange); currentExchanges.put(sinkChannel.getStreamId(), exchange);


sinkChannel.setTrailersProducer(new Http2DataStreamSinkChannel.TrailersProducer() {
@Override
public HeaderMap getTrailers() {
return exchange.getAttachment(HttpAttachments.RESPONSE_TRAILERS);
}
});


if(clientCallback != null) { if(clientCallback != null) {
clientCallback.completed(exchange); clientCallback.completed(exchange);
Expand Down Expand Up @@ -364,6 +373,12 @@ public void handleEvent(Http2Channel channel) {
Channels.drain(result, Long.MAX_VALUE); Channels.drain(result, Long.MAX_VALUE);
return; return;
} }
((Http2StreamSourceChannel) result).setTrailersHandler(new Http2StreamSourceChannel.TrailersHandler() {
@Override
public void handleTrailers(HeaderMap headerMap) {
request.putAttachment(HttpAttachments.REQUEST_TRAILERS, headerMap);
}
});


result.addCloseTask(new ChannelListener<AbstractHttp2StreamSourceChannel>() { result.addCloseTask(new ChannelListener<AbstractHttp2StreamSourceChannel>() {
@Override @Override
Expand Down
Expand Up @@ -18,16 +18,16 @@


package io.undertow.protocols.http2; package io.undertow.protocols.http2;


import java.io.IOException; import io.undertow.UndertowMessages;
import java.nio.ByteBuffer; import io.undertow.connector.PooledByteBuffer;

import io.undertow.server.protocol.framed.SendFrameHeader;
import io.undertow.util.HeaderMap;
import io.undertow.util.ImmediatePooledByteBuffer; import io.undertow.util.ImmediatePooledByteBuffer;
import org.xnio.ChannelListener; import org.xnio.ChannelListener;
import org.xnio.ChannelListeners; import org.xnio.ChannelListeners;
import io.undertow.connector.PooledByteBuffer;


import io.undertow.server.protocol.framed.SendFrameHeader; import java.io.IOException;
import io.undertow.util.HeaderMap; import java.nio.ByteBuffer;


/** /**
* Headers channel * Headers channel
Expand All @@ -44,6 +44,7 @@ public class Http2DataStreamSinkChannel extends Http2StreamSinkChannel implement


private final int frameType; private final int frameType;
private boolean completionListenerReady; private boolean completionListenerReady;
private TrailersProducer trailersProducer;


Http2DataStreamSinkChannel(Http2Channel channel, int streamId, int frameType) { Http2DataStreamSinkChannel(Http2Channel channel, int streamId, int frameType) {
this(channel, streamId, new HeaderMap(), frameType); this(channel, streamId, new HeaderMap(), frameType);
Expand All @@ -56,6 +57,14 @@ public class Http2DataStreamSinkChannel extends Http2StreamSinkChannel implement
this.frameType = frameType; this.frameType = frameType;
} }


public TrailersProducer getTrailersProducer() {
return trailersProducer;
}

public void setTrailersProducer(TrailersProducer trailersProducer) {
this.trailersProducer = trailersProducer;
}

@Override @Override
protected SendFrameHeader createFrameHeaderImpl() { protected SendFrameHeader createFrameHeaderImpl() {
//TODO: this is a mess WRT re-using between headers and push_promise, sort out a more reasonable abstraction //TODO: this is a mess WRT re-using between headers and push_promise, sort out a more reasonable abstraction
Expand Down Expand Up @@ -85,6 +94,14 @@ protected SendFrameHeader createFrameHeaderImpl() {
ByteBuffer firstBuffer = firstHeaderBuffer.getBuffer(); ByteBuffer firstBuffer = firstHeaderBuffer.getBuffer();
boolean firstFrame = false; boolean firstFrame = false;


HeaderMap trailers = null;
if(finalFrame && this.trailersProducer != null) {
trailers = this.trailersProducer.getTrailers();
if(trailers != null && trailers.size() == 0) {
trailers = null;
}
}

if (first) { if (first) {
firstFrame = true; firstFrame = true;
first = false; first = false;
Expand All @@ -102,14 +119,14 @@ protected SendFrameHeader createFrameHeaderImpl() {
firstBuffer.put((byte) (paddingBytes & 0xFF)); firstBuffer.put((byte) (paddingBytes & 0xFF));
} }
writeBeforeHeaderBlock(firstBuffer); writeBeforeHeaderBlock(firstBuffer);

HeaderMap headers = this.headers;
HpackEncoder.State result = encoder.encode(headers, firstBuffer); HpackEncoder.State result = encoder.encode(headers, firstBuffer);
PooledByteBuffer current = firstHeaderBuffer; PooledByteBuffer current = firstHeaderBuffer;
int headerFrameLength = firstBuffer.position() - 9 + paddingBytes; int headerFrameLength = firstBuffer.position() - 9 + paddingBytes;
firstBuffer.put(0, (byte) ((headerFrameLength >> 16) & 0xFF)); firstBuffer.put(0, (byte) ((headerFrameLength >> 16) & 0xFF));
firstBuffer.put(1, (byte) ((headerFrameLength >> 8) & 0xFF)); firstBuffer.put(1, (byte) ((headerFrameLength >> 8) & 0xFF));
firstBuffer.put(2, (byte) (headerFrameLength & 0xFF)); firstBuffer.put(2, (byte) (headerFrameLength & 0xFF));
firstBuffer.put(4, (byte) ((isFinalFrameQueued() && !getBuffer().hasRemaining() && frameType == Http2Channel.FRAME_TYPE_HEADERS ? Http2Channel.HEADERS_FLAG_END_STREAM : 0) | (result == HpackEncoder.State.COMPLETE ? Http2Channel.HEADERS_FLAG_END_HEADERS : 0 ) | (paddingBytes > 0 ? Http2Channel.HEADERS_FLAG_PADDED : 0))); //flags firstBuffer.put(4, (byte) ((isFinalFrameQueued() && !getBuffer().hasRemaining() && frameType == Http2Channel.FRAME_TYPE_HEADERS && trailers == null ? Http2Channel.HEADERS_FLAG_END_STREAM : 0) | (result == HpackEncoder.State.COMPLETE ? Http2Channel.HEADERS_FLAG_END_HEADERS : 0 ) | (paddingBytes > 0 ? Http2Channel.HEADERS_FLAG_PADDED : 0))); //flags
ByteBuffer currentBuffer = firstBuffer; ByteBuffer currentBuffer = firstBuffer;


if(currentBuffer.remaining() < paddingBytes) { if(currentBuffer.remaining() < paddingBytes) {
Expand All @@ -126,30 +143,16 @@ protected SendFrameHeader createFrameHeaderImpl() {


allHeaderBuffers = allocateAll(allHeaderBuffers, current); allHeaderBuffers = allocateAll(allHeaderBuffers, current);
current = allHeaderBuffers[allHeaderBuffers.length - 1]; current = allHeaderBuffers[allHeaderBuffers.length - 1];
//continuation frame result = encodeContinuationFrame(headers, current);
//note that if the buffers are small we may not actually need a continuation here
//but it greatly reduces the code complexity
//back fill the length
currentBuffer = current.getBuffer();
currentBuffer.put((byte) 0);
currentBuffer.put((byte) 0);
currentBuffer.put((byte) 0);
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_CONTINUATION); //type
currentBuffer.put((byte) 0); //back fill the flags
Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
result = encoder.encode(headers, currentBuffer);
int contFrameLength = currentBuffer.position() - 9;
currentBuffer.put(0, (byte) ((contFrameLength >> 16) & 0xFF));
currentBuffer.put(1, (byte) ((contFrameLength >> 8) & 0xFF));
currentBuffer.put(2, (byte) (contFrameLength & 0xFF));
currentBuffer.put(4, (byte) (result == HpackEncoder.State.COMPLETE ? Http2Channel.HEADERS_FLAG_END_HEADERS : 0 )); //flags
} }
} }


PooledByteBuffer currentPooled = allHeaderBuffers == null ? firstHeaderBuffer : allHeaderBuffers[allHeaderBuffers.length - 1]; PooledByteBuffer currentPooled = allHeaderBuffers == null ? firstHeaderBuffer : allHeaderBuffers[allHeaderBuffers.length - 1];
ByteBuffer currentBuffer = currentPooled.getBuffer(); ByteBuffer currentBuffer = currentPooled.getBuffer();
ByteBuffer trailer = null; ByteBuffer trailer = null;
int remainingInBuffer = 0; int remainingInBuffer = 0;
boolean requiresTrailers = false;


if (getBuffer().remaining() > 0) { if (getBuffer().remaining() > 0) {
if (fcWindow > 0) { if (fcWindow > 0) {
Expand All @@ -168,7 +171,14 @@ protected SendFrameHeader createFrameHeaderImpl() {
currentBuffer.put((byte) ((fcWindow >> 8) & 0xFF)); currentBuffer.put((byte) ((fcWindow >> 8) & 0xFF));
currentBuffer.put((byte) (fcWindow & 0xFF)); currentBuffer.put((byte) (fcWindow & 0xFF));
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_DATA); //type currentBuffer.put((byte) Http2Channel.FRAME_TYPE_DATA); //type
currentBuffer.put((byte) ((finalFrame ? Http2Channel.DATA_FLAG_END_STREAM : 0) | (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0))); //flags if(trailers == null) {
currentBuffer.put((byte) ((finalFrame ? Http2Channel.DATA_FLAG_END_STREAM : 0) | (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0))); //flags
} else {
if(finalFrame) {
requiresTrailers = true;
}
currentBuffer.put((byte) (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0)); //flags
}
Http2ProtocolUtils.putInt(currentBuffer, getStreamId()); Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
if(dataPaddingBytes > 0) { if(dataPaddingBytes > 0) {
currentBuffer.put((byte) (dataPaddingBytes & 0xFF)); currentBuffer.put((byte) (dataPaddingBytes & 0xFF));
Expand All @@ -182,13 +192,46 @@ protected SendFrameHeader createFrameHeaderImpl() {
currentBuffer.put((byte) ((fcWindow >> 8) & 0xFF)); currentBuffer.put((byte) ((fcWindow >> 8) & 0xFF));
currentBuffer.put((byte) (fcWindow & 0xFF)); currentBuffer.put((byte) (fcWindow & 0xFF));
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_DATA); //type currentBuffer.put((byte) Http2Channel.FRAME_TYPE_DATA); //type
currentBuffer.put((byte) ((Http2Channel.HEADERS_FLAG_END_STREAM & 0xFF)| (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0))); //flags if (trailers == null) {
currentBuffer.put((byte) ((Http2Channel.HEADERS_FLAG_END_STREAM & 0xFF) | (dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0))); //flags
} else {
requiresTrailers = true;
currentBuffer.put((byte) ((dataPaddingBytes > 0 ? Http2Channel.DATA_FLAG_PADDED : 0))); //flags
}
Http2ProtocolUtils.putInt(currentBuffer, getStreamId()); Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
if(dataPaddingBytes > 0) { if (dataPaddingBytes > 0) {
currentBuffer.put((byte) (dataPaddingBytes & 0xFF)); currentBuffer.put((byte) (dataPaddingBytes & 0xFF));
trailer = ByteBuffer.allocate(dataPaddingBytes); trailer = ByteBuffer.allocate(dataPaddingBytes);
} }
} }

if (requiresTrailers) {
PooledByteBuffer firstTrailerBuffer = getChannel().getBufferPool().allocate();
if (trailer != null) {
firstTrailerBuffer.getBuffer().put(trailer);
}
firstTrailerBuffer.getBuffer().put((byte) 0);
firstTrailerBuffer.getBuffer().put((byte) 0);
firstTrailerBuffer.getBuffer().put((byte) 0);
firstTrailerBuffer.getBuffer().put((byte) Http2Channel.FRAME_TYPE_HEADERS); //type
firstTrailerBuffer.getBuffer().put((byte) (Http2Channel.HEADERS_FLAG_END_STREAM | Http2Channel.HEADERS_FLAG_END_HEADERS)); //back fill the flags

Http2ProtocolUtils.putInt(firstTrailerBuffer.getBuffer(), getStreamId());
HpackEncoder.State result = encoder.encode(trailers, firstTrailerBuffer.getBuffer());
if (result != HpackEncoder.State.COMPLETE) {
throw UndertowMessages.MESSAGES.http2TrailerToLargeForSingleBuffer();
}
int headerFrameLength = firstTrailerBuffer.getBuffer().position() - 9;
firstTrailerBuffer.getBuffer().put(0, (byte) ((headerFrameLength >> 16) & 0xFF));
firstTrailerBuffer.getBuffer().put(1, (byte) ((headerFrameLength >> 8) & 0xFF));
firstTrailerBuffer.getBuffer().put(2, (byte) (headerFrameLength & 0xFF));
firstTrailerBuffer.getBuffer().flip();
int size = firstTrailerBuffer.getBuffer().remaining();
trailer = ByteBuffer.allocate(size);
trailer.put(firstTrailerBuffer.getBuffer());
trailer.flip();
firstTrailerBuffer.close();
}
if (allHeaderBuffers == null) { if (allHeaderBuffers == null) {
//only one buffer required //only one buffer required
currentBuffer.flip(); currentBuffer.flip();
Expand Down Expand Up @@ -219,6 +262,28 @@ protected SendFrameHeader createFrameHeaderImpl() {


} }


private HpackEncoder.State encodeContinuationFrame(HeaderMap headers, PooledByteBuffer current) {
ByteBuffer currentBuffer;
HpackEncoder.State result;//continuation frame
//note that if the buffers are small we may not actually need a continuation here
//but it greatly reduces the code complexity
//back fill the length
currentBuffer = current.getBuffer();
currentBuffer.put((byte) 0);
currentBuffer.put((byte) 0);
currentBuffer.put((byte) 0);
currentBuffer.put((byte) Http2Channel.FRAME_TYPE_CONTINUATION); //type
currentBuffer.put((byte) 0); //back fill the flags
Http2ProtocolUtils.putInt(currentBuffer, getStreamId());
result = encoder.encode(headers, currentBuffer);
int contFrameLength = currentBuffer.position() - 9;
currentBuffer.put(0, (byte) ((contFrameLength >> 16) & 0xFF));
currentBuffer.put(1, (byte) ((contFrameLength >> 8) & 0xFF));
currentBuffer.put(2, (byte) (contFrameLength & 0xFF));
currentBuffer.put(4, (byte) (result == HpackEncoder.State.COMPLETE ? Http2Channel.HEADERS_FLAG_END_HEADERS : 0 )); //flags
return result;
}

@Override @Override
public boolean flush() throws IOException { public boolean flush() throws IOException {
if(completionListenerReady && completionListener != null) { if(completionListenerReady && completionListener != null) {
Expand Down Expand Up @@ -266,4 +331,8 @@ public ChannelListener<Http2DataStreamSinkChannel> getCompletionListener() {
public void setCompletionListener(ChannelListener<Http2DataStreamSinkChannel> completionListener) { public void setCompletionListener(ChannelListener<Http2DataStreamSinkChannel> completionListener) {
this.completionListener = completionListener; this.completionListener = completionListener;
} }

public interface TrailersProducer {
HeaderMap getTrailers();
}
} }
Expand Up @@ -62,6 +62,8 @@ public class Http2StreamSourceChannel extends AbstractHttp2StreamSourceChannel i


private long contentLengthRemaining; private long contentLengthRemaining;


private TrailersHandler trailersHandler;

Http2StreamSourceChannel(Http2Channel framedChannel, PooledByteBuffer data, long frameDataRemaining, HeaderMap headers, int streamId) { Http2StreamSourceChannel(Http2Channel framedChannel, PooledByteBuffer data, long frameDataRemaining, HeaderMap headers, int streamId) {
super(framedChannel, data, frameDataRemaining); super(framedChannel, data, frameDataRemaining);
this.headers = headers; this.headers = headers;
Expand Down Expand Up @@ -89,6 +91,10 @@ protected void handleHeaderData(FrameHeaderData headerData) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} }
} else if(parser instanceof Http2HeadersParser) {
if(trailersHandler != null) {
trailersHandler.handleTrailers(((Http2HeadersParser) parser).getHeaderMap());
}
} }
handleFinalFrame(data); handleFinalFrame(data);
} }
Expand Down Expand Up @@ -248,6 +254,14 @@ boolean isHeadersEndStream() {
return headersEndStream; return headersEndStream;
} }


public TrailersHandler getTrailersHandler() {
return trailersHandler;
}

public void setTrailersHandler(TrailersHandler trailersHandler) {
this.trailersHandler = trailersHandler;
}

@Override @Override
public String toString() { public String toString() {
return "Http2StreamSourceChannel{" + return "Http2StreamSourceChannel{" +
Expand All @@ -273,4 +287,8 @@ void updateContentSize(long frameLength, boolean last) {
} }
} }


public interface TrailersHandler {
void handleTrailers(HeaderMap headerMap);
}

} }
Expand Up @@ -29,6 +29,7 @@
import io.undertow.server.Connectors; import io.undertow.server.Connectors;
import io.undertow.server.HttpHandler; import io.undertow.server.HttpHandler;
import io.undertow.server.HttpServerExchange; import io.undertow.server.HttpServerExchange;
import io.undertow.server.protocol.http.HttpAttachments;
import io.undertow.util.HeaderMap; import io.undertow.util.HeaderMap;
import io.undertow.util.HeaderValues; import io.undertow.util.HeaderValues;
import io.undertow.util.Headers; import io.undertow.util.Headers;
Expand Down Expand Up @@ -139,6 +140,18 @@ private void handleRequests(Http2Channel channel, Http2StreamSourceChannel frame




final HttpServerExchange exchange = new HttpServerExchange(connection, dataChannel.getHeaders(), dataChannel.getResponseChannel().getHeaders(), maxEntitySize); final HttpServerExchange exchange = new HttpServerExchange(connection, dataChannel.getHeaders(), dataChannel.getResponseChannel().getHeaders(), maxEntitySize);
dataChannel.getResponseChannel().setTrailersProducer(new Http2DataStreamSinkChannel.TrailersProducer() {
@Override
public HeaderMap getTrailers() {
return exchange.getAttachment(HttpAttachments.RESPONSE_TRAILERS);
}
});
dataChannel.setTrailersHandler(new Http2StreamSourceChannel.TrailersHandler() {
@Override
public void handleTrailers(HeaderMap headerMap) {
exchange.putAttachment(HttpAttachments.REQUEST_TRAILERS, headerMap);
}
});
connection.setExchange(exchange); connection.setExchange(exchange);
dataChannel.setMaxStreamSize(maxEntitySize); dataChannel.setMaxStreamSize(maxEntitySize);
exchange.setRequestScheme(exchange.getRequestHeaders().getFirst(SCHEME)); exchange.setRequestScheme(exchange.getRequestHeaders().getFirst(SCHEME));
Expand Down

0 comments on commit e014c4c

Please sign in to comment.