Skip to content

Commit

Permalink
Merge pull request #214 from takapi327/feature/2024-05-Convenience-me…
Browse files Browse the repository at this point in the history
…thod-added-for-sql-construction

Feature/2024 05 convenience method added for sql construction
  • Loading branch information
takapi327 committed May 25, 2024
2 parents 402d5d4 + 87e559c commit b5a3f83
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import ldbc.sql.logging.*
* @tparam F
* The effect type
*/
case class Mysql[F[_]: Temporal](statement: String, params: Seq[ParameterBinder[F]]) extends SQL[F]:
case class Mysql[F[_]: Temporal](statement: String, params: List[ParameterBinder[F]]) extends SQL[F]:

@targetName("combine")
override def ++(sql: SQL[F]): SQL[F] =
Mysql[F](statement ++ " " ++ sql.statement, params ++ sql.params)
Mysql[F](statement ++ sql.statement, params ++ sql.params)

override def update(using logHandler: LogHandler[F]): Kleisli[F, Connection[F], Int] = Kleisli { connection =>
(for
Expand All @@ -40,8 +40,8 @@ case class Mysql[F[_]: Temporal](statement: String, params: Seq[ParameterBinder[
case (param, index) => param.bind(statement, index + 1)
} >> statement.executeUpdate() <* statement.close()
yield result)
.onError(ex => logHandler.run(LogEvent.ExecFailure(statement, params.map(_.parameter).toList, ex)))
<* logHandler.run(LogEvent.Success(statement, params.map(_.parameter).toList))
.onError(ex => logHandler.run(LogEvent.ExecFailure(statement, params.map(_.parameter), ex)))
<* logHandler.run(LogEvent.Success(statement, params.map(_.parameter)))
}

override def returning[T <: String | Int | Long](using
Expand All @@ -59,8 +59,8 @@ case class Mysql[F[_]: Temporal](statement: String, params: Seq[ParameterBinder[
} >> statement.executeUpdate() >> statement.getGeneratedKeys()
result <- summon[ResultSetConsumer[F, T]].consume(resultSet) <* statement.close()
yield result)
.onError(ex => logHandler.run(LogEvent.ExecFailure(statement, params.map(_.parameter).toList, ex)))
<* logHandler.run(LogEvent.Success(statement, params.map(_.parameter).toList))
.onError(ex => logHandler.run(LogEvent.ExecFailure(statement, params.map(_.parameter), ex)))
<* logHandler.run(LogEvent.Success(statement, params.map(_.parameter)))
}

private[ldbc] override def connection[T](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@

package ldbc

import cats.data.Kleisli
import cats.syntax.all.*

import cats.effect.*
import cats.effect.kernel.Resource.ExitCase

import ldbc.sql.*

Expand All @@ -20,6 +24,33 @@ package object connector:
val expressions = args.iterator
Mysql[F](strings.mkString("?"), expressions.toList)

private trait ConnectionSyntax[F[_]: Temporal]:

extension [T](connectionKleisli: Kleisli[F, Connection[F], T])

def readOnly(connection: Connection[F]): F[T] =
connection.setReadOnly(true) *> connectionKleisli.run(connection)

def autoCommit(connection: Connection[F]): F[T] =
connection.setReadOnly(false) *> connection.setAutoCommit(true) *> connectionKleisli.run(connection)

def transaction(connection: Connection[F]): F[T] =
val acquire = connection.setReadOnly(false) *> connection.setAutoCommit(false) *> Temporal[F].pure(connection)

val release = (connection: Connection[F], exitCase: ExitCase) =>
exitCase match
case ExitCase.Errored(_) | ExitCase.Canceled => connection.rollback()
case _ => connection.commit()

Resource
.makeCase(acquire)(release)
.use(connectionKleisli.run)

def rollback(connection: Connection[F]): F[T] =
connection.setReadOnly(false) *> connection.setAutoCommit(false) *> connectionKleisli.run(
connection
) <* connection.rollback()

/**
* Top-level imports provide aliases for the most commonly used types and modules. A typical starting set of imports
* might look something like this.
Expand All @@ -29,4 +60,4 @@ package object connector:
* import ldbc.connector.io.*
* }}}
*/
val io: StringContextSyntax[IO] = new StringContextSyntax[IO] {}
val io: StringContextSyntax[IO] & ConnectionSyntax[IO] = new StringContextSyntax[IO] with ConnectionSyntax[IO] {}
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@ class SQLStringContextQueryTest extends CatsEffectSuite:
assertIO(
connection.use { conn =>
(for
result1 <- sql"SELECT 1".toList[Tuple1[Int]]
result2 <- sql"SELECT 2".headOption[Tuple1[Int]]
result3 <- sql"SELECT 3".unsafe[Tuple1[Int]]
result1 <- sql"SELECT 1".toList[Int]
result2 <- sql"SELECT 2".headOption[Int]
result3 <- sql"SELECT 3".unsafe[Int]
yield (result1, result2, result3)).run(conn)
},
(List(Tuple1(1)), Some(Tuple1(2)), Tuple1(3))
(List(1), Some(2), 3)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class SQLStringContextUpdateTest extends CatsEffectSuite:
_ <- sql"CREATE TABLE `string_context_bit_table1`(`bit_column` BIT NOT NULL)".update
count <- sql"INSERT INTO `string_context_bit_table1`(`bit_column`) VALUES (b'1')".update
_ <- sql"DROP TABLE `string_context_bit_table1`".update
yield count).run(conn)
yield count).transaction(conn)
},
1
)
Expand All @@ -49,7 +49,7 @@ class SQLStringContextUpdateTest extends CatsEffectSuite:
_ <- sql"CREATE TABLE `string_context_bit_table2`(`bit_column` BIT NOT NULL)".update
count <- sql"INSERT INTO `string_context_bit_table2`(`bit_column`) VALUES (b'0'),(b'1')".update
_ <- sql"DROP TABLE `string_context_bit_table2`".update
yield count).run(conn)
yield count).transaction(conn)
},
2
)
Expand All @@ -64,8 +64,29 @@ class SQLStringContextUpdateTest extends CatsEffectSuite:
_ <- sql"INSERT INTO `returning_auto_inc`(`id`, `c1`) VALUES ($None, ${ "column 1" })".update
generated <- sql"INSERT INTO `returning_auto_inc`(`id`, `c1`) VALUES ($None, ${ "column 2" })".returning[Long]
_ <- sql"DROP TABLE `returning_auto_inc`".update
yield generated).run(conn)
yield generated).transaction(conn)
},
2L
)
}

test("Not a single submission of result data rolled back in transaction has been reflected. ") {
assertIO(
connection.use { conn =>
for
_ <-
sql"CREATE TABLE `transaction_rollback_test`(`id` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY, `c1` VARCHAR(255) NOT NULL)".update
.autoCommit(conn)
result <- sql"INSERT INTO `transaction_rollback_test`(`id`, `c1`) VALUES ($None, ${ "column 1" })".update
.flatMap(_ =>
sql"INSERT INTO `transaction_rollback_test`(`id`, `xxx`) VALUES ($None, ${ "column 2" })".update
)
.transaction(conn)
.attempt
count <- sql"SELECT count(*) FROM `transaction_rollback_test`".unsafe[Int].readOnly(conn)
_ <- sql"DROP TABLE `transaction_rollback_test`".update.autoCommit(conn)
yield count
},
0
)
}
4 changes: 2 additions & 2 deletions module/ldbc-dsl/shared/src/main/scala/ldbc/dsl/Mysql.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import ldbc.sql.logging.*
* @tparam F
* The effect type
*/
case class Mysql[F[_]: Temporal](statement: String, params: Seq[ParameterBinder[F]]) extends SQL[F]:
case class Mysql[F[_]: Temporal](statement: String, params: List[ParameterBinder[F]]) extends SQL[F]:

@targetName("combine")
override def ++(sql: SQL[F]): SQL[F] =
Mysql[F](statement ++ " " ++ sql.statement, params ++ sql.params)
Mysql[F](statement ++ sql.statement, params ++ sql.params)

override def update(using logHandler: LogHandler[F]): Kleisli[F, Connection[F], Int] = Kleisli { connection =>
(for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,4 @@ trait StringContextSyntax[F[_]: Temporal]:
inline def sql(inline args: ParameterBinder[F]*): SQL[F] =
val strings = sc.parts.iterator
val expressions = args.iterator
Mysql(strings.mkString("?"), expressions.toSeq)
Mysql(strings.mkString("?"), expressions.toList)
92 changes: 23 additions & 69 deletions module/ldbc-sql/src/main/scala/ldbc/sql/SQL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import scala.annotation.targetName

import cats.{ Monad, MonadError }
import cats.data.Kleisli
import cats.kernel.Semigroup
import cats.syntax.all.*

import ldbc.sql.util.FactoryCompat
Expand Down Expand Up @@ -41,19 +42,9 @@ trait SQL[F[_]: Monad]:
/**
* Methods for returning an array of data to be retrieved from the database.
*/
inline def toList[T <: Tuple](using FactoryCompat[T, List[T]], LogHandler[F]): Kleisli[F, Connection[F], List[T]] =
given Kleisli[F, ResultSet[F], T] = Kleisli { resultSet =>
ResultSetReader
.fold[F, T]
.toList
.zipWithIndex
.traverse {
case (reader: ResultSetReader[F, Any], index) => reader.read(resultSet, index + 1)
}
.map(list => Tuple.fromArray(list.toArray).asInstanceOf[T])
}

connectionToList[T](statement, params)
def toList[T](using reader: ResultSetReader[F, T], logHandler: LogHandler[F]): Kleisli[F, Connection[F], List[T]] =
given Kleisli[F, ResultSet[F], T] = Kleisli(resultSet => reader.read(resultSet, 1))
connection[List[T]](statement, params, summon[ResultSetConsumer[F, List[T]]])

inline def toList[P <: Product](using
mirror: Mirror.ProductOf[P],
Expand All @@ -71,25 +62,18 @@ trait SQL[F[_]: Monad]:
.map(list => mirror.fromProduct(Tuple.fromArray(list.toArray)))
}

connectionToList[P](statement, params)
connection[List[P]](statement, params, summon[ResultSetConsumer[F, List[P]]])

/**
* A method to return the data to be retrieved from the database as Option type. If there are multiple data, the
* first one is retrieved.
*/
inline def headOption[T <: Tuple](using LogHandler[F]): Kleisli[F, Connection[F], Option[T]] =
given Kleisli[F, ResultSet[F], T] = Kleisli { resultSet =>
ResultSetReader
.fold[F, T]
.toList
.zipWithIndex
.traverse {
case (reader: ResultSetReader[F, Any], index) => reader.read(resultSet, index + 1)
}
.map(list => Tuple.fromArray(list.toArray).asInstanceOf[T])
}

connectionToHeadOption[T](statement, params)
def headOption[T](using
reader: ResultSetReader[F, T],
logHandler: LogHandler[F]
): Kleisli[F, Connection[F], Option[T]] =
given Kleisli[F, ResultSet[F], T] = Kleisli(resultSet => reader.read(resultSet, 1))
connection[Option[T]](statement, params, summon[ResultSetConsumer[F, Option[T]]])

inline def headOption[P <: Product](using
mirror: Mirror.ProductOf[P],
Expand All @@ -106,25 +90,19 @@ trait SQL[F[_]: Monad]:
.map(list => mirror.fromProduct(Tuple.fromArray(list.toArray)))
}

connectionToHeadOption[P](statement, params)
connection[Option[P]](statement, params, summon[ResultSetConsumer[F, Option[P]]])

/**
* A method to return the data to be retrieved from the database as is. If the data does not exist, an exception is
* raised. Use the [[headOption]] method if you want to retrieve individual data.
*/
inline def unsafe[T <: Tuple](using MonadError[F, Throwable], LogHandler[F]): Kleisli[F, Connection[F], T] =
given Kleisli[F, ResultSet[F], T] = Kleisli { resultSet =>
ResultSetReader
.fold[F, T]
.toList
.zipWithIndex
.traverse {
case (reader: ResultSetReader[F, Any], index) => reader.read(resultSet, index + 1)
}
.map(list => Tuple.fromArray(list.toArray).asInstanceOf[T])
}

connectionToUnsafe[T](statement, params)
def unsafe[T](using
reader: ResultSetReader[F, T],
logHandler: LogHandler[F],
ev: MonadError[F, Throwable]
): Kleisli[F, Connection[F], T] =
given Kleisli[F, ResultSet[F], T] = Kleisli(resultSet => reader.read(resultSet, 1))
connection[T](statement, params, summon[ResultSetConsumer[F, T]])

inline def unsafe[P <: Product](using
mirror: Mirror.ProductOf[P],
Expand All @@ -142,7 +120,7 @@ trait SQL[F[_]: Monad]:
.map(list => mirror.fromProduct(Tuple.fromArray(list.toArray)))
}

connectionToUnsafe[P](statement, params)
connection[P](statement, params, summon[ResultSetConsumer[F, P]])

/**
* A method to return the number of rows updated by the SQL statement.
Expand All @@ -160,31 +138,7 @@ trait SQL[F[_]: Monad]:
consumer: ResultSetConsumer[F, T]
)(using logHandler: LogHandler[F]): Kleisli[F, Connection[F], T]

/**
* Methods for returning an array of data to be retrieved from the database.
*/
private def connectionToList[T](
statement: String,
params: Seq[ParameterBinder[F]]
)(using Kleisli[F, ResultSet[F], T], LogHandler[F], FactoryCompat[T, List[T]]): Kleisli[F, Connection[F], List[T]] =
connection[List[T]](statement, params, summon[ResultSetConsumer[F, List[T]]])

/**
* A method to return the data to be retrieved from the database as Option type. If there are multiple data, the first
* one is retrieved.
*/
private def connectionToHeadOption[T](
statement: String,
params: Seq[ParameterBinder[F]]
)(using Kleisli[F, ResultSet[F], T], LogHandler[F]): Kleisli[F, Connection[F], Option[T]] =
connection[Option[T]](statement, params, summon[ResultSetConsumer[F, Option[T]]])
object SQL:

/**
* A method to return the data to be retrieved from the database as is. If the data does not exist, an exception is
* raised. Use the [[connectionToHeadOption]] method if you want to retrieve individual data.
*/
private def connectionToUnsafe[T](
statement: String,
params: Seq[ParameterBinder[F]]
)(using Kleisli[F, ResultSet[F], T], LogHandler[F], MonadError[F, Throwable]): Kleisli[F, Connection[F], T] =
connection[T](statement, params, summon[ResultSetConsumer[F, T]])
given [F[_]]: Semigroup[SQL[F]] with
override def combine(x: SQL[F], y: SQL[F]): SQL[F] = x ++ y

0 comments on commit b5a3f83

Please sign in to comment.