diff --git a/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java b/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java index 0ca21ac9..653f0382 100644 --- a/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java +++ b/src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java @@ -1,7 +1,6 @@ package io.r2dbc.postgresql.authentication; import com.ongres.scram.client.ScramClient; -import com.ongres.scram.common.StringPreparation; import com.ongres.scram.common.exception.ScramException; import com.ongres.scram.common.util.TlsServerEndpoint; import io.r2dbc.postgresql.client.ConnectionContext; @@ -25,6 +24,9 @@ import java.security.cert.CertificateException; import java.security.cert.X509Certificate; +import static com.ongres.scram.common.StringPreparation.POSTGRESQL_PREPARATION; +import static com.ongres.scram.common.util.TlsServerEndpoint.TLS_SERVER_END_POINT; + public class SASLAuthenticationHandler implements AuthenticationHandler { private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class); @@ -82,22 +84,16 @@ public FrontendMessage handle(AuthenticationMessage message) { } private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) { - - char[] password = new char[this.password.length()]; - for (int i = 0; i < password.length; i++) { - password[i] = this.password.charAt(i); - } - ScramClient.FinalBuildStage builder = ScramClient.builder() .advertisedMechanisms(message.getAuthenticationMechanisms()) .username(this.username) // ignored by the server, use startup message - .password(password) - .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION); + .password(password.toString().toCharArray()) + .stringPreparation(POSTGRESQL_PREPARATION); SSLSession sslSession = this.context.getSslSession(); if (sslSession != null && sslSession.isValid()) { - builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession)); + builder.channelBinding(TLS_SERVER_END_POINT, extractSslEndpoint(sslSession)); } this.scramClient = builder.build(); @@ -107,14 +103,9 @@ private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) { private static byte[] extractSslEndpoint(SSLSession sslSession) { try { - Certificate[] certificates = sslSession.getPeerCertificates(); - if (certificates != null && certificates.length > 0) { - Certificate peerCert = certificates[0]; // First certificate is the peer's certificate - if (peerCert instanceof X509Certificate) { - X509Certificate cert = (X509Certificate) peerCert; - return TlsServerEndpoint.getChannelBindingData(cert); - - } + Certificate[] certificates = sslSession.getPeerCertificates(); // First certificate is the peer's certificate + if (certificates != null && certificates.length > 0 && certificates[0] instanceof X509Certificate ) { + return TlsServerEndpoint.getChannelBindingData((X509Certificate) certificates[0]); } } catch (CertificateException | SSLException e) { LOG.debug("Cannot extract X509Certificate from SSL session", e); @@ -125,7 +116,6 @@ private static byte[] extractSslEndpoint(SSLSession sslSession) { private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) { try { this.scramClient.serverFirstMessage(ByteBufferUtils.decode(message.getData())); - return new SASLResponse(ByteBufferUtils.encode(this.scramClient.clientFinalMessage().toString())); } catch (ScramException e) { throw Exceptions.propagate(e); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java index cb6db4f2..a52384e1 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java @@ -36,6 +36,7 @@ import java.util.Collections; +import static com.ongres.scram.common.StringPreparation.POSTGRESQL_PREPARATION; import static io.r2dbc.postgresql.util.TestByteBufAllocator.TEST; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; @@ -85,6 +86,7 @@ void createAuthenticationSASL() { .advertisedMechanisms(Collections.singletonList("SCRAM-SHA-256")) .username("test-username") .password("test-password".toCharArray()) + .stringPreparation(POSTGRESQL_PREPARATION) .build(); // @formatter:off @@ -103,6 +105,12 @@ void createAuthenticationSASL() { .username("test-username") .password("test-password") .build(); + + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration) + .create() + .as(StepVerifier::create) + .expectNextCount(1) + .verifyComplete(); } @Test