Skip to content

Commit

Permalink
Add sslContextBuilderCustomizer(Function<SslContextBuilder, SslContex…
Browse files Browse the repository at this point in the history
…tBuilder>)

We now accept a customizer function to customize the SSL configuration through SslContextBuilder. The customizer is applied each time a connection gets established.

[resolves #152]
  • Loading branch information
mp911de committed May 5, 2020
1 parent f385598 commit 33b0223
Show file tree
Hide file tree
Showing 10 changed files with 132 additions and 43 deletions.
73 changes: 66 additions & 7 deletions src/main/java/io/r2dbc/mssql/MssqlConnectionConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,33 @@

package io.r2dbc.mssql;

import io.netty.handler.ssl.SslContextBuilder;
import io.r2dbc.mssql.client.ClientConfiguration;
import io.r2dbc.mssql.client.ssl.ExpectedHostnameX509TrustManager;
import io.r2dbc.mssql.client.ssl.TrustAllTrustManager;
import io.r2dbc.mssql.codec.DefaultCodecs;
import io.r2dbc.mssql.message.tds.Redirect;
import io.r2dbc.mssql.util.Assert;
import io.r2dbc.mssql.util.StringUtils;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.tcp.SslProvider;
import reactor.util.annotation.Nullable;

import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.security.GeneralSecurityException;
import java.security.KeyStore;
import java.time.Duration;
import java.util.Optional;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Predicate;

import static reactor.netty.tcp.SslProvider.DefaultConfigurationType.TCP;

/**
* Connection configuration information for connecting to a Microsoft SQL database.
* Allows configuration of the connection endpoint, login credentials, database and trace details such as application name and connection Id.
Expand Down Expand Up @@ -73,10 +85,13 @@ public final class MssqlConnectionConfiguration {

private final boolean ssl;

private final Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer;

private final String username;

private MssqlConnectionConfiguration(@Nullable String applicationName, @Nullable UUID connectionId, Duration connectTimeout, @Nullable String database, String host, String hostNameInCertificate
, CharSequence password, Predicate<String> preferCursoredExecution, int port, boolean sendStringParametersAsUnicode, boolean ssl, String username) {
, CharSequence password, Predicate<String> preferCursoredExecution, int port, boolean sendStringParametersAsUnicode, boolean ssl,
Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer, String username) {

this.applicationName = applicationName;
this.connectionId = connectionId;
Expand All @@ -89,6 +104,7 @@ private MssqlConnectionConfiguration(@Nullable String applicationName, @Nullable
this.port = port;
this.sendStringParametersAsUnicode = sendStringParametersAsUnicode;
this.ssl = ssl;
this.sslContextBuilderCustomizer = sslContextBuilderCustomizer;
this.username = Assert.requireNonNull(username, "username must not be null");
}

Expand Down Expand Up @@ -126,11 +142,11 @@ MssqlConnectionConfiguration withRedirect(Redirect redirect) {
}

return new MssqlConnectionConfiguration(this.applicationName, this.connectionId, this.connectTimeout, this.database, redirectServerName, hostNameInCertificate, this.password,
this.preferCursoredExecution, redirect.getPort(), this.sendStringParametersAsUnicode, this.ssl, this.username);
this.preferCursoredExecution, redirect.getPort(), this.sendStringParametersAsUnicode, this.ssl, this.sslContextBuilderCustomizer, this.username);
}

ClientConfiguration toClientConfiguration() {
return new DefaultClientConfiguration(this.connectTimeout, this.host, this.hostNameInCertificate, this.port, this.ssl);
return new DefaultClientConfiguration(this.connectTimeout, this.host, this.hostNameInCertificate, this.port, this.ssl, sslContextBuilderCustomizer);
}

ConnectionOptions toConnectionOptions() {
Expand All @@ -152,6 +168,7 @@ public String toString() {
sb.append(", port=").append(this.port);
sb.append(", sendStringParametersAsUnicode=").append(this.sendStringParametersAsUnicode);
sb.append(", ssl=").append(this.ssl);
sb.append(", sslContextBuilderCustomizer=").append(this.sslContextBuilderCustomizer);
sb.append(", username=\"").append(this.username).append('\"');
sb.append(']');
return sb.toString();
Expand Down Expand Up @@ -280,6 +297,8 @@ public static final class Builder {

private boolean ssl;

private Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer = Function.identity();

private String username;

private Builder() {
Expand Down Expand Up @@ -428,6 +447,21 @@ public Builder sendStringParametersAsUnicode(boolean sendStringParametersAsUnico
return this;
}

/**
* Configure a {@link SslContextBuilder} customizer. The customizer gets applied on each SSL connection attempt to allow for just-in-time configuration updates. The {@link Function} gets
* called with the prepared {@link SslContextBuilder} that has all configuration options applied. The customizer may return the same builder or return a new builder instance to be used to
* build the SSL context.
*
* @param sslContextBuilderCustomizer customizer function
* @return this {@link Builder}
* @throws IllegalArgumentException if {@code sslContextBuilderCustomizer} is {@code null}
* @since 0.8.3
*/
public Builder sslContextBuilderCustomizer(Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer) {
this.sslContextBuilderCustomizer = Assert.requireNonNull(sslContextBuilderCustomizer, "sslContextBuilderCustomizer must not be null");
return this;
}

/**
* Configure the username.
*
Expand All @@ -453,7 +487,7 @@ public MssqlConnectionConfiguration build() {

return new MssqlConnectionConfiguration(this.applicationName, this.connectionId, this.connectTimeout, this.database, this.host, this.hostNameInCertificate, this.password,
this.preferCursoredExecution, this.port,
this.sendStringParametersAsUnicode, this.ssl, this.username);
this.sendStringParametersAsUnicode, this.ssl, this.sslContextBuilderCustomizer, this.username);
}
}

Expand All @@ -469,13 +503,17 @@ private static class DefaultClientConfiguration implements ClientConfiguration {

private final boolean ssl;

DefaultClientConfiguration(Duration connectTimeout, String host, String hostNameInCertificate, int port, boolean ssl) {
private final Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer;

DefaultClientConfiguration(Duration connectTimeout, String host, String hostNameInCertificate, int port, boolean ssl,
Function<SslContextBuilder, SslContextBuilder> sslContextBuilderCustomizer) {

this.connectTimeout = connectTimeout;
this.host = host;
this.hostNameInCertificate = hostNameInCertificate;
this.port = port;
this.ssl = ssl;
this.sslContextBuilderCustomizer = sslContextBuilderCustomizer;
}

@Override
Expand Down Expand Up @@ -504,8 +542,29 @@ public boolean isSslEnabled() {
}

@Override
public String getHostNameInCertificate() {
return this.hostNameInCertificate;
public SslProvider getSslProvider() throws GeneralSecurityException {

SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();

TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
KeyStore ks = null;
tmf.init(ks);

TrustManager[] trustManagers = tmf.getTrustManagers();
TrustManager result;

if (isSslEnabled()) {
result = new ExpectedHostnameX509TrustManager((X509TrustManager) trustManagers[0], this.hostNameInCertificate);
} else {
result = TrustAllTrustManager.INSTANCE;
}

sslContextBuilder.trustManager(result);

return SslProvider.builder()
.sslContext(this.sslContextBuilderCustomizer.apply(sslContextBuilder))
.defaultConfiguration(TCP)
.build();
}
}
}
13 changes: 13 additions & 0 deletions src/main/java/io/r2dbc/mssql/MssqlConnectionFactoryProvider.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.r2dbc.mssql;

import io.netty.handler.ssl.SslContextBuilder;
import io.r2dbc.mssql.util.Assert;
import io.r2dbc.spi.ConnectionFactoryOptions;
import io.r2dbc.spi.ConnectionFactoryProvider;
Expand All @@ -25,6 +26,7 @@

import java.time.Duration;
import java.util.UUID;
import java.util.function.Function;
import java.util.function.Predicate;

import static io.r2dbc.spi.ConnectionFactoryOptions.CONNECT_TIMEOUT;
Expand Down Expand Up @@ -73,6 +75,13 @@ public final class MssqlConnectionFactoryProvider implements ConnectionFactoryPr
*/
public static final Option<Boolean> SEND_STRING_PARAMETERS_AS_UNICODE = Option.valueOf("sendStringParametersAsUnicode");

/**
* Customizer {@link Function} for {@link SslContextBuilder}.
*
* @since 0.8.3
*/
public static final Option<Function<SslContextBuilder, SslContextBuilder>> SSL_CONTEXT_BUILDER_CUSTOMIZER = Option.valueOf("sslContextBuilderCustomizer");

/**
* Driver option value.
*/
Expand Down Expand Up @@ -167,6 +176,10 @@ public MssqlConnectionFactory create(ConnectionFactoryOptions connectionFactoryO
builder.username(connectionFactoryOptions.getRequiredValue(USER));
builder.applicationName(connectionFactoryOptions.getRequiredValue(USER));

if (connectionFactoryOptions.hasOption(SSL_CONTEXT_BUILDER_CUSTOMIZER)) {
builder.sslContextBuilderCustomizer(connectionFactoryOptions.getRequiredValue(SSL_CONTEXT_BUILDER_CUSTOMIZER));
}

MssqlConnectionConfiguration configuration = builder.build();
if (this.logger.isDebugEnabled()) {
this.logger.debug(String.format("Creating MssqlConnectionFactory with configuration [%s] from options [%s]", configuration, connectionFactoryOptions));
Expand Down
11 changes: 9 additions & 2 deletions src/main/java/io/r2dbc/mssql/client/ReactorNettyClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import io.netty.channel.ChannelPipeline;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.r2dbc.mssql.client.ssl.TdsSslHandler;
Expand Down Expand Up @@ -52,6 +53,7 @@
import reactor.core.publisher.SynchronousSink;
import reactor.netty.Connection;
import reactor.netty.resources.ConnectionProvider;
import reactor.netty.tcp.SslProvider;
import reactor.netty.tcp.TcpClient;
import reactor.util.Logger;
import reactor.util.Loggers;
Expand All @@ -70,6 +72,8 @@
import java.util.function.Predicate;
import java.util.function.Supplier;

import static reactor.netty.tcp.SslProvider.DefaultConfigurationType.TCP;

/**
* An implementation of a TDS client based on the Reactor Netty project.
*
Expand Down Expand Up @@ -383,8 +387,11 @@ public boolean isSslEnabled() {
}

@Override
public String getHostNameInCertificate() {
return host;
public SslProvider getSslProvider() {
return SslProvider.builder()
.sslContext(SslContextBuilder.forClient())
.defaultConfiguration(TCP)
.build();
}
}, null, null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
*
* @author Mark Paluch
*/
final class ExpectedHostnameX509TrustManager implements X509TrustManager {
public final class ExpectedHostnameX509TrustManager implements X509TrustManager {

private static final Logger logger = Loggers.getLogger(TdsSslHandler.class);

Expand All @@ -42,7 +42,7 @@ final class ExpectedHostnameX509TrustManager implements X509TrustManager {

private final Predicate<String> matcher;

ExpectedHostnameX509TrustManager(X509TrustManager tm, String expectedHostName) {
public ExpectedHostnameX509TrustManager(X509TrustManager tm, String expectedHostName) {

this.defaultTrustManager = tm;
this.expectedHostName = expectedHostName;
Expand Down
11 changes: 8 additions & 3 deletions src/main/java/io/r2dbc/mssql/client/ssl/SslConfiguration.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@

package io.r2dbc.mssql.client.ssl;

import reactor.netty.tcp.SslProvider;

import java.security.GeneralSecurityException;

/**
* SSL Configuration for SQL Server connections.
* <p>Microsoft SQL server supports various SSL setups:
Expand All @@ -28,7 +32,7 @@
* <p>
* Supported mode uses SSL during login to encrypt login credentials. SSL is disabled after login.
* The client supports login-time SSL even when {@link #isSslEnabled()} is {@code false}. This mode does not validate certificates.
* <p>Enabling {@link #isSslEnabled() SSL} enables also SSL certificate validation using {@link #getHostNameInCertificate()}.
* <p>Enabling {@link #isSslEnabled() SSL} enables also SSL certificate validation.
*
* @author Mark Paluch
*/
Expand All @@ -40,7 +44,8 @@ public interface SslConfiguration {
boolean isSslEnabled();

/**
* @return expected hostname in the SSL certificate.
* @return the {@link SslProvider}.
* @since 0.8.3
*/
String getHostNameInCertificate();
SslProvider getSslProvider() throws GeneralSecurityException;
}
30 changes: 4 additions & 26 deletions src/main/java/io/r2dbc/mssql/client/ssl/TdsSslHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,8 @@
import reactor.util.Loggers;
import reactor.util.annotation.Nullable;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509TrustManager;
import java.security.GeneralSecurityException;
import java.security.KeyStore;

/**
* SSL handling for TDS connections.
Expand Down Expand Up @@ -118,31 +113,14 @@ void setState(SslState state) {
* @return the configured {@link SslHandler}.
* @throws GeneralSecurityException thrown on security API errors.
*/
private static SslHandler createSslHandler(SslConfiguration sslConfiguration) throws GeneralSecurityException {
private static SslHandler createSslHandler(SslConfiguration sslConfiguration, ByteBufAllocator allocator) throws GeneralSecurityException {

TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
SSLContext sslContext = SSLContext.getInstance("TLS");
KeyStore ks = null;
tmf.init(ks);
SSLEngine sslEngine = sslConfiguration.getSslProvider().getSslContext()
.newEngine(allocator);

TrustManager[] trustManagers = tmf.getTrustManagers();
TrustManager[] tms = new TrustManager[]{getTrustManager(sslConfiguration, trustManagers[0])};
sslContext.init(null, tms, null);

SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setUseClientMode(true);
return new SslHandler(sslEngine);
}

private static TrustManager getTrustManager(SslConfiguration sslConfiguration, TrustManager trustManager) {

if (sslConfiguration.isSslEnabled()) {
return new ExpectedHostnameX509TrustManager((X509TrustManager) trustManager, sslConfiguration.getHostNameInCertificate());
}

return TrustAllTrustManager.INSTANCE;
}

/**
* Lazily register {@link SslHandler} if needed.
*
Expand All @@ -156,7 +134,7 @@ public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exc
if (evt == SslState.LOGIN_ONLY || evt == SslState.CONNECTION) {

this.state = (SslState) evt;
this.sslHandler = createSslHandler(this.sslConfiguration);
this.sslHandler = createSslHandler(this.sslConfiguration, ctx.alloc());

LOGGER.debug(this.connectionContext.getMessage("Registering Context Proxy and SSL Event Handlers to propagate SSL events to channelRead()"));
ctx.pipeline().addAfter(getClass().getName(), ContextProxy.class.getName(), new ContextProxy());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
*
* @author Mark Paluch
*/
enum TrustAllTrustManager implements X509TrustManager {
public enum TrustAllTrustManager implements X509TrustManager {

INSTANCE;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,14 @@ void constructorNoUsername() {
.withMessage("username must not be null");
}

@Test
void constructorNoSslCustomizer() {
assertThatIllegalArgumentException().isThrownBy(() -> MssqlConnectionConfiguration.builder()
.sslContextBuilderCustomizer(null)
.build())
.withMessage("sslContextBuilderCustomizer must not be null");
}

@Test
void redirect() {
MssqlConnectionConfiguration configuration = MssqlConnectionConfiguration.builder()
Expand Down
Loading

0 comments on commit 33b0223

Please sign in to comment.