Skip to content

Commit

Permalink
Fix #99: Implement args streaming without akka-stream
Browse files Browse the repository at this point in the history
  • Loading branch information
povder committed Jun 29, 2017
1 parent d6bfa4a commit ba50cf9
Show file tree
Hide file tree
Showing 14 changed files with 553 additions and 207 deletions.
14 changes: 8 additions & 6 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,12 @@ lazy val core = (project in file("rdbc-pgsql-core"))
Library.rdbcUtil,
Library.typesafeConfig,
Library.scalaLogging,
Library.akkaStream,
Library.sourcecode,
Library.scodecBits,
Library.stm
Library.stm,
Library.logback % Test,
Library.scalatest % Test,
Library.reactiveStreamsTck % Test
),
buildInfoPackage := "io.rdbc.pgsql.core"
)
Expand All @@ -80,11 +82,11 @@ lazy val nettyTransport = (project in file("rdbc-pgsql-transport-netty"))
Library.nettyHandler,
Library.rdbcTypeconv,
Library.rdbcUtil,
Library.rdbcTests,
Library.scalaLogging,
Library.logback,
Library.scalatest,
Library.pgsql
Library.logback % Test,
Library.rdbcTests % Test,
Library.scalatest % Test,
Library.pgsql % Test
),
buildInfoPackage := "io.rdbc.pgsql.transport.netty"
).dependsOn(core, scodec)
Expand Down
10 changes: 5 additions & 5 deletions project/Library.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@ object Library {
val rdbcTypeconv = "io.rdbc" %% "rdbc-typeconv" % rdbcVersion
val rdbcUtil = "io.rdbc" %% "rdbc-util" % rdbcVersion
val reactiveStreams = "org.reactivestreams" % "reactive-streams" % "1.0.0"
val akkaStream = "com.typesafe.akka" %% "akka-stream" % "2.4.18"
val scodecBits = "org.scodec" %% "scodec-bits" % "1.1.4"
val scodecCore = "org.scodec" %% "scodec-core" % "1.10.3"
val typesafeConfig = "com.typesafe" % "config" % "1.3.1"
val scalaLogging = "com.typesafe.scala-logging" %% "scala-logging" % "3.5.0"
val logback = "ch.qos.logback" % "logback-classic" % "1.2.3"
val sourcecode = "com.lihaoyi" %% "sourcecode" % "0.1.3"
val nettyHandler = "io.netty" % "netty-handler" % nettyVersion
val stm = "org.scala-stm" %% "scala-stm" % "0.8"

val rdbcTests = "io.rdbc" %% "rdbc-tests" % rdbcVersion % Test
val scalatest = "org.scalatest" %% "scalatest" % "3.0.3" % Test
val pgsql = "ru.yandex.qatools.embed" % "postgresql-embedded" % "2.1" % Test
val logback = "ch.qos.logback" % "logback-classic" % "1.2.3"
val rdbcTests = "io.rdbc" %% "rdbc-tests" % rdbcVersion
val scalatest = "org.scalatest" %% "scalatest" % "3.0.3"
val pgsql = "ru.yandex.qatools.embed" % "postgresql-embedded" % "2.1"
val reactiveStreamsTck = "org.reactivestreams" % "reactive-streams-tck" % "1.0.0"
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package io.rdbc.pgsql.core
import java.nio.charset.Charset
import java.util.concurrent.atomic.AtomicInteger

import akka.stream.Materializer
import akka.stream.scaladsl.{Sink, Source}
import io.rdbc.api.exceptions.{ConnectionClosedException, ConnectionValidationException, IllegalSessionStateException}
import io.rdbc.implbase.ConnectionPartialImpl
import io.rdbc.pgsql.core.StmtCacheConfig.{Disabled, Enabled}
Expand All @@ -39,6 +37,7 @@ import io.rdbc.sapi._
import io.rdbc.util.Logging
import io.rdbc.util.Preconditions._
import io.rdbc.util.scheduler.TaskScheduler
import org.reactivestreams.Publisher

import scala.concurrent.duration.FiniteDuration
import scala.concurrent.{ExecutionContext, Future, Promise}
Expand All @@ -49,20 +48,22 @@ abstract class AbstractPgConnection(val id: ConnId,
implicit private[this] val out: ChannelWriter,
implicit protected val ec: ExecutionContext,
scheduler: TaskScheduler,
requestCanceler: RequestCanceler,
implicit private[this] val streamMaterializer: Materializer)
requestCanceler: RequestCanceler)
extends Connection
with ConnectionPartialImpl
with WriteFailureHandler
with FatalErrorHandler
with PgStatementExecutor
with BatchExecutor
with Logging {

private[this] val fsmManager = new PgSessionFsmManager(id, this)
@volatile private[this] var sessionParams = SessionParams.default
@volatile private[this] var maybeBackendKeyData = Option.empty[BackendKeyData]
private[this] val stmtCounter = new AtomicInteger(0)

private[this] def argConverter = new StmtArgConverter(config.pgTypes, sessionParams)

@volatile private[this] var maybeStmtCache = config.stmtCacheConfig match {
case Disabled => None
case Enabled(capacity) => Some(LruStmtCache.empty(capacity))
Expand All @@ -83,7 +84,8 @@ abstract class AbstractPgConnection(val id: ConnId,
stmtExecutor = this,
pgTypes = config.pgTypes,
sessionParams = sessionParams,
nativeStmt = PgNativeStatement.parse(RdbcSql(finalSql))
nativeStmt = PgNativeStatement.parse(RdbcSql(finalSql)),
argConverter = argConverter
)
}

Expand Down Expand Up @@ -292,8 +294,10 @@ abstract class AbstractPgConnection(val id: ConnId,

object ParseAndBind {
def apply(bind: Bind): ParseAndBind = ParseAndBind(None, bind)

def apply(parse: Parse, bind: Bind): ParseAndBind = ParseAndBind(Some(parse), bind)
}

case class ParseAndBind(parse: Option[Parse], bind: Bind)

private def newParseAndBind(nativeSql: NativeSql, params: Vector[Argument]): ParseAndBind = {
Expand All @@ -319,7 +323,7 @@ abstract class AbstractPgConnection(val id: ConnId,

override private[core] def executeStatementForRowsAffected(nativeSql: NativeSql,
params: Vector[Argument])(
implicit timeout: Timeout): Future[Long] = traced {
implicit timeout: Timeout): Future[Long] = traced {
fsmManager.ifReadyF { (reqId, _) =>
logger.debug(s"Executing write-only statement '$nativeSql'")

Expand All @@ -344,60 +348,77 @@ abstract class AbstractPgConnection(val id: ConnId,
}
}

override private[core] def subscribeToStatementArgsStream(nativeSql: NativeSql,
paramsSource: ArgsSource): Future[Unit] = traced {
fsmManager.ifReadyF { (_, _) =>
sourceWithParseWritten(nativeSql, paramsSource)
.batch(max = config.maxBatchSize, seed = Vector(_))(_ :+ _)
.mapAsyncUnordered(parallelism = 1)(executeBatch)
.runWith(Sink.last)
.map(txStatus => fsmManager.triggerTransition(newState = Idle(txStatus)))
.map(_ => ())
private[core]
def subscribeToStatementArgsStream[A](nativeStatement: PgNativeStatement,
argsSource: Publisher[A],
argsConverter: A => Vector[Argument]): Future[Unit] = {
fsmManager.ifReadyF { (_, txStatus) =>
val subscriber = new StatementArgsSubscriber(
nativeStmt = nativeStatement,
bufferCapacity = config.maxBatchSize,
minDemandRequest = 10, //TODO
initialTxStatus = txStatus,
batchExecutor = this,
argConverter = argsConverter
)
argsSource.subscribe(subscriber)
subscriber.done.map(_ => ())
}
}

/** Transforms params source to a source which upon materialization sends "Parse" to the backend before
* any of source's elements are processed. */
private def sourceWithParseWritten(nativeSql: NativeSql,
paramsSource: ArgsSource): ArgsSource = {
paramsSource.prefixAndTail(1).flatMapConcat { case (head, tail) =>
val firstParams = head.head
out.write(Parse(None, nativeSql, firstParams.map(_.dataTypeOid)))
Source(head).concat(tail)
private def batchMessages(nativeStmt: PgNativeStatement,
batch: Vector[Vector[Argument]],
first: Boolean): Vector[PgFrontendMessage] = {
val execute = Execute(optionalPortalName = None, optionalFetchSize = None)
val parseVec = {
if (first) {
Vector(Parse(None, nativeStmt.sql, batch.head.map(_.dataTypeOid))) //TODO guard against an empty batch?
} else Vector.empty
}
}

private def executeBatch(batch: Vector[Vector[Argument]]): Future[TxStatus] = {
val execute = Execute(optionalPortalName = None, optionalFetchSize = None)
val batchMsgs = batch.flatMap { params =>
parseVec ++ batch.flatMap { params =>
Vector(Bind(execute.optionalPortalName, None, params, ReturnColFormats.AllBinary), execute)
}
} :+ Sync
}

override private[core]
def executeBatch(nativeStmt: PgNativeStatement,
batch: Vector[Vector[Argument]],
first: Boolean): Future[TxStatus] = {
val batchPromise = Promise[TxStatus]
fsmManager.triggerTransition(State.executingBatch(batchPromise))
val batchMsgs = batchMessages(nativeStmt, batch, first)
out
.writeAndFlush(batchMsgs :+ Sync)
.writeAndFlush(batchMsgs)
.recoverWith(writeFailureHandler)
.flatMap(_ => batchPromise.future)
}

override private[core]
def completeBatch(txStatus: TxStatus): Unit = {
fsmManager.triggerTransition(Idle(txStatus))
}

override private[core] def handleWriteError(cause: Throwable): Unit = traced {
handleFatalError("Write error occurred, the connection will be closed", cause)
}

private def simpleQueryIgnoreResult(sql: NativeSql)(implicit timeout: Timeout): Future[Unit] = traced {
fsmManager.ifReadyF { (reqId, _) =>
val queryPromise = Promise[Unit]
fsmManager.triggerTransition(State.simpleQuerying(queryPromise))
out
.writeAndFlush(Query(sql))
.recoverWith(writeFailureHandler)
.map(_ => newTimeoutHandler(reqId, timeout).map(_.scheduleTimeoutTask(reqId)))
.flatMap { maybeTimeoutTask =>
queryPromise.future.andThen { case _ =>
maybeTimeoutTask.foreach(_.cancel())
fsmManager.ifReadyF {
(reqId, _) =>
val queryPromise = Promise[Unit]
fsmManager.triggerTransition(State.simpleQuerying(queryPromise))
out
.writeAndFlush(Query(sql))
.recoverWith(writeFailureHandler)
.map(_ => newTimeoutHandler(reqId, timeout).map(_.scheduleTimeoutTask(reqId)))
.flatMap {
maybeTimeoutTask =>
queryPromise.future.andThen {
case _ =>
maybeTimeoutTask.foreach(_.cancel())
}
}
}
}
}

Expand All @@ -412,20 +433,26 @@ abstract class AbstractPgConnection(val id: ConnId,
private def handleParamStatusChange(p: ParameterStatus): Unit = traced {
p match {
case ParameterStatus(SessionParamKey("client_encoding"), SessionParamVal(pgCharsetName)) =>
handleCharsetChange(pgCharsetName) { charset =>
handleClientCharsetChange(charset)
sessionParams = sessionParams.copy(clientCharset = charset)
handleCharsetChange(pgCharsetName) {
charset =>
handleClientCharsetChange(charset)
sessionParams = sessionParams.copy(clientCharset = charset)
}

case ParameterStatus(SessionParamKey("server_encoding"), SessionParamVal(pgCharsetName)) =>
handleCharsetChange(pgCharsetName) { charset =>
handleServerCharsetChange(charset)
sessionParams = sessionParams.copy(serverCharset = charset)
handleCharsetChange(pgCharsetName) {
charset =>
handleServerCharsetChange(charset)
sessionParams = sessionParams.copy(serverCharset = charset)
}

case _ => ()
}
logger.debug(s"Session parameter '${p.key.value}' is now set to '${p.value.value}'")
logger.debug(s"Session parameter '${
p.key.value
}' is now set to '${
p.value.value
}'")
}

private def nextStmtName(): StmtName = traced {
Expand All @@ -435,18 +462,21 @@ abstract class AbstractPgConnection(val id: ConnId,
private def doRelease(cause: Throwable): Future[Unit] = traced {
out
.writeAndFlush(Terminate)
.recover { case writeEx =>
logger.error("Write error occurred when terminating connection", writeEx)
.recover {
case writeEx =>
logger.error("Write error occurred when terminating connection", writeEx)
}
.flatMap { _ =>
val connClosedEx = cause match {
case ex: ConnectionClosedException => ex
case ex => new ConnectionClosedException("Connection closed", ex)
}
fsmManager.triggerTransition(ConnectionClosed(connClosedEx))
out.close().recover { case closeEx =>
logger.error("Channel close error occurred when terminating connection", closeEx)
}
.flatMap {
_ =>
val connClosedEx = cause match {
case ex: ConnectionClosedException => ex
case ex => new ConnectionClosedException("Connection closed", ex)
}
fsmManager.triggerTransition(ConnectionClosed(connClosedEx))
out.close().recover {
case closeEx =>
logger.error("Channel close error occurred when terminating connection", closeEx)
}
}
}

Expand All @@ -468,8 +498,9 @@ abstract class AbstractPgConnection(val id: ConnId,
val shouldCancel = fsmManager.startHandlingTimeout(reqId)
if (shouldCancel) {
logger.debug(s"Timeout occurred for request '$reqId', cancelling it")
maybeBackendKeyData.foreach { bkd =>
requestCanceler(bkd).onComplete(_ => fsmManager.finishHandlingTimeout())
maybeBackendKeyData.foreach {
bkd =>
requestCanceler(bkd).onComplete(_ => fsmManager.finishHandlingTimeout())
}
} else {
logger.debug(s"Timeout task ran for request '$reqId', but this request is not being executed anymore")
Expand All @@ -481,18 +512,20 @@ abstract class AbstractPgConnection(val id: ConnId,
}

private[core] def deallocateStatement(nativeSql: NativeSql): Future[Unit] = traced {
maybeStmtCache.fold(Future.unit) { stmtCache =>
fsmManager.ifReadyF { (_, txStatus) =>
stmtCache.evict(nativeSql) match {
case Some((newCache, evictedName)) =>
maybeStmtCache = Some(newCache)
deallocateCached(evictedName)

case None =>
fsmManager.triggerTransition(Idle(txStatus))
Future.unit
maybeStmtCache.fold(Future.unit) {
stmtCache =>
fsmManager.ifReadyF {
(_, txStatus) =>
stmtCache.evict(nativeSql) match {
case Some((newCache, evictedName)) =>
maybeStmtCache = Some(newCache)
deallocateCached(evictedName)

case None =>
fsmManager.triggerTransition(Idle(txStatus))
Future.unit
}
}
}
}
}

Expand All @@ -501,9 +534,10 @@ abstract class AbstractPgConnection(val id: ConnId,
fsmManager.triggerTransition(new DeallocatingStatement(promise))
out
.writeAndFlush(CloseStatement(Some(stmtName)), Sync)
.recoverWith { case writeEx =>
handleWriteError(writeEx)
Future.failed(writeEx)
.recoverWith {
case writeEx =>
handleWriteError(writeEx)
Future.failed(writeEx)
}
.flatMap(_ => promise.future)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,5 @@ object StmtCacheConfig {

final case class PgConnectionConfig(pgTypes: PgTypeRegistry,
typeConverters: TypeConverterRegistry,
maxBatchSize: Long,
maxBatchSize: Int,
stmtCacheConfig: StmtCacheConfig)
Loading

0 comments on commit ba50cf9

Please sign in to comment.