Skip to content

Commit

Permalink
[UNDERTOW-2327] At Http2Channel, store the rst streams in a cache so …
Browse files Browse the repository at this point in the history
…responses to rst streams can be handled correctly.

The cache is cleaned after a while (current default value is set to 1 minute). If, during that time, a response to a canceled request stream is received from the server, the channel will be able to detect it is not a protocol error but just a matter of timing: the server responded the request before receiving and processing the rst frame

Signed-off-by: Flavia Rainone <frainone@redhat.com>
  • Loading branch information
fl4via committed Oct 16, 2023
1 parent 017b15f commit a6861ee
Showing 1 changed file with 92 additions and 11 deletions.
103 changes: 92 additions & 11 deletions core/src/main/java/io/undertow/protocols/http2/Http2Channel.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.xnio.channels.StreamSinkChannel;
import org.xnio.ssl.SslConnection;

import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
Expand All @@ -52,13 +53,15 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import javax.net.ssl.SSLSession;

/**
* HTTP2 channel.
Expand Down Expand Up @@ -128,6 +131,9 @@ public class Http2Channel extends AbstractFramedChannel<Http2Channel, AbstractHt
public static final int DEFAULT_MAX_FRAME_SIZE = 16384;
public static final int MAX_FRAME_SIZE = 16777215;
public static final int FLOW_CONTROL_MIN_WINDOW = 2;
// the time a discarded stream is kept at the stream cache before actually being discarded
// (used for handling responses to streams that were closed via rst)
private static final int STREAM_CACHE_EVICTION_TIME_MS = 60000;


private Http2FrameHeaderParser frameParser;
Expand Down Expand Up @@ -198,6 +204,8 @@ public class Http2Channel extends AbstractFramedChannel<Http2Channel, AbstractHt
*/
private volatile int receiveWindowSize;

private final StreamCache sentRstStreams = new StreamCache();


public Http2Channel(StreamConnection connectedStreamChannel, String protocol, ByteBufferPool bufferPool, PooledByteBuffer data, boolean clientSide, boolean fromUpgrade, OptionMap settings) {
this(connectedStreamChannel, protocol, bufferPool, data, clientSide, fromUpgrade, true, null, settings);
Expand Down Expand Up @@ -390,8 +398,13 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra
//this is an existing stream
//make sure it exists
StreamHolder existing = currentStreams.get(frameParser.streamId);
if (existing == null) {
existing = sentRstStreams.find(frameParser.streamId);
}
if(existing == null || existing.sourceClosed) {
sendGoAway(ERROR_PROTOCOL_ERROR);
if (existing != null || sentRstStreams.find(frameParser.streamId) == null) {
sendGoAway(ERROR_PROTOCOL_ERROR);
}
frameData.close();
return null;
} else if (existing.sourceChannel != null ){
Expand Down Expand Up @@ -423,8 +436,13 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra

StreamHolder holder = currentStreams.get(frameParser.streamId);
if(holder == null) {
receiveConcurrentStreamsAtomicUpdater.getAndIncrement(this);
currentStreams.put(frameParser.streamId, holder = new StreamHolder((Http2StreamSourceChannel) channel));
holder = sentRstStreams.find(frameParser.streamId);
if (holder != null) {
holder.sourceChannel = (Http2StreamSourceChannel) channel;
} else {
receiveConcurrentStreamsAtomicUpdater.getAndIncrement(this);
currentStreams.put(frameParser.streamId, holder = new StreamHolder((Http2StreamSourceChannel) channel));
}
} else {
holder.sourceChannel = (Http2StreamSourceChannel) channel;
}
Expand Down Expand Up @@ -633,8 +651,12 @@ protected void handleBrokenSinkChannel(Throwable e) {

@Override
protected void closeSubChannels() {
closeSubChannels(currentStreams);
closeSubChannels(sentRstStreams.getStreamHolders());
}

for (Map.Entry<Integer, StreamHolder> e : currentStreams.entrySet()) {
private void closeSubChannels(Map<Integer, StreamHolder> streams) {
for (Map.Entry<Integer, StreamHolder> e : streams.entrySet()) {
StreamHolder holder = e.getValue();
AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel;
if(receiver != null) {
Expand Down Expand Up @@ -763,7 +785,7 @@ public void handleWindowUpdate(int streamId, int deltaWindowSize) throws IOExcep
StreamHolder holder = currentStreams.get(streamId);
Http2StreamSinkChannel stream = holder != null ? holder.sinkChannel : null;
if (stream == null) {
if(isIdle(streamId)) {
if (sentRstStreams.find(streamId) == null && isIdle(streamId)) {
sendGoAway(ERROR_PROTOCOL_ERROR);
}
} else {
Expand Down Expand Up @@ -1115,15 +1137,15 @@ public void sendRstStream(int streamId, int statusCode) {
//no point sending if the channel is closed
return;
}
handleRstStream(streamId);
sentRstStreams.store(streamId, handleRstStream(streamId));
if(UndertowLogger.REQUEST_IO_LOGGER.isDebugEnabled()) {
UndertowLogger.REQUEST_IO_LOGGER.debugf(new ClosedChannelException(), "Sending rststream on channel %s stream %s", this, streamId);
}
Http2RstStreamSinkChannel channel = new Http2RstStreamSinkChannel(this, streamId, statusCode);
flushChannelIgnoreFailure(channel);
}

private void handleRstStream(int streamId) {
private StreamHolder handleRstStream(int streamId) {
StreamHolder holder = currentStreams.remove(streamId);
if(holder != null) {
if(streamId % 2 == (isClient() ? 1 : 0)) {
Expand All @@ -1138,6 +1160,7 @@ private void handleRstStream(int streamId) {
holder.sourceChannel.rstStream();
}
}
return holder;
}

/**
Expand Down Expand Up @@ -1173,8 +1196,9 @@ public boolean isThisGoneAway() {

Http2StreamSourceChannel removeStreamSource(int streamId) {
StreamHolder existing = currentStreams.get(streamId);
if(existing == null){
return null;
if (existing == null) {
existing = sentRstStreams.find(streamId);
return existing == null? null : existing.sourceChannel;
}
existing.sourceClosed = true;
Http2StreamSourceChannel ret = existing.sourceChannel;
Expand All @@ -1193,7 +1217,10 @@ Http2StreamSourceChannel removeStreamSource(int streamId) {
Http2StreamSourceChannel getIncomingStream(int streamId) {
StreamHolder existing = currentStreams.get(streamId);
if(existing == null){
return null;
existing = sentRstStreams.find(streamId);
if (existing == null) {
return null;
}
}
return existing.sourceChannel;
}
Expand Down Expand Up @@ -1246,4 +1273,58 @@ private static final class StreamHolder {
this.sinkChannel = sinkChannel;
}
}

// cache that keeps track of streams until they can be evicted @see Http2Channel#RST_STREAM_EVICATION_TIME
private static final class StreamCache {
private Map<Integer, StreamHolder> streamHolders = new ConcurrentHashMap<>();
// entries are sorted per creation time
private Queue<StreamCacheEntry> entries = new ConcurrentLinkedQueue<>();

private void store(int streamId, StreamHolder streamHolder) {
if (streamHolder == null) {
return;
}
streamHolders.put(streamId, streamHolder);
entries.add(new StreamCacheEntry(streamId));
}
private StreamHolder find(int streamId) {
for (Iterator<StreamCacheEntry> iterator = entries.iterator(); iterator.hasNext();) {
StreamCacheEntry entry = iterator.next();
if (entry.shouldEvict()) {
iterator.remove();
StreamHolder holder = streamHolders.remove(entry.streamId);
AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel;
if(receiver != null) {
IoUtils.safeClose(receiver);
}
Http2StreamSinkChannel sink = holder.sinkChannel;
if(sink != null) {
if (sink.isWritesShutdown()) {
ChannelListeners.invokeChannelListener(sink.getIoThread(), sink, ((ChannelListener.SimpleSetter) sink.getWriteSetter()).get());
}
IoUtils.safeClose(sink);
}
} else break;
}
return streamHolders.get(streamId);
}

private Map<Integer, StreamHolder> getStreamHolders() {
return streamHolders;
}
}

private static class StreamCacheEntry {
int streamId;
long time;

StreamCacheEntry(int streamId) {
this.streamId = streamId;
this.time = System.currentTimeMillis();
}

public boolean shouldEvict() {
return System.currentTimeMillis() - time > STREAM_CACHE_EVICTION_TIME_MS;
}
}
}

0 comments on commit a6861ee

Please sign in to comment.