Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ ThisBuild / libraryDependencySchemes ++= Seq(

import com.typesafe.tools.mima.core._
ThisBuild / mimaBinaryIssueFilters ++= List(
ProblemFilters.exclude[DirectMissingMethodProblem]("skunk.net.BitVectorSocket.fromSocket")
ProblemFilters.exclude[DirectMissingMethodProblem]("skunk.net.BitVectorSocket.fromSocket"),
ProblemFilters.exclude[MissingTypesProblem]("skunk.net.protocol.Startup$"),
ProblemFilters.exclude[DirectMissingMethodProblem]("skunk.net.protocol.Startup.authenticationSASL"),
ProblemFilters.exclude[MissingClassProblem]("skunk.net.protocol.StartupCompanionPlatform"),
)

ThisBuild / tlFatalWarnings := false
Expand Down Expand Up @@ -128,14 +131,12 @@ lazy val core = crossProject(JVMPlatform, JSPlatform, NativePlatform)
"org.typelevel" %%% "otel4s-semconv-metrics" % otel4sVersion,
"org.tpolecat" %%% "sourcepos" % "1.2.0",
"org.typelevel" %%% "twiddles-core" % "1.0.0",
"com.armanbilge" %%% "saslprep" % "0.1.2",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not urgent, but since it's becoming a core depedency we could move this under org.typelevel and 1.0? I can do that.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would be great!

) ++ Seq(
"com.beachape" %%% "enumeratum" % "1.9.0",
).filterNot(_ => tlIsScala3.value)
).jvmSettings(
libraryDependencies += "com.ongres.scram" % "client" % "2.1",
).platformsSettings(JSPlatform, NativePlatform)(
libraryDependencies ++= Seq(
"com.armanbilge" %%% "saslprep" % "0.1.2",
"io.github.cquiroz" %%% "scala-java-time" % "2.6.0",
"io.github.cquiroz" %%% "locales-minimal-en_us-db" % "1.5.4"
),
Expand Down

This file was deleted.

39 changes: 39 additions & 0 deletions modules/core/jvm/src/main/scala/net/message/ScramPlatform.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2018-2024 by Rob Norris and Contributors
// This software is licensed under the MIT License (MIT).
// For more information see LICENSE or https://opensource.org/licenses/MIT

package skunk.net.message

import scodec.bits.ByteVector

import java.security.SecureRandom
import javax.crypto.{Mac, SecretKeyFactory}
import javax.crypto.spec.{PBEKeySpec, SecretKeySpec}

private[message] trait ScramPlatform { this: Scram.type =>

def clientFirstBareWithRandomNonce: ByteVector = {
val random = new SecureRandom()
val nonceBytes = new Array[Byte](32)
random.nextBytes(nonceBytes)
val nonce = ByteVector.view(nonceBytes).toBase64
clientFirstBareWithNonce(nonce)
}

private[message] def HMAC(key: ByteVector, str: ByteVector): ByteVector = {
val mac = Mac.getInstance("HmacSHA256")
val keySpec = new SecretKeySpec(key.toArray, "HmacSHA256")
mac.init(keySpec)
ByteVector.view(mac.doFinal(str.toArray))
}

private[message] def H(input: ByteVector): ByteVector =
input.sha256

private[message] def Hi(str: String, salt: ByteVector, iterations: Int): ByteVector = {
val spec = new PBEKeySpec(str.toCharArray, salt.toArray, iterations, 256)
val factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256")
val key = factory.generateSecret(spec)
ByteVector.view(key.getEncoded).take(32)
}
}
64 changes: 0 additions & 64 deletions modules/core/jvm/src/main/scala/net/protocol/StartupPlatform.scala

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import scodec.codecs.utf8

/**
* Partial implementation of [RFC5802](https://tools.ietf.org/html/rfc5802), as needed by PostgreSQL.
*
*
* That is, only features used by PostgreSQL are implemented -- e.g., channel binding is not supported and
* optional message fields omitted by PostgreSQL are not supported.
*/
Expand All @@ -34,7 +34,7 @@ private[skunk] object Scram extends ScramPlatform {
utf8.decodeValue(bytes.bits).toOption.flatMap {
case Pattern(r, s, i) =>
Some(ServerFirst(r, ByteVector.fromValidBase64(s), i.toInt))
case _ =>
case _ =>
None
}
}
Expand All @@ -56,7 +56,7 @@ private[skunk] object Scram extends ScramPlatform {
utf8.decodeValue(bytes.bits).toOption.flatMap {
case Pattern(v) =>
Some(ServerFinal(Verifier(ByteVector.fromValidBase64(v))))
case _ =>
case _ =>
None
}
}
Expand All @@ -78,21 +78,21 @@ private[skunk] object Scram extends ScramPlatform {
SASLInitialResponse(SaslMechanism, channelBinding ++ clientFirstBare)

def saslChallenge(
password: String,
channelBinding: ByteVector,
serverFirst: ServerFirst,
clientFirstBare: ByteVector,
password: String,
channelBinding: ByteVector,
serverFirst: ServerFirst,
clientFirstBare: ByteVector,
serverFirstBytes: ByteVector
): (SASLResponse, Verifier) = {
val clientFinalMessageWithoutProof = ClientFinalWithoutProof(channelBinding.toBase64, serverFirst.nonce)
val (clientProof, expectedVerifier) =
val (clientProof, expectedVerifier) =
makeClientProofAndServerSignature(
password,
serverFirst.salt,
serverFirst.iterations,
clientFirstBare,
serverFirstBytes,
password,
serverFirst.salt,
serverFirst.iterations,
clientFirstBare,
serverFirstBytes,
clientFinalMessageWithoutProof.encode)
(SASLResponse(clientFinalMessageWithoutProof.encodeWithProof(clientProof)), expectedVerifier)
}
}
}
52 changes: 49 additions & 3 deletions modules/core/shared/src/main/scala/net/protocol/Startup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

package skunk.net.protocol

import cats.{ApplicativeError, MonadError}
import cats.{ApplicativeError, MonadError, MonadThrow}
import cats.syntax.all._
import org.typelevel.otel4s.Attribute
import org.typelevel.otel4s.trace.Span
Expand All @@ -18,7 +18,8 @@ import skunk.exception.{
SCRAMProtocolException,
StartupException,
SkunkException,
UnsupportedAuthenticationSchemeException
UnsupportedAuthenticationSchemeException,
UnsupportedSASLMechanismsException
}
import org.typelevel.otel4s.metrics.Histogram
import cats.effect.MonadCancel
Expand All @@ -27,7 +28,7 @@ trait Startup[F[_]] {
def apply(user: String, database: String, password: Option[String], parameters: Map[String, String]): F[Unit]
}

object Startup extends StartupCompanionPlatform {
object Startup {

def apply[F[_]: Exchange: MessageSocket: Tracer](opDuration: Histogram[F, Double])(
implicit ev: MonadCancel[F, Throwable]
Expand Down Expand Up @@ -93,6 +94,51 @@ object Startup extends StartupCompanionPlatform {
}
}

private def authenticationSASL[F[_]: MonadThrow: MessageSocket: Tracer](
sm: StartupMessage,
password: Option[String],
mechanisms: List[String]
): F[Unit] =
Tracer[F].span("authenticationSASL").surround {
if (mechanisms.contains(Scram.SaslMechanism)) {
for {
pw <- requirePassword[F](sm, password)
channelBinding = Scram.NoChannelBinding
clientFirstBare = Scram.clientFirstBareWithRandomNonce
_ <- send(Scram.saslInitialResponse(channelBinding, clientFirstBare))
serverFirstBytes <- flatExpectStartup(sm) {
case AuthenticationSASLContinue(serverFirstBytes) => serverFirstBytes.pure[F]
}
serverFirst <- Scram.ServerFirst.decode(serverFirstBytes) match {
case Some(serverFirst) => serverFirst.pure[F]
case None =>
new SCRAMProtocolException(
s"Failed to parse server-first-message in SASLInitialResponse: ${serverFirstBytes.toHex}."
).raiseError[F, Scram.ServerFirst]
}
(response, expectedVerifier) = Scram.saslChallenge(pw, channelBinding, serverFirst, clientFirstBare, serverFirstBytes)
_ <- send(response)
serverFinalBytes <- flatExpectStartup(sm) {
case AuthenticationSASLFinal(serverFinalBytes) => serverFinalBytes.pure[F]
}
_ <- Scram.ServerFinal.decode(serverFinalBytes) match {
case Some(serverFinal) =>
if (serverFinal.verifier == expectedVerifier) ().pure[F]
else new SCRAMProtocolException(
s"Expected verifier ${expectedVerifier.value.toHex} but received ${serverFinal.verifier.value.toHex}."
).raiseError[F, Unit]
case None =>
new SCRAMProtocolException(
s"Failed to parse server-final-message in AuthenticationSASLFinal: ${serverFinalBytes.toHex}."
).raiseError[F, Unit]
}
_ <- flatExpectStartup(sm) { case AuthenticationOk => ().pure[F] }
} yield ()
} else {
new UnsupportedSASLMechanismsException(mechanisms).raiseError[F, Unit]
}
}

private[protocol] def requirePassword[F[_]](sm: StartupMessage, password: Option[String])(
implicit ev: ApplicativeError[F, Throwable]
): F[String] =
Expand Down
Loading