diff --git a/build.sbt b/build.sbt index 82b504b17..7d98d0c3a 100644 --- a/build.sbt +++ b/build.sbt @@ -53,7 +53,10 @@ ThisBuild / libraryDependencySchemes ++= Seq( import com.typesafe.tools.mima.core._ ThisBuild / mimaBinaryIssueFilters ++= List( - ProblemFilters.exclude[DirectMissingMethodProblem]("skunk.net.BitVectorSocket.fromSocket") + ProblemFilters.exclude[DirectMissingMethodProblem]("skunk.net.BitVectorSocket.fromSocket"), + ProblemFilters.exclude[MissingTypesProblem]("skunk.net.protocol.Startup$"), + ProblemFilters.exclude[DirectMissingMethodProblem]("skunk.net.protocol.Startup.authenticationSASL"), + ProblemFilters.exclude[MissingClassProblem]("skunk.net.protocol.StartupCompanionPlatform"), ) ThisBuild / tlFatalWarnings := false @@ -128,14 +131,12 @@ lazy val core = crossProject(JVMPlatform, JSPlatform, NativePlatform) "org.typelevel" %%% "otel4s-semconv-metrics" % otel4sVersion, "org.tpolecat" %%% "sourcepos" % "1.2.0", "org.typelevel" %%% "twiddles-core" % "1.0.0", + "com.armanbilge" %%% "saslprep" % "0.1.2", ) ++ Seq( "com.beachape" %%% "enumeratum" % "1.9.0", ).filterNot(_ => tlIsScala3.value) - ).jvmSettings( - libraryDependencies += "com.ongres.scram" % "client" % "2.1", ).platformsSettings(JSPlatform, NativePlatform)( libraryDependencies ++= Seq( - "com.armanbilge" %%% "saslprep" % "0.1.2", "io.github.cquiroz" %%% "scala-java-time" % "2.6.0", "io.github.cquiroz" %%% "locales-minimal-en_us-db" % "1.5.4" ), diff --git a/modules/core/js-native/src/main/scala/protocol/StartupPlatform.scala b/modules/core/js-native/src/main/scala/protocol/StartupPlatform.scala deleted file mode 100644 index 601c1f5f7..000000000 --- a/modules/core/js-native/src/main/scala/protocol/StartupPlatform.scala +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2018-2024 by Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package skunk.net.protocol - -import cats.MonadThrow -import cats.syntax.all._ -import org.typelevel.otel4s.trace.Tracer -import skunk.net.MessageSocket -import skunk.net.message._ -import skunk.exception.{ - SCRAMProtocolException, - UnsupportedSASLMechanismsException -} - -private[protocol] trait StartupCompanionPlatform { this: Startup.type => - - private[protocol] def authenticationSASL[F[_]: MonadThrow: MessageSocket: Tracer]( - sm: StartupMessage, - password: Option[String], - mechanisms: List[String] - ): F[Unit] = - Tracer[F].span("authenticationSASL").surround { - if (mechanisms.contains(Scram.SaslMechanism)) { - for { - pw <- requirePassword[F](sm, password) - channelBinding = Scram.NoChannelBinding - clientFirstBare = Scram.clientFirstBareWithRandomNonce - _ <- send(Scram.saslInitialResponse(channelBinding, clientFirstBare)) - serverFirstBytes <- flatExpectStartup(sm) { - case AuthenticationSASLContinue(serverFirstBytes) => serverFirstBytes.pure[F] - } - serverFirst <- Scram.ServerFirst.decode(serverFirstBytes) match { - case Some(serverFirst) => serverFirst.pure[F] - case None => - new SCRAMProtocolException( - s"Failed to parse server-first-message in SASLInitialResponse: ${serverFirstBytes.toHex}." - ).raiseError[F, Scram.ServerFirst] - } - (response, expectedVerifier) = Scram.saslChallenge(pw, channelBinding, serverFirst, clientFirstBare, serverFirstBytes) - _ <- send(response) - serverFinalBytes <- flatExpectStartup(sm) { - case AuthenticationSASLFinal(serverFinalBytes) => serverFinalBytes.pure[F] - } - _ <- Scram.ServerFinal.decode(serverFinalBytes) match { - case Some(serverFinal) => - if (serverFinal.verifier == expectedVerifier) ().pure[F] - else new SCRAMProtocolException( - s"Expected verifier ${expectedVerifier.value.toHex} but received ${serverFinal.verifier.value.toHex}." - ).raiseError[F, Unit] - case None => - new SCRAMProtocolException( - s"Failed to parse server-final-message in AuthenticationSASLFinal: ${serverFinalBytes.toHex}." - ).raiseError[F, Unit] - } - _ <- flatExpectStartup(sm) { case AuthenticationOk => ().pure[F] } - } yield () - } else { - new UnsupportedSASLMechanismsException(mechanisms).raiseError[F, Unit] - } - } - -} diff --git a/modules/core/jvm/src/main/scala/net/message/ScramPlatform.scala b/modules/core/jvm/src/main/scala/net/message/ScramPlatform.scala new file mode 100644 index 000000000..163ee140a --- /dev/null +++ b/modules/core/jvm/src/main/scala/net/message/ScramPlatform.scala @@ -0,0 +1,39 @@ +// Copyright (c) 2018-2024 by Rob Norris and Contributors +// This software is licensed under the MIT License (MIT). +// For more information see LICENSE or https://opensource.org/licenses/MIT + +package skunk.net.message + +import scodec.bits.ByteVector + +import java.security.SecureRandom +import javax.crypto.{Mac, SecretKeyFactory} +import javax.crypto.spec.{PBEKeySpec, SecretKeySpec} + +private[message] trait ScramPlatform { this: Scram.type => + + def clientFirstBareWithRandomNonce: ByteVector = { + val random = new SecureRandom() + val nonceBytes = new Array[Byte](32) + random.nextBytes(nonceBytes) + val nonce = ByteVector.view(nonceBytes).toBase64 + clientFirstBareWithNonce(nonce) + } + + private[message] def HMAC(key: ByteVector, str: ByteVector): ByteVector = { + val mac = Mac.getInstance("HmacSHA256") + val keySpec = new SecretKeySpec(key.toArray, "HmacSHA256") + mac.init(keySpec) + ByteVector.view(mac.doFinal(str.toArray)) + } + + private[message] def H(input: ByteVector): ByteVector = + input.sha256 + + private[message] def Hi(str: String, salt: ByteVector, iterations: Int): ByteVector = { + val spec = new PBEKeySpec(str.toCharArray, salt.toArray, iterations, 256) + val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256") + val key = factory.generateSecret(spec) + ByteVector.view(key.getEncoded).take(32) + } +} diff --git a/modules/core/jvm/src/main/scala/net/protocol/StartupPlatform.scala b/modules/core/jvm/src/main/scala/net/protocol/StartupPlatform.scala deleted file mode 100644 index e66ba0b74..000000000 --- a/modules/core/jvm/src/main/scala/net/protocol/StartupPlatform.scala +++ /dev/null @@ -1,64 +0,0 @@ -// Copyright (c) 2018-2024 by Rob Norris and Contributors -// This software is licensed under the MIT License (MIT). -// For more information see LICENSE or https://opensource.org/licenses/MIT - -package skunk.net.protocol - -import com.ongres.scram.client.ScramClient -import com.ongres.scram.common.stringprep.StringPreparations - -import cats.MonadError -import cats.syntax.all._ -import org.typelevel.otel4s.trace.Tracer -import scala.util.control.NonFatal -import skunk.net.MessageSocket -import skunk.net.message._ -import skunk.exception.{ - SCRAMProtocolException, - UnsupportedSASLMechanismsException -} - -private[protocol] trait StartupCompanionPlatform { this: Startup.type => - - private[protocol] def authenticationSASL[F[_]: MessageSocket: Tracer]( - sm: StartupMessage, - password: Option[String], - mechanisms: List[String] - )( - implicit ev: MonadError[F, Throwable] - ): F[Unit] = - Tracer[F].span("authenticationSASL").surround { - for { - client <- { - try ScramClient. - channelBinding(ScramClient.ChannelBinding.NO). - stringPreparation(StringPreparations.SASL_PREPARATION). - selectMechanismBasedOnServerAdvertised(mechanisms.toArray: _*). - setup().pure[F] - catch { - case _: IllegalArgumentException => new UnsupportedSASLMechanismsException(mechanisms).raiseError[F, ScramClient] - case NonFatal(t) => new SCRAMProtocolException(t.getMessage).raiseError[F, ScramClient] - } - } - session = client.scramSession("*") - _ <- send(SASLInitialResponse(client.getScramMechanism.getName, bytesUtf8(session.clientFirstMessage))) - serverFirstBytes <- flatExpectStartup(sm) { - case AuthenticationSASLContinue(serverFirstBytes) => serverFirstBytes.pure[F] - } - serverFirst <- guardScramAction { - session.receiveServerFirstMessage(new String(serverFirstBytes.toArray, "UTF-8")) - } - pw <- requirePassword[F](sm, password) - clientFinal = serverFirst.clientFinalProcessor(pw) - _ <- send(SASLResponse(bytesUtf8(clientFinal.clientFinalMessage))) - serverFinalBytes <- flatExpectStartup(sm) { - case AuthenticationSASLFinal(serverFinalBytes) => serverFinalBytes.pure[F] - } - _ <- guardScramAction { - clientFinal.receiveServerFinalMessage(new String(serverFinalBytes.toArray, "UTF-8")).pure[F] - } - _ <- flatExpectStartup(sm) { case AuthenticationOk => ().pure[F] } - } yield () - } - -} diff --git a/modules/core/js-native/src/main/scala/message/Scram.scala b/modules/core/shared/src/main/scala/net/message/Scram.scala similarity index 90% rename from modules/core/js-native/src/main/scala/message/Scram.scala rename to modules/core/shared/src/main/scala/net/message/Scram.scala index 8720a9258..e0ba728f4 100644 --- a/modules/core/js-native/src/main/scala/message/Scram.scala +++ b/modules/core/shared/src/main/scala/net/message/Scram.scala @@ -10,7 +10,7 @@ import scodec.codecs.utf8 /** * Partial implementation of [RFC5802](https://tools.ietf.org/html/rfc5802), as needed by PostgreSQL. - * + * * That is, only features used by PostgreSQL are implemented -- e.g., channel binding is not supported and * optional message fields omitted by PostgreSQL are not supported. */ @@ -34,7 +34,7 @@ private[skunk] object Scram extends ScramPlatform { utf8.decodeValue(bytes.bits).toOption.flatMap { case Pattern(r, s, i) => Some(ServerFirst(r, ByteVector.fromValidBase64(s), i.toInt)) - case _ => + case _ => None } } @@ -56,7 +56,7 @@ private[skunk] object Scram extends ScramPlatform { utf8.decodeValue(bytes.bits).toOption.flatMap { case Pattern(v) => Some(ServerFinal(Verifier(ByteVector.fromValidBase64(v)))) - case _ => + case _ => None } } @@ -78,21 +78,21 @@ private[skunk] object Scram extends ScramPlatform { SASLInitialResponse(SaslMechanism, channelBinding ++ clientFirstBare) def saslChallenge( - password: String, - channelBinding: ByteVector, - serverFirst: ServerFirst, - clientFirstBare: ByteVector, + password: String, + channelBinding: ByteVector, + serverFirst: ServerFirst, + clientFirstBare: ByteVector, serverFirstBytes: ByteVector ): (SASLResponse, Verifier) = { val clientFinalMessageWithoutProof = ClientFinalWithoutProof(channelBinding.toBase64, serverFirst.nonce) - val (clientProof, expectedVerifier) = + val (clientProof, expectedVerifier) = makeClientProofAndServerSignature( - password, - serverFirst.salt, - serverFirst.iterations, - clientFirstBare, - serverFirstBytes, + password, + serverFirst.salt, + serverFirst.iterations, + clientFirstBare, + serverFirstBytes, clientFinalMessageWithoutProof.encode) (SASLResponse(clientFinalMessageWithoutProof.encodeWithProof(clientProof)), expectedVerifier) } -} +} \ No newline at end of file diff --git a/modules/core/shared/src/main/scala/net/protocol/Startup.scala b/modules/core/shared/src/main/scala/net/protocol/Startup.scala index aab57729c..81aea2aeb 100644 --- a/modules/core/shared/src/main/scala/net/protocol/Startup.scala +++ b/modules/core/shared/src/main/scala/net/protocol/Startup.scala @@ -4,7 +4,7 @@ package skunk.net.protocol -import cats.{ApplicativeError, MonadError} +import cats.{ApplicativeError, MonadError, MonadThrow} import cats.syntax.all._ import org.typelevel.otel4s.Attribute import org.typelevel.otel4s.trace.Span @@ -18,7 +18,8 @@ import skunk.exception.{ SCRAMProtocolException, StartupException, SkunkException, - UnsupportedAuthenticationSchemeException + UnsupportedAuthenticationSchemeException, + UnsupportedSASLMechanismsException } import org.typelevel.otel4s.metrics.Histogram import cats.effect.MonadCancel @@ -27,7 +28,7 @@ trait Startup[F[_]] { def apply(user: String, database: String, password: Option[String], parameters: Map[String, String]): F[Unit] } -object Startup extends StartupCompanionPlatform { +object Startup { def apply[F[_]: Exchange: MessageSocket: Tracer](opDuration: Histogram[F, Double])( implicit ev: MonadCancel[F, Throwable] @@ -93,6 +94,51 @@ object Startup extends StartupCompanionPlatform { } } + private def authenticationSASL[F[_]: MonadThrow: MessageSocket: Tracer]( + sm: StartupMessage, + password: Option[String], + mechanisms: List[String] + ): F[Unit] = + Tracer[F].span("authenticationSASL").surround { + if (mechanisms.contains(Scram.SaslMechanism)) { + for { + pw <- requirePassword[F](sm, password) + channelBinding = Scram.NoChannelBinding + clientFirstBare = Scram.clientFirstBareWithRandomNonce + _ <- send(Scram.saslInitialResponse(channelBinding, clientFirstBare)) + serverFirstBytes <- flatExpectStartup(sm) { + case AuthenticationSASLContinue(serverFirstBytes) => serverFirstBytes.pure[F] + } + serverFirst <- Scram.ServerFirst.decode(serverFirstBytes) match { + case Some(serverFirst) => serverFirst.pure[F] + case None => + new SCRAMProtocolException( + s"Failed to parse server-first-message in SASLInitialResponse: ${serverFirstBytes.toHex}." + ).raiseError[F, Scram.ServerFirst] + } + (response, expectedVerifier) = Scram.saslChallenge(pw, channelBinding, serverFirst, clientFirstBare, serverFirstBytes) + _ <- send(response) + serverFinalBytes <- flatExpectStartup(sm) { + case AuthenticationSASLFinal(serverFinalBytes) => serverFinalBytes.pure[F] + } + _ <- Scram.ServerFinal.decode(serverFinalBytes) match { + case Some(serverFinal) => + if (serverFinal.verifier == expectedVerifier) ().pure[F] + else new SCRAMProtocolException( + s"Expected verifier ${expectedVerifier.value.toHex} but received ${serverFinal.verifier.value.toHex}." + ).raiseError[F, Unit] + case None => + new SCRAMProtocolException( + s"Failed to parse server-final-message in AuthenticationSASLFinal: ${serverFinalBytes.toHex}." + ).raiseError[F, Unit] + } + _ <- flatExpectStartup(sm) { case AuthenticationOk => ().pure[F] } + } yield () + } else { + new UnsupportedSASLMechanismsException(mechanisms).raiseError[F, Unit] + } + } + private[protocol] def requirePassword[F[_]](sm: StartupMessage, password: Option[String])( implicit ev: ApplicativeError[F, Throwable] ): F[String] =