Skip to content

Commit

Permalink
Merge pull request #1525 from fl4via/UNDERTOW-2323
Browse files Browse the repository at this point in the history
[UNDERTOW-2323] Fix for CVE-2023-44487 and related issues
  • Loading branch information
fl4via committed Oct 17, 2023
2 parents ef2c80b + 05d58fa commit eb372bb
Show file tree
Hide file tree
Showing 4 changed files with 1,000 additions and 16 deletions.
25 changes: 25 additions & 0 deletions core/src/main/java/io/undertow/UndertowOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,8 @@ public class UndertowOptions {
*/
public static final Option<Integer> HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = Option.simple(UndertowOptions.class, "HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS", Integer.class);

public static final int DEFAULT_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS = -1;

public static final Option<Integer> HTTP2_SETTINGS_INITIAL_WINDOW_SIZE = Option.simple(UndertowOptions.class, "HTTP2_SETTINGS_INITIAL_WINDOW_SIZE", Integer.class);
public static final Option<Integer> HTTP2_SETTINGS_MAX_FRAME_SIZE = Option.simple(UndertowOptions.class, "HTTP2_SETTINGS_MAX_FRAME_SIZE", Integer.class);

Expand Down Expand Up @@ -397,6 +399,29 @@ public class UndertowOptions {
*/
public static final Option<Boolean> TRACK_ACTIVE_REQUESTS = Option.simple(UndertowOptions.class, "TRACK_ACTIVE_REQUESTS", Boolean.class);

/**
* Default value of {@link #RST_FRAMES_TIME_WINDOW} option.
*/
public static final int DEFAULT_RST_FRAMES_TIME_WINDOW = 30000;
/**
* Default value of {@link #MAX_RST_FRAMES_PER_WINDOW} option.
*/
public static final int DEFAULT_MAX_RST_FRAMES_PER_WINDOW = 200;

/**
* Window of time per which the number of HTTP2 RST received frames is measured, in milliseconds.
* If a number of RST frames bigger than {@link #MAX_RST_FRAMES_PER_WINDOW} is received during this time window,
* the server will send a GO_AWAY frame with error code 11 ({@code ENHANCE_YOUR_CALM}) and it will close the connection.
*/
public static final Option<Integer> RST_FRAMES_TIME_WINDOW = Option.simple(UndertowOptions.class, "MAX_RST_STREAM_TIME_WINDOW", Integer.class);

/**
* Maximum number of HTTP2 RST frames received allowed during a time window.
* If a number of RST frames bigger than this limit is received during {@link #RST_FRAMES_TIME_WINDOW} milliseconds,
* the server will send a GO_AWAY frame with error code 11 ({@code ENHANCE_YOUR_CALM}) and it will close the connection.
*/
public static final Option<Integer> MAX_RST_FRAMES_PER_WINDOW = Option.simple(UndertowOptions.class, "MAX_RST_STREAMS_PER_TIME_WINDOW", Integer.class);

private UndertowOptions() {

}
Expand Down
138 changes: 122 additions & 16 deletions core/src/main/java/io/undertow/protocols/http2/Http2Channel.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import org.xnio.channels.StreamSinkChannel;
import org.xnio.ssl.SslConnection;

import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channel;
Expand All @@ -52,13 +53,15 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicIntegerFieldUpdater;
import javax.net.ssl.SSLSession;

/**
* HTTP2 channel.
Expand Down Expand Up @@ -121,15 +124,16 @@ public class Http2Channel extends AbstractFramedChannel<Http2Channel, AbstractHt

public static final int DEFAULT_INITIAL_WINDOW_SIZE = 65535;

public static final int DEFAULT_MAX_CONCURRENT_STREAMS = -1;

static final byte[] PREFACE_BYTES = {
0x50, 0x52, 0x49, 0x20, 0x2a, 0x20, 0x48, 0x54,
0x54, 0x50, 0x2f, 0x32, 0x2e, 0x30, 0x0d, 0x0a,
0x0d, 0x0a, 0x53, 0x4d, 0x0d, 0x0a, 0x0d, 0x0a};
public static final int DEFAULT_MAX_FRAME_SIZE = 16384;
public static final int MAX_FRAME_SIZE = 16777215;
public static final int FLOW_CONTROL_MIN_WINDOW = 2;
// the time a discarded stream is kept at the stream cache before actually being discarded
// (used for handling responses to streams that were closed via rst)
private static final int STREAM_CACHE_EVICTION_TIME_MS = 60000;


private Http2FrameHeaderParser frameParser;
Expand All @@ -150,6 +154,16 @@ public class Http2Channel extends AbstractFramedChannel<Http2Channel, AbstractHt
private final int maxHeaders;
private final int maxHeaderListSize;

// the max number of rst frames received per window
private final int maxRstFramesPerWindow;
// the time window for counting rst frames received
private final long rstFramesTimeWindow;
// the time in milliseconds the last rst frame was received
private long lastRstFrameMillis = System.currentTimeMillis();
// the total number of received rst frames during current time windows
private int receivedRstFramesPerWindow;


private static final AtomicIntegerFieldUpdater<Http2Channel> sendConcurrentStreamsAtomicUpdater = AtomicIntegerFieldUpdater.newUpdater(
Http2Channel.class, "sendConcurrentStreams");

Expand Down Expand Up @@ -200,6 +214,8 @@ public class Http2Channel extends AbstractFramedChannel<Http2Channel, AbstractHt
*/
private volatile int receiveWindowSize;

private final StreamCache sentRstStreams = new StreamCache();


public Http2Channel(StreamConnection connectedStreamChannel, String protocol, ByteBufferPool bufferPool, PooledByteBuffer data, boolean clientSide, boolean fromUpgrade, OptionMap settings) {
this(connectedStreamChannel, protocol, bufferPool, data, clientSide, fromUpgrade, true, null, settings);
Expand All @@ -216,7 +232,7 @@ public Http2Channel(StreamConnection connectedStreamChannel, String protocol, By
pushEnabled = settings.get(UndertowOptions.HTTP2_SETTINGS_ENABLE_PUSH, true);
this.initialReceiveWindowSize = settings.get(UndertowOptions.HTTP2_SETTINGS_INITIAL_WINDOW_SIZE, DEFAULT_INITIAL_WINDOW_SIZE);
this.receiveWindowSize = initialReceiveWindowSize;
this.receiveMaxConcurrentStreams = settings.get(UndertowOptions.HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, DEFAULT_MAX_CONCURRENT_STREAMS);
this.receiveMaxConcurrentStreams = settings.get(UndertowOptions.HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, UndertowOptions.DEFAULT_HTTP2_SETTINGS_MAX_CONCURRENT_STREAMS);

this.protocol = protocol == null ? Http2OpenListener.HTTP2 : protocol;
this.maxHeaders = settings.get(UndertowOptions.MAX_HEADERS, clientSide ? -1 : UndertowOptions.DEFAULT_MAX_HEADERS);
Expand All @@ -230,6 +246,8 @@ public Http2Channel(StreamConnection connectedStreamChannel, String protocol, By
} else {
paddingRandom = null;
}
maxRstFramesPerWindow = settings.get(UndertowOptions.MAX_RST_FRAMES_PER_WINDOW, settings.get(UndertowOptions.MAX_RST_FRAMES_PER_WINDOW, UndertowOptions.DEFAULT_MAX_RST_FRAMES_PER_WINDOW));
rstFramesTimeWindow = settings.get(UndertowOptions.RST_FRAMES_TIME_WINDOW, settings.get(UndertowOptions.RST_FRAMES_TIME_WINDOW, UndertowOptions.DEFAULT_RST_FRAMES_TIME_WINDOW));

this.decoder = new HpackDecoder(encoderHeaderTableSize);
this.encoder = new HpackEncoder(encoderHeaderTableSize);
Expand Down Expand Up @@ -392,8 +410,13 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra
//this is an existing stream
//make sure it exists
StreamHolder existing = currentStreams.get(frameParser.streamId);
if (existing == null) {
existing = sentRstStreams.find(frameParser.streamId);
}
if(existing == null || existing.sourceClosed) {
sendGoAway(ERROR_PROTOCOL_ERROR);
if (existing != null || sentRstStreams.find(frameParser.streamId) == null) {
sendGoAway(ERROR_PROTOCOL_ERROR);
}
frameData.close();
return null;
} else if (existing.sourceChannel != null ){
Expand Down Expand Up @@ -425,8 +448,13 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra

StreamHolder holder = currentStreams.get(frameParser.streamId);
if(holder == null) {
receiveConcurrentStreamsAtomicUpdater.getAndIncrement(this);
currentStreams.put(frameParser.streamId, holder = new StreamHolder((Http2StreamSourceChannel) channel));
holder = sentRstStreams.find(frameParser.streamId);
if (holder != null) {
holder.sourceChannel = (Http2StreamSourceChannel) channel;
} else {
receiveConcurrentStreamsAtomicUpdater.getAndIncrement(this);
currentStreams.put(frameParser.streamId, holder = new StreamHolder((Http2StreamSourceChannel) channel));
}
} else {
holder.sourceChannel = (Http2StreamSourceChannel) channel;
}
Expand Down Expand Up @@ -463,7 +491,7 @@ protected AbstractHttp2StreamSourceChannel createChannelImpl(FrameHeaderData fra
throw new ConnectionErrorException(Http2Channel.ERROR_PROTOCOL_ERROR, UndertowMessages.MESSAGES.streamIdMustNotBeZeroForFrameType(FRAME_TYPE_RST_STREAM));
}
channel = new Http2RstStreamStreamSourceChannel(this, frameData, parser.getErrorCode(), frameParser.streamId);
handleRstStream(frameParser.streamId);
handleRstStream(frameParser.streamId, true);
if(isIdle(frameParser.streamId)) {
sendGoAway(ERROR_PROTOCOL_ERROR);
}
Expand Down Expand Up @@ -635,8 +663,12 @@ protected void handleBrokenSinkChannel(Throwable e) {

@Override
protected void closeSubChannels() {
closeSubChannels(currentStreams);
closeSubChannels(sentRstStreams.getStreamHolders());
}

for (Map.Entry<Integer, StreamHolder> e : currentStreams.entrySet()) {
private void closeSubChannels(Map<Integer, StreamHolder> streams) {
for (Map.Entry<Integer, StreamHolder> e : streams.entrySet()) {
StreamHolder holder = e.getValue();
AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel;
if(receiver != null) {
Expand Down Expand Up @@ -765,7 +797,7 @@ public void handleWindowUpdate(int streamId, int deltaWindowSize) throws IOExcep
StreamHolder holder = currentStreams.get(streamId);
Http2StreamSinkChannel stream = holder != null ? holder.sinkChannel : null;
if (stream == null) {
if(isIdle(streamId)) {
if (sentRstStreams.find(streamId) == null && isIdle(streamId)) {
sendGoAway(ERROR_PROTOCOL_ERROR);
}
} else {
Expand Down Expand Up @@ -1117,16 +1149,16 @@ public void sendRstStream(int streamId, int statusCode) {
//no point sending if the channel is closed
return;
}
handleRstStream(streamId);
sentRstStreams.store(streamId, handleRstStream(streamId, false));
if(UndertowLogger.REQUEST_IO_LOGGER.isDebugEnabled()) {
UndertowLogger.REQUEST_IO_LOGGER.debugf(new ClosedChannelException(), "Sending rststream on channel %s stream %s", this, streamId);
}
Http2RstStreamSinkChannel channel = new Http2RstStreamSinkChannel(this, streamId, statusCode);
flushChannelIgnoreFailure(channel);
}

private void handleRstStream(int streamId) {
StreamHolder holder = currentStreams.remove(streamId);
private StreamHolder handleRstStream(int streamId, boolean receivedRst) {
final StreamHolder holder = currentStreams.remove(streamId);
if(holder != null) {
if(streamId % 2 == (isClient() ? 1 : 0)) {
sendConcurrentStreamsAtomicUpdater.getAndDecrement(this);
Expand All @@ -1139,7 +1171,23 @@ private void handleRstStream(int streamId) {
if (holder.sourceChannel != null) {
holder.sourceChannel.rstStream();
}
if (receivedRst) {
long currentTimeMillis = System.currentTimeMillis();
// reset the window tracking
if (currentTimeMillis - lastRstFrameMillis >= rstFramesTimeWindow) {
lastRstFrameMillis = currentTimeMillis;
receivedRstFramesPerWindow = 1;
} else {
//
receivedRstFramesPerWindow ++;
if (receivedRstFramesPerWindow > maxRstFramesPerWindow) {
sendGoAway(Http2Channel.ERROR_ENHANCE_YOUR_CALM);
UndertowLogger.REQUEST_IO_LOGGER.debugf("Reached maximum number of rst frames %s during %s ms, sending GO_AWAY 11", maxRstFramesPerWindow, rstFramesTimeWindow);
}
}
}
}
return holder;
}

/**
Expand Down Expand Up @@ -1175,8 +1223,9 @@ public boolean isThisGoneAway() {

Http2StreamSourceChannel removeStreamSource(int streamId) {
StreamHolder existing = currentStreams.get(streamId);
if(existing == null){
return null;
if (existing == null) {
existing = sentRstStreams.find(streamId);
return existing == null? null : existing.sourceChannel;
}
existing.sourceClosed = true;
Http2StreamSourceChannel ret = existing.sourceChannel;
Expand All @@ -1195,7 +1244,10 @@ Http2StreamSourceChannel removeStreamSource(int streamId) {
Http2StreamSourceChannel getIncomingStream(int streamId) {
StreamHolder existing = currentStreams.get(streamId);
if(existing == null){
return null;
existing = sentRstStreams.find(streamId);
if (existing == null) {
return null;
}
}
return existing.sourceChannel;
}
Expand Down Expand Up @@ -1248,4 +1300,58 @@ private static final class StreamHolder {
this.sinkChannel = sinkChannel;
}
}

// cache that keeps track of streams until they can be evicted @see Http2Channel#RST_STREAM_EVICATION_TIME
private static final class StreamCache {
private Map<Integer, StreamHolder> streamHolders = new ConcurrentHashMap<>();
// entries are sorted per creation time
private Queue<StreamCacheEntry> entries = new ConcurrentLinkedQueue<>();

private void store(int streamId, StreamHolder streamHolder) {
if (streamHolder == null) {
return;
}
streamHolders.put(streamId, streamHolder);
entries.add(new StreamCacheEntry(streamId));
}
private StreamHolder find(int streamId) {
for (Iterator<StreamCacheEntry> iterator = entries.iterator(); iterator.hasNext();) {
StreamCacheEntry entry = iterator.next();
if (entry.shouldEvict()) {
iterator.remove();
StreamHolder holder = streamHolders.remove(entry.streamId);
AbstractHttp2StreamSourceChannel receiver = holder.sourceChannel;
if(receiver != null) {
IoUtils.safeClose(receiver);
}
Http2StreamSinkChannel sink = holder.sinkChannel;
if(sink != null) {
if (sink.isWritesShutdown()) {
ChannelListeners.invokeChannelListener(sink.getIoThread(), sink, ((ChannelListener.SimpleSetter) sink.getWriteSetter()).get());
}
IoUtils.safeClose(sink);
}
} else break;
}
return streamHolders.get(streamId);
}

private Map<Integer, StreamHolder> getStreamHolders() {
return streamHolders;
}
}

private static class StreamCacheEntry {
int streamId;
long time;

StreamCacheEntry(int streamId) {
this.streamId = streamId;
this.time = System.currentTimeMillis();
}

public boolean shouldEvict() {
return System.currentTimeMillis() - time > STREAM_CACHE_EVICTION_TIME_MS;
}
}
}

0 comments on commit eb372bb

Please sign in to comment.