Skip to content

Commit

Permalink
Ensure handshake timeout is applied in case of SNI (#2840)
Browse files Browse the repository at this point in the history
  • Loading branch information
violetagg committed Jun 21, 2023
1 parent 9856731 commit 877e20c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2022 VMware, Inc. or its affiliates, All Rights Reserved.
* Copyright (c) 2020-2023 VMware, Inc. or its affiliates, All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,21 +57,24 @@ void addSniHandler(Channel channel, boolean sslDebug) {
SslProvider.addSslReadHandler(pipeline, sslDebug);
}

final long handshakeTimeoutMillis;
final AsyncMapping<String, SslProvider> mappings;

SniProvider(AsyncMapping<String, SslProvider> mappings) {
SniProvider(AsyncMapping<String, SslProvider> mappings, long handshakeTimeoutMillis) {
this.mappings = mappings;
this.handshakeTimeoutMillis = handshakeTimeoutMillis;
}

SniProvider(Map<String, SslProvider> confPerDomainName, SslProvider defaultSslProvider) {
DomainWildcardMappingBuilder<SslProvider> mappingsSslProviderBuilder =
new DomainWildcardMappingBuilder<>(defaultSslProvider);
confPerDomainName.forEach(mappingsSslProviderBuilder::add);
this.mappings = new AsyncMappingAdapter(mappingsSslProviderBuilder.build());
this.handshakeTimeoutMillis = defaultSslProvider.handshakeTimeoutMillis;
}

SniHandler newSniHandler() {
return new SniHandler(mappings);
return new SniHandler(mappings, handshakeTimeoutMillis);
}

static final class AsyncMappingAdapter implements AsyncMapping<String, SslProvider> {
Expand All @@ -97,7 +100,8 @@ static final class SniHandler extends AbstractSniHandler<SslProvider> {

final AsyncMapping<String, SslProvider> mappings;

SniHandler(AsyncMapping<String, SslProvider> mappings) {
SniHandler(AsyncMapping<String, SslProvider> mappings, long handshakeTimeoutMillis) {
super(handshakeTimeoutMillis);
this.mappings = mappings;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ else if (builder.protocolSslContextSpec != null) {
}
}
else if (sniMappings != null) {
this.sniProvider = new SniProvider(sniMappings);
this.sniProvider = new SniProvider(sniMappings, builder.handshakeTimeoutMillis);
}
else {
this.sniProvider = null;
Expand Down Expand Up @@ -474,7 +474,7 @@ else if (sniMappings != null) {
this.sniProvider = updateAllSslProviderConfiguration(confPerDomainName, this, type);
}
else {
this.sniProvider = new SniProvider(sniMappings);
this.sniProvider = new SniProvider(sniMappings, from.handshakeTimeoutMillis);
}
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.channel.FixedRecvByteBufAllocator;
import io.netty.channel.group.ChannelGroup;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.channel.group.DefaultChannelGroup;
Expand Down Expand Up @@ -94,6 +95,7 @@
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.SslHandshakeTimeoutException;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.ssl.util.SelfSignedCertificate;
import io.netty.util.AttributeKey;
Expand Down Expand Up @@ -2209,6 +2211,51 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
assertThat(hostname.get()).isEqualTo("test.com");
}

@Test
void testSniSupportHandshakeTimeout() {
Http11SslContextSpec defaultSslContextBuilder =
Http11SslContextSpec.forServer(ssc.certificate(), ssc.privateKey());

Http11SslContextSpec clientSslContextBuilder =
Http11SslContextSpec.forClient()
.configure(builder -> builder.trustManager(InsecureTrustManagerFactory.INSTANCE));

AtomicReference<Throwable> error = new AtomicReference<>();
disposableServer =
createServer()
.childOption(ChannelOption.RCVBUF_ALLOCATOR, new FixedRecvByteBufAllocator(64))
.secure(spec -> spec.sslContext(defaultSslContextBuilder)
.handshakeTimeout(Duration.ofMillis(1))
.addSniMapping("*.test.com", domainSpec -> domainSpec.sslContext(defaultSslContextBuilder)))
.doOnChannelInit((obs, ch, addr) ->
ch.pipeline().addBefore(NettyPipeline.ReactiveBridge, "testSniSupportHandshakeTimeout",
new ChannelInboundHandlerAdapter() {

@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
if (evt instanceof SniCompletionEvent) {
error.set(((SniCompletionEvent) evt).cause());
}
ctx.fireUserEventTriggered(evt);
}
}))
.handle((req, res) -> res.sendString(Mono.just("testSniSupport")))
.bindNow();

createClient(disposableServer::address)
.secure(spec -> spec.sslContext(clientSslContextBuilder)
.serverNames(new SNIHostName("test.com")))
.get()
.uri("/")
.responseContent()
.aggregate()
.as(StepVerifier::create)
.expectError()
.verify(Duration.ofSeconds(5));

assertThat(error.get()).isNotNull().isInstanceOf(SslHandshakeTimeoutException.class);
}

@Test
void testIssue1286_HTTP11() throws Exception {
doTestIssue1286(Function.identity(), Function.identity(), false, false);
Expand Down

0 comments on commit 877e20c

Please sign in to comment.