Skip to content

Commit

Permalink
Add type information to all standard library calls
Browse files Browse the repository at this point in the history
  • Loading branch information
szeiger committed Oct 8, 2012
1 parent d7cb723 commit 04a47b5
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/main/scala/scala/slick/ast/Library.scala
Expand Up @@ -83,7 +83,7 @@ object Library {
class FunctionSymbol(val name: String) extends Symbol {

/** Create an untyped Apply of this Symbol */
def apply(ch: Node*): Apply = Apply(this, ch)
//def apply(ch: Node*): Apply = Apply(this, ch)

/** Match an Apply of this Symbol */
def unapplySeq(n: Node) = n match {
Expand Down
14 changes: 14 additions & 0 deletions src/main/scala/scala/slick/ast/Type.scala
Expand Up @@ -15,3 +15,17 @@ trait Typed {
object Typed {
def unapply(t: Typed) = Some(t.tpe)
}

sealed class StaticType(name: String) extends Type {
override def toString = "StaticType."+name
}

object StaticType {
object Boolean extends StaticType("Boolean")
object Char extends StaticType("Char")
object Int extends StaticType("Int")
object Long extends StaticType("Long")
object Null extends StaticType("Null")
object String extends StaticType("String")
object Unit extends StaticType("Unit")
}
4 changes: 2 additions & 2 deletions src/main/scala/scala/slick/compiler/Relational.scala
Expand Up @@ -105,7 +105,7 @@ class ConvertToComprehensions extends Phase {
case Bind(s1, Select(Ref(gen2), ElementSymbol(2)), Pure(ProductNode(Seq(Select(Ref(s2), field)))))
if (s2 == s1) && (gen2 == gen) => Select(Ref(gen), field)
case Library.CountAll(Select(Ref(gen2), ElementSymbol(2))) if gen2 == gen =>
Library.Count(ConstColumn(1))
Library.Count.typed(StaticType.Long, ConstColumn(1))
case Select(Ref(gen2), ElementSymbol(2)) if gen2 == gen => Ref(gen2)
case Select(Ref(gen2), ElementSymbol(1)) if gen2 == gen => newBy
}
Expand Down Expand Up @@ -258,7 +258,7 @@ class FuseComprehensions extends Phase {
val a2 = new AnonSymbol
val (c2b, call) = s match {
case Library.CountAll =>
(c2, Library.Count(ConstColumn(1)))
(c2, Library.Count.typed(StaticType.Long, ConstColumn(1)))
case s =>
val c3 = ensureStruct(c2)
// All standard aggregate functions operate on a single column
Expand Down
10 changes: 6 additions & 4 deletions src/main/scala/scala/slick/direct/SlickBackend.scala
Expand Up @@ -28,6 +28,8 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
,"scala.Double" /*typeOf[Double]*/ -> TypeMapper.DoubleTypeMapper
,"scala.String" /*typeOf[String]*/ -> TypeMapper.StringTypeMapper
,"java.lang.String" /*typeOf[String]*/ -> TypeMapper.StringTypeMapper // FIXME: typeOf[String] leads to java.lang.String, but param.typeSignature to String
,"Boolean" /*typeBof[Boolean]*/ -> TypeMapper.BooleanTypeMapper
,"scala.Boolean" /*typeBof[Boolean]*/ -> TypeMapper.BooleanTypeMapper
)

//def resolveSym( lhs:Type, name:String, rhs:Type* ) = lhs.member(newTermName(name).encodedName).asTerm.resolveOverloaded(actuals = rhs.toList)
Expand Down Expand Up @@ -192,10 +194,10 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
}
=>
term.decoded match {
case "+" => Library.Concat(s2sq( lhs ).node, s2sq( rhs ).node )
case "+" => Library.Concat.typed(sq.StaticType.String, s2sq( lhs ).node, s2sq( rhs ).node )
}

case Apply(op@Select(lhs,term),rhs::Nil) => {
case a@Apply(op@Select(lhs,term),rhs::Nil) => {
val actualTypes = lhs.tpe :: rhs.tpe :: Nil //.map(_.tpe).toList
val matching_ops = ( operatorMap.collect{
case (str2sym, types)
Expand All @@ -207,7 +209,7 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
})
matching_ops.size match{
case 0 => throw new SlickException("Operator not supported: "+ lhs.tpe +"."+term.decoded+"("+ rhs.tpe +")")
case 1 => matching_ops.head( s2sq( lhs ).node, s2sq( rhs ).node )
case 1 => matching_ops.head.typed(typeMappers(a.tpe.toString), s2sq( lhs ).node, s2sq( rhs ).node )
case _ => throw new SlickException("Internal Slick error: resolution of "+ lhs.tpe +" "+term.decoded+" "+ rhs.tpe +" was ambigious")
}
}
Expand All @@ -223,7 +225,7 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac

case Select(scala_lhs, term)
if scala_lhs.tpe.erasure <:< typeOf[QueryOps[_]].erasure && (term.decoded == "length" || term.decoded == "size")
=> sq.Pure( Library.CountAll( s2sq(scala_lhs).node ) )
=> sq.Pure( Library.CountAll.typed(sq.StaticType.Int, s2sq(scala_lhs).node ) )

case tree if tree.tpe.erasure <:< typeOf[BaseQueryable[_]].erasure
=> val (tpe,query) = toQuery( eval(tree).asInstanceOf[BaseQueryable[_]] ); query
Expand Down
Expand Up @@ -124,7 +124,7 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
protected def buildWhereClause(where: Seq[Node]) = building(WherePart) {
if(!where.isEmpty) {
b" where "
expr(where.reduceLeft((a, b) => Library.And(a, b)), true)
expr(where.reduceLeft((a, b) => Library.And.typed(typeMapperDelegates.booleanTypeMapperDelegate, a, b)), true)
}
}

Expand Down Expand Up @@ -214,8 +214,8 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
case Library.Database() if !capabilities.contains(SqlProfile.capabilities.functionDatabase) =>
b += "''"
case Library.Pi() if !hasPiFunction => b += pi
case Library.Degrees(ch) if !hasRadDegConversion => b"(180.0/!${Library.Pi()}*$ch)"
case Library.Radians(ch) if!hasRadDegConversion => b"(!${Library.Pi()}/180.0*$ch)"
case Library.Degrees(ch) if !hasRadDegConversion => b"(180.0/!${Library.Pi.typed(typeMapperDelegates.bigDecimalTypeMapperDelegate)}*$ch)"
case Library.Radians(ch) if!hasRadDegConversion => b"(!${Library.Pi.typed(typeMapperDelegates.bigDecimalTypeMapperDelegate)}/180.0*$ch)"
case s: SimpleFunction =>
if(s.scalar) b"{fn "
b"${s.name}("
Expand All @@ -236,7 +236,8 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
case Library.EndsWith(n, LiteralNode(s: String)) =>
b"\($n like ${quote("%"+likeEncode(s))} escape '^'\)"
case Library.Trim(n) =>
expr(Library.LTrim(Library.RTrim(n)), skipParens)
expr(Library.LTrim.typed(typeMapperDelegates.stringTypeMapperDelegate,
Library.RTrim.typed(typeMapperDelegates.stringTypeMapperDelegate, n)), skipParens)
case a @ Library.Cast(ch @ _*) =>
val tn =
if(ch.length == 2) ch(1).asInstanceOf[LiteralNode].value.asInstanceOf[String]
Expand Down Expand Up @@ -311,7 +312,7 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
b.sep(select, ", ")(field => b += symbolName(field) += " = ?")
if(!where.isEmpty) {
b" where "
expr(where.reduceLeft((a, b) => Library.And(a, b)), true)
expr(where.reduceLeft((a, b) => Library.And.typed(typeMapperDelegates.booleanTypeMapperDelegate, a, b)), true)
}
QueryBuilderResult(b.build, input.linearizer)
}
Expand All @@ -326,7 +327,7 @@ trait JdbcStatementBuilderComponent { driver: JdbcDriver =>
b"delete from $qtn"
if(!where.isEmpty) {
b" where "
expr(where.reduceLeft((a, b) => Library.And(a, b)), true)
expr(where.reduceLeft((a, b) => Library.And.typed(typeMapperDelegates.booleanTypeMapperDelegate, a, b)), true)
}
QueryBuilderResult(b.build, input.linearizer)
}
Expand Down
2 changes: 1 addition & 1 deletion src/main/scala/scala/slick/lifted/AbstractTable.scala
Expand Up @@ -23,7 +23,7 @@ abstract class AbstractTable[T](val schemaName: Option[String], val tableName: S
val q = Query[TT, U, TT](targetTable)(Shape.tableShape.asInstanceOf[Shape[TT, U, TT]])
val generator = new AnonSymbol
val aliased = q.unpackable.encodeRef(generator)
val fv = Library.==(Node(targetColumns(aliased.value)), Node(sourceColumns))
val fv = Library.==.typed(StaticType.Boolean, Node(targetColumns(aliased.value)), Node(sourceColumns))
val fk = ForeignKey(name, this, q.unpackable.asInstanceOf[ShapedValue[TT, _]],
targetTable, unpackp, sourceColumns, targetColumns, onUpdate, onDelete)
new ForeignKeyQuery[TT, U](Filter(generator, Node(q), fv), q.unpackable, IndexedSeq(fk), q, generator, aliased.value)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/scala/slick/lifted/Constraint.scala
Expand Up @@ -73,8 +73,8 @@ class ForeignKeyQuery[E <: TableNode, U](
def & (other: ForeignKeyQuery[E, U]): ForeignKeyQuery[E, U] = {
val newFKs = fks ++ other.fks
val conditions =
newFKs.map(fk => Library.==(Node(fk.targetColumns(aliasedValue)), Node(fk.sourceColumns))).
reduceLeft[Node]((a, b) => Library.And(a, b))
newFKs.map(fk => Library.==.typed(StaticType.Boolean, Node(fk.targetColumns(aliasedValue)), Node(fk.sourceColumns))).
reduceLeft[Node]((a, b) => Library.And.typed(StaticType.Boolean, a, b))
val newDelegate = Filter(generator, Node(targetBaseQuery), conditions)
new ForeignKeyQuery[E, U](newDelegate, base, newFKs, targetBaseQuery, generator, aliasedValue)
}
Expand Down
14 changes: 9 additions & 5 deletions src/main/scala/scala/slick/lifted/TypeMapper.scala
Expand Up @@ -6,6 +6,7 @@ import scala.slick.SlickException
import scala.slick.ast.Type
import scala.slick.driver.JdbcProfile
import scala.slick.jdbc.{PositionedParameters, PositionedResult}
import scala.reflect.ClassTag

/**
* A (usually implicit) TypeMapper object represents a Scala type that can be
Expand All @@ -27,7 +28,7 @@ import scala.slick.jdbc.{PositionedParameters, PositionedResult}
* }
* </pre></code>
*/
sealed trait TypeMapper[T] extends (JdbcProfile => TypeMapperDelegate[T]) with Type { self =>
sealed abstract class TypeMapper[T](implicit val classTag: ClassTag[T]) extends (JdbcProfile => TypeMapperDelegate[T]) with Type { self =>
def createOptionTypeMapper: OptionTypeMapper[T] = new OptionTypeMapper[T](self) {
def apply(profile: JdbcProfile) = self(profile).createOptionTypeMapperDelegate
def getBaseTypeMapper[U](implicit ev: Option[U] =:= Option[T]): TypeMapper[U] = self.asInstanceOf[TypeMapper[U]]
Expand Down Expand Up @@ -114,17 +115,20 @@ object TypeMapper {
trait BaseTypeMapper[T] extends TypeMapper[T] {
def getBaseTypeMapper[U](implicit ev: Option[U] =:= T) =
throw new SlickException("A BaseTypeMapper should not have an Option type")
override def toString = "TypeMapper[" + classTag.runtimeClass.getName + "]"
}

abstract class OptionTypeMapper[T](val base: TypeMapper[T]) extends TypeMapper[Option[T]]
abstract class OptionTypeMapper[T : ClassTag](val base: TypeMapper[T]) extends TypeMapper[Option[T]] {
override def toString = "TypeMapper[Option[" + base.classTag.runtimeClass.getName + "]]"
}

/**
* Adding this marker trait to a TypeMapper makes the type eligible for
* numeric operators.
*/
trait NumericTypeMapper

trait TypeMapperDelegate[T] { self =>
trait TypeMapperDelegate[T] extends Type { self =>
/**
* A zero value for the type. This is used as a default instead of NULL when
* used as a non-nullable column.
Expand Down Expand Up @@ -183,7 +187,7 @@ object TypeMapperDelegate {
yield f.get(null).asInstanceOf[Int] -> f.getName)
}

abstract class MappedTypeMapper[T,U](implicit tm: TypeMapper[U]) extends TypeMapper[T] { self =>
abstract class MappedTypeMapper[T : ClassTag,U](implicit tm: TypeMapper[U]) extends TypeMapper[T] { self =>
def map(t: T): U
def comap(u: U): T

Expand All @@ -207,7 +211,7 @@ abstract class MappedTypeMapper[T,U](implicit tm: TypeMapper[U]) extends TypeMap
}

object MappedTypeMapper {
def base[T, U](tmap: T => U, tcomap: U => T)(implicit tm: TypeMapper[U]): BaseTypeMapper[T] =
def base[T : ClassTag, U](tmap: T => U, tcomap: U => T)(implicit tm: TypeMapper[U]): BaseTypeMapper[T] =
new MappedTypeMapper[T, U] with BaseTypeMapper[T] {
def map(t: T) = tmap(t)
def comap(u: U) = tcomap(u)
Expand Down

0 comments on commit 04a47b5

Please sign in to comment.