Skip to content

Commit

Permalink
Remove ParameterValue erasure
Browse files Browse the repository at this point in the history
  • Loading branch information
cchantep committed Nov 28, 2013
1 parent 841fe3b commit 0e085dc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 13 deletions.
46 changes: 36 additions & 10 deletions framework/src/anorm/src/main/scala/anorm/Anorm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -381,18 +381,44 @@ object ToStatement {
}

import SqlParser._
case class ParameterValue[A](aValue: A, statementSetter: ToStatement[A]) {
def set(s: java.sql.PreparedStatement, index: Int) = statementSetter.set(s, index, aValue)

/**
* Prepared parameter value.
*/
trait ParameterValue {

/**
* Sets this value on given statement at specified index.
*
* @param s SQL Statement
* @param index Parameter index
*/
def set(s: java.sql.PreparedStatement, index: Int): Unit
}

/**
* Value factory for parameter.
*
* {{{
* val param = ParameterValue("str", setter)
*
* SQL("...").onParams(param)
* }}}
*/
object ParameterValue {
def apply[A](value: A, setter: ToStatement[A]) = new ParameterValue {
def set(s: java.sql.PreparedStatement, i: Int) = setter.set(s, i, value)
}
}

case class SimpleSql[T](sql: SqlQuery, params: Seq[(String, ParameterValue[_])], defaultParser: RowParser[T]) extends Sql {
case class SimpleSql[T](sql: SqlQuery, params: Seq[(String, ParameterValue)], defaultParser: RowParser[T]) extends Sql {

def on(args: (Any, ParameterValue[_])*): SimpleSql[T] = this.copy(params = (this.params) ++ args.map {
def on(args: (Any, ParameterValue)*): SimpleSql[T] = this.copy(params = (this.params) ++ args.map {
case (s: Symbol, v) => (s.name, v)
case (k, v) => (k.toString, v)
})

def onParams(args: ParameterValue[_]*): SimpleSql[T] = this.copy(params = (this.params) ++ sql.argsInitialOrder.zip(args))
def onParams(args: ParameterValue*): SimpleSql[T] = this.copy(params = (this.params) ++ sql.argsInitialOrder.zip(args))

def list()(implicit connection: java.sql.Connection): Seq[T] = as(defaultParser*)

Expand Down Expand Up @@ -422,13 +448,13 @@ case class SimpleSql[T](sql: SqlQuery, params: Seq[(String, ParameterValue[_])],
def withQueryTimeout(seconds: Option[Int]): SimpleSql[T] = this.copy(sql = sql.withQueryTimeout(seconds))
}

case class BatchSql(sql: SqlQuery, params: Seq[Seq[(String, ParameterValue[_])]]) {
case class BatchSql(sql: SqlQuery, params: Seq[Seq[(String, ParameterValue)]]) {

def addBatch(args: (String, ParameterValue[_])*): BatchSql = this.copy(params = (this.params) :+ args)
def addBatchList(paramsMapList: TraversableOnce[Seq[(String, ParameterValue[_])]]): BatchSql = this.copy(params = (this.params) ++ paramsMapList)
def addBatch(args: (String, ParameterValue)*): BatchSql = this.copy(params = (this.params) :+ args)
def addBatchList(paramsMapList: TraversableOnce[Seq[(String, ParameterValue)]]): BatchSql = this.copy(params = (this.params) ++ paramsMapList)

def addBatchParams(args: ParameterValue[_]*): BatchSql = this.copy(params = (this.params) :+ sql.argsInitialOrder.zip(args))
def addBatchParamsList(paramsSeqList: TraversableOnce[Seq[ParameterValue[_]]]): BatchSql = this.copy(params = (this.params) ++ paramsSeqList.map(paramsSeq => sql.argsInitialOrder.zip(paramsSeq)))
def addBatchParams(args: ParameterValue*): BatchSql = this.copy(params = (this.params) :+ sql.argsInitialOrder.zip(args))
def addBatchParamsList(paramsSeqList: TraversableOnce[Seq[ParameterValue]]): BatchSql = this.copy(params = (this.params) ++ paramsSeqList.map(paramsSeq => sql.argsInitialOrder.zip(paramsSeq)))

def getFilledStatement(connection: java.sql.Connection, getGeneratedKeys: Boolean = false) = {
val statement = if (getGeneratedKeys) connection.prepareStatement(sql.query, java.sql.Statement.RETURN_GENERATED_KEYS)
Expand Down
5 changes: 2 additions & 3 deletions framework/src/anorm/src/main/scala/anorm/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ package object anorm {

implicit def implicitID[ID](id: Id[ID] with NotNull): ID = id.id

implicit def toParameterValue[A](a: A)(implicit p: ToStatement[A]): ParameterValue[A] =
ParameterValue(a, p)
implicit def toParameterValue[A](a: A)(implicit p: ToStatement[A]): ParameterValue = ParameterValue(a, p)

def SQL(stmt: String) = Sql.sql(stmt)

}
}

0 comments on commit 0e085dc

Please sign in to comment.