Skip to content

Commit

Permalink
SASL test verifies connection
Browse files Browse the repository at this point in the history
  • Loading branch information
reneleonhardt committed May 26, 2024
1 parent e44091e commit 3b5a7f0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package io.r2dbc.postgresql.authentication;

import com.ongres.scram.client.ScramClient;
import com.ongres.scram.common.StringPreparation;
import com.ongres.scram.common.exception.ScramException;
import com.ongres.scram.common.util.TlsServerEndpoint;
import io.r2dbc.postgresql.client.ConnectionContext;
Expand All @@ -25,6 +24,9 @@
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;

import static com.ongres.scram.common.StringPreparation.POSTGRESQL_PREPARATION;
import static com.ongres.scram.common.util.TlsServerEndpoint.TLS_SERVER_END_POINT;

public class SASLAuthenticationHandler implements AuthenticationHandler {

private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class);
Expand Down Expand Up @@ -82,22 +84,16 @@ public FrontendMessage handle(AuthenticationMessage message) {
}

private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {

char[] password = new char[this.password.length()];
for (int i = 0; i < password.length; i++) {
password[i] = this.password.charAt(i);
}

ScramClient.FinalBuildStage builder = ScramClient.builder()
.advertisedMechanisms(message.getAuthenticationMechanisms())
.username(this.username) // ignored by the server, use startup message
.password(password)
.stringPreparation(StringPreparation.POSTGRESQL_PREPARATION);
.password(password.toString().toCharArray())
.stringPreparation(POSTGRESQL_PREPARATION);

SSLSession sslSession = this.context.getSslSession();

if (sslSession != null && sslSession.isValid()) {
builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession));
builder.channelBinding(TLS_SERVER_END_POINT, extractSslEndpoint(sslSession));
}

this.scramClient = builder.build();
Expand All @@ -107,14 +103,9 @@ private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {

private static byte[] extractSslEndpoint(SSLSession sslSession) {
try {
Certificate[] certificates = sslSession.getPeerCertificates();
if (certificates != null && certificates.length > 0) {
Certificate peerCert = certificates[0]; // First certificate is the peer's certificate
if (peerCert instanceof X509Certificate) {
X509Certificate cert = (X509Certificate) peerCert;
return TlsServerEndpoint.getChannelBindingData(cert);

}
Certificate[] certificates = sslSession.getPeerCertificates(); // First certificate is the peer's certificate
if (certificates != null && certificates.length > 0 && certificates[0] instanceof X509Certificate ) {
return TlsServerEndpoint.getChannelBindingData((X509Certificate) certificates[0]);
}
} catch (CertificateException | SSLException e) {
LOG.debug("Cannot extract X509Certificate from SSL session", e);
Expand All @@ -125,7 +116,6 @@ private static byte[] extractSslEndpoint(SSLSession sslSession) {
private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
try {
this.scramClient.serverFirstMessage(ByteBufferUtils.decode(message.getData()));

return new SASLResponse(ByteBufferUtils.encode(this.scramClient.clientFinalMessage().toString()));
} catch (ScramException e) {
throw Exceptions.propagate(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

import java.util.Collections;

import static com.ongres.scram.common.StringPreparation.POSTGRESQL_PREPARATION;
import static io.r2dbc.postgresql.util.TestByteBufAllocator.TEST;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
Expand Down Expand Up @@ -85,6 +86,7 @@ void createAuthenticationSASL() {
.advertisedMechanisms(Collections.singletonList("SCRAM-SHA-256"))
.username("test-username")
.password("test-password".toCharArray())
.stringPreparation(POSTGRESQL_PREPARATION)
.build();

// @formatter:off
Expand All @@ -103,6 +105,12 @@ void createAuthenticationSASL() {
.username("test-username")
.password("test-password")
.build();

new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration)
.create()
.as(StepVerifier::create)
.expectNextCount(1)
.verifyComplete();
}

@Test
Expand Down

0 comments on commit 3b5a7f0

Please sign in to comment.