Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions modules/core/shared/src/main/scala/Session.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import cats.effect._
import cats.effect.std.Console
import cats.syntax.all._
import fs2.concurrent.Signal
import fs2.io.net.{ Network, SocketGroup }
import fs2.io.net.{ Network, SocketGroup, SocketOption }
import fs2.Pipe
import fs2.Stream
import natchez.Trace
Expand Down Expand Up @@ -201,6 +201,9 @@ object Session {
"client_encoding" -> "UTF8",
)

val DefaultSocketOptions: List[SocketOption] =
List(SocketOption.noDelay(true))

object Recyclers {

/**
Expand Down Expand Up @@ -270,12 +273,13 @@ object Session {
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
ssl: SSL = SSL.None,
parameters: Map[String, String] = Session.DefaultConnectionParameters,
socketOptions: List[SocketOption] = Session.DefaultSocketOptions,
commandCache: Int = 1024,
queryCache: Int = 1024,
): Resource[F, Resource[F, Session[F]]] = {

def session(socketGroup: SocketGroup[F], sslOp: Option[SSLNegotiation.Options[F]], cache: Describe.Cache[F]): Resource[F, Session[F]] =
fromSocketGroup[F](socketGroup, host, port, user, database, password, debug, strategy, sslOp, parameters, cache)
def session(socketGroup: SocketGroup[F], sslOp: Option[SSLNegotiation.Options[F]], cache: Describe.Cache[F]): Resource[F, Session[F]] =
fromSocketGroup[F](socketGroup, host, port, user, database, password, debug, strategy, socketOptions, sslOp, parameters, cache)

val logger: String => F[Unit] = s => Console[F].println(s"TLS: $s")

Expand Down Expand Up @@ -330,13 +334,14 @@ object Session {
password: Option[String] = none,
debug: Boolean = false,
strategy: Typer.Strategy = Typer.Strategy.BuiltinsOnly,
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
parameters: Map[String, String],
describeCache: Describe.Cache[F],
): Resource[F, Session[F]] =
for {
namer <- Resource.eval(Namer[F])
proto <- Protocol[F](host, port, debug, namer, socketGroup, sslOptions, describeCache)
proto <- Protocol[F](host, port, debug, namer, socketGroup, socketOptions, sslOptions, describeCache)
_ <- Resource.eval(proto.startup(user, database, password, parameters))
sess <- Resource.eval(fromProtocol(proto, namer, strategy))
} yield sess
Expand Down
11 changes: 6 additions & 5 deletions modules/core/shared/src/main/scala/net/BitVectorSocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,18 +61,19 @@ object BitVectorSocket {
* @group Constructors
*/
def apply[F[_]](
host: String,
port: Int,
sg: SocketGroup[F],
sslOptions: Option[SSLNegotiation.Options[F]],
host: String,
port: Int,
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
)(implicit ev: MonadError[F, Throwable]): Resource[F, BitVectorSocket[F]] = {

def fail[A](msg: String): Resource[F, A] =
Resource.eval(ev.raiseError(new SkunkException(message = msg, sql = None)))

def sock: Resource[F, Socket[F]] = {
(Hostname.fromString(host), Port.fromInt(port)) match {
case (Some(validHost), Some(validPort)) => sg.client(SocketAddress(validHost, validPort), List(SocketOption.noDelay(true)))
case (Some(validHost), Some(validPort)) => sg.client(SocketAddress(validHost, validPort), socketOptions)
case (None, _) => fail(s"""Hostname: "$host" is not syntactically valid.""")
case (_, None) => fail(s"Port: $port falls out of the allowed range.")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
package skunk.net

import cats._
import cats.effect.{ Sync => _, _ }
import cats.effect.{Sync => _, _}
import cats.effect.implicits._
import cats.effect.std.{ Console, Queue }
import cats.effect.std.{Console, Queue}
import cats.syntax.all._
import fs2.concurrent._
import fs2.Stream
import skunk.data._
import skunk.net.message._
import fs2.io.net.SocketGroup
import fs2.io.net.{SocketGroup, SocketOption}

/**
* A `MessageSocket` that buffers incoming messages, removing and handling asynchronous back-end
Expand Down Expand Up @@ -80,10 +80,11 @@ object BufferedMessageSocket {
queueSize: Int,
debug: Boolean,
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
): Resource[F, BufferedMessageSocket[F]] =
for {
ms <- MessageSocket(host, port, debug, sg, sslOptions)
ms <- MessageSocket(host, port, debug, sg, socketOptions, sslOptions)
ams <- Resource.make(BufferedMessageSocket.fromMessageSocket[F](ms, queueSize))(_.terminate)
} yield ams

Expand Down
5 changes: 3 additions & 2 deletions modules/core/shared/src/main/scala/net/MessageSocket.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import scodec.codecs._
import scodec.interop.cats._
import skunk.net.message.{ Sync => _, _ }
import skunk.util.Origin
import fs2.io.net.SocketGroup
import fs2.io.net.{ SocketGroup, SocketOption }

/** A higher-level `BitVectorSocket` that speaks in terms of `Message`. */
trait MessageSocket[F[_]] {
Expand Down Expand Up @@ -91,10 +91,11 @@ object MessageSocket {
port: Int,
debug: Boolean,
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
): Resource[F, MessageSocket[F]] =
for {
bvs <- BitVectorSocket(host, port, sg, sslOptions)
bvs <- BitVectorSocket(host, port, sg, socketOptions, sslOptions)
ms <- Resource.eval(fromBitVectorSocket(bvs, debug))
} yield ms

Expand Down
5 changes: 3 additions & 2 deletions modules/core/shared/src/main/scala/net/Protocol.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import skunk.data._
import skunk.util.{ Namer, Origin }
import skunk.util.Typer
import natchez.Trace
import fs2.io.net.SocketGroup
import fs2.io.net.{ SocketGroup, SocketOption }
import skunk.net.protocol.Exchange
import skunk.net.protocol.Describe

Expand Down Expand Up @@ -193,11 +193,12 @@ object Protocol {
debug: Boolean,
nam: Namer[F],
sg: SocketGroup[F],
socketOptions: List[SocketOption],
sslOptions: Option[SSLNegotiation.Options[F]],
describeCache: Describe.Cache[F],
): Resource[F, Protocol[F]] =
for {
bms <- BufferedMessageSocket[F](host, port, 256, debug, sg, sslOptions) // TODO: should we expose the queue size?
bms <- BufferedMessageSocket[F](host, port, 256, debug, sg, socketOptions, sslOptions) // TODO: should we expose the queue size?
p <- Resource.eval(fromMessageSocket(bms, nam, describeCache))
} yield p

Expand Down
6 changes: 4 additions & 2 deletions modules/tests/shared/src/test/scala/BitVectorSocketTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@ class BitVectorSocketTest extends ffstest.FTest {
override def serverResource(address: Option[Host], port: Option[Port], options: List[SocketOption]): Resource[IO, (SocketAddress[IpAddress], fs2.Stream[IO, Socket[IO]])] = ???
}

private val socketOptions = List(SocketOption.noDelay(true))

test("Invalid host") {
BitVectorSocket("", 1, dummySg, None).use(_ => IO.unit).assertFailsWith[SkunkException]
BitVectorSocket("", 1, dummySg, socketOptions, None).use(_ => IO.unit).assertFailsWith[SkunkException]
.flatMap(e => assertEqual("message", e.message, """Hostname: "" is not syntactically valid."""))
}
test("Invalid port") {
BitVectorSocket("localhost", -1, dummySg, None).use(_ => IO.unit).assertFailsWith[SkunkException]
BitVectorSocket("localhost", -1, dummySg, socketOptions, None).use(_ => IO.unit).assertFailsWith[SkunkException]
.flatMap(e => assertEqual("message", e.message, "Port: -1 falls out of the allowed range."))
}

Expand Down