Skip to content

Commit

Permalink
Merge pull request #149 from slick/tmp/matchSymbols
Browse files Browse the repository at this point in the history
change slick backend to use symbols for matching on Scala trees (direct embedding)
  • Loading branch information
cvogt committed May 1, 2013
2 parents 3ee3803 + 45ea275 commit 5ad68bb
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 129 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,15 @@ class QueryableTest(val tdb: TestDB) extends DBTest {

// filter with more complex condition
assertMatch(
query.filter( c => c.sales > 5 || "Chris" == c.name ),
inMem.filter( c => c.sales > 5 || "Chris" == c.name )
query.filter( c => c.sales > 2 || "Colombian_Decaf" == c.name ),
inMem.filter( c => c.sales > 2 || "Colombian_Decaf" == c.name )
)


assertMatch(
query.filter( c => c.sales > 2 && "Colombian_Decaf" == c.name ),
inMem.filter( c => c.sales > 2 && "Colombian_Decaf" == c.name )
)

// type annotations FIXME canBuildFrom
assertMatch(
query.map[String]( (_:Coffee).name : String ),
Expand Down
284 changes: 158 additions & 126 deletions src/main/scala/scala/slick/direct/SlickBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,46 @@ import scala.reflect.ClassTag
import scala.slick.compiler.CompilerState
import scala.reflect.runtime.universe.TypeRef
import scala.slick.ast.ColumnOption
import scala.annotation.StaticAnnotation
import scala.reflect.runtime.universe._
import scala.reflect.runtime.{currentMirror=>cm}

/** maps a Scala method to a Slick FunctionSymbol */
final case class slickOp(to:FunctionSymbol) extends StaticAnnotation
/** denotes the Scala type the mapped interface refers to */
final class scalaType[+T](t:Type) extends StaticAnnotation
trait OperationMapping{
// Supported operators by Slick
// Slick also supports == for all supported types
@scalaType[Int](typeOf[Int])
trait IntOps{
@ slickOp(Library.+) def +(i:Int) : Int
@ slickOp(Library.+) def +(i:Double) : Double
@ slickOp(Library.<) def <(i:Int) : Boolean
@ slickOp(Library.<) def <(i:Double) : Boolean
@ slickOp(Library.>) def >(i:Int) : Boolean
@ slickOp(Library.>) def >(i:Double) : Boolean
}
@scalaType[Double](typeOf[Double])
trait DoubleOps{
@ slickOp(Library.+) def +(i:Int) : Double
@ slickOp(Library.+) def +(i:Double) : Double
@ slickOp(Library.<) def <(i:Int) : Boolean
@ slickOp(Library.<) def <(i:Double) : Boolean
@ slickOp(Library.>) def >(i:Int) : Boolean
@ slickOp(Library.>) def >(i:Double) : Boolean
}
@scalaType[Boolean](typeOf[Boolean])
trait BooleanOps{
@ slickOp(Library.Not) def unary_! : Boolean
@ slickOp(Library.Or) def ||( b:Boolean ) : Boolean
@ slickOp(Library.And) def &&( b:Boolean ) : Boolean
}
//@scalaType[String](typeOf[String]) // <- scalac crash SI-7426
trait StringOps{
@ slickOp(Library.Concat) def +(i:String) : String
}
}

trait QueryableBackend

Expand Down Expand Up @@ -44,55 +84,68 @@ import CustomNodes._

class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBackend{
type Session = JdbcDriver#Backend#Session
import scala.reflect.runtime.universe._
import scala.reflect.runtime.{currentMirror=>cm}

val columnTypes = Map( // FIXME use symbols instead of strings for type names here
"Int" /*typeOf[Int]*/ -> driver.columnTypes.intJdbcType
,"Double" /*typeOf[Double]*/ -> driver.columnTypes.doubleJdbcType
,"String" /*typeOf[String]*/ -> driver.columnTypes.stringJdbcType
,"scala.Int" /*typeOf[Int]*/ -> driver.columnTypes.intJdbcType
,"scala.Double" /*typeOf[Double]*/ -> driver.columnTypes.doubleJdbcType
,"scala.String" /*typeOf[String]*/ -> driver.columnTypes.stringJdbcType
,"java.lang.String" /*typeOf[String]*/ -> driver.columnTypes.stringJdbcType // FIXME: typeOf[String] leads to java.lang.String, but param.typeSignature to String
,"Boolean" /*typeBof[Boolean]*/ -> driver.columnTypes.booleanJdbcType
,"scala.Boolean" /*typeBof[Boolean]*/ -> driver.columnTypes.booleanJdbcType
)

//def resolveSym( lhs:Type, name:String, rhs:Type* ) = lhs.member(newTermName(name).encodedName).asTerm.resolveOverloaded(actuals = rhs.toList)

val operatorMap : Vector[ (Map[String, FunctionSymbol], List[List[Type]]) ] = {
import Library._
Vector(
Map( "unary_!" -> Library.Not )
->
List(List(typeOf[Any])),
Map( "==" -> Library.==, "!=" -> Library.== )
->
List(
List(typeOf[Any]),
List(typeOf[Any])
),
Map( "+" -> Library.+, "<" -> <, ">" -> > )
->
List(
List(typeOf[Int],typeOf[Double]),
List(typeOf[Int],typeOf[Double])
),
Map( "+" -> Concat )
->
List(
List(typeOf[String],typeOf[java.lang.String]),
List(typeOf[String],typeOf[java.lang.String])
),
Map( "||" -> <, "&&" -> > )
->
List(
List(typeOf[Boolean]),
List(typeOf[Boolean])
)
import slick.ast.StaticType
val columnTypes = {
import driver.columnTypes._
Map( // FIXME use symbols instead of strings for type names here
typeOf[Int].typeSymbol -> StaticType.Int
,typeOf[Double].typeSymbol -> doubleJdbcType
,typeOf[String].typeSymbol -> StaticType.String
,typeOf[Boolean].typeSymbol -> StaticType.Boolean
)
}
/** generates a map from Scala symbols to Slick FunctionSymbols from description in OperatorMapping */
val operatorSymbolMap : Map[Symbol,FunctionSymbol] = {
def annotations[T:TypeTag]( m:Symbol ) = m.annotations.filter{
case Annotation(tpe,_,_) => tpe <:< typeOf[T]
}
typeOf[OperationMapping]
.members
// only take annotated members
.filter(annotations[scalaType[_]](_).size > 0)
.flatMap(
_.typeSignature
.members
.filter(annotations[slickOp](_).size > 0)
)
.map{
specOp =>
val scalaType =
annotations[scalaType[_]](specOp.owner).head.tpe match{
case TypeRef(tpe,sym,args) => args.head
}
val specOpName = specOp.name
def argTypeSyms( s:Symbol ) = s.asMethod.paramss.map(_.map(_.typeSignature.typeSymbol))
// resolve overloaded methods
scalaType.member(specOpName)
.asTerm
.alternatives
.find(
scalaOp =>{
argTypeSyms( scalaOp ) == argTypeSyms( specOp )
}
)
.getOrElse{
throw new SlickException("Could not find Scala method: "+scalaType+"."+specOpName+argTypeSyms( specOp ))
}
.->( annotations[slickOp](specOp).head
match { case Annotation(_,args,_) =>
// look up FunctionSymbol from annotation
// FIXME: make this simpler
val op = args.head.symbol
val mod = op.owner.companionSymbol.asModule
val i = cm.reflectModule(mod).instance
cm.reflect(i).reflectMethod(
op.owner.companionSymbol.typeSignature.member(op.name).asMethod
)().asInstanceOf[FunctionSymbol]
}
)
}
.toMap
.+ ( typeOf[String].member(newTermName("+").encodedName) -> Library.Concat) // workaround for SI-7426
}

def isMapped( sym:Symbol ) = operatorSymbolMap.contains(sym) || sym.name.decoded == "==" || sym.name.decoded == "!="

object removeTypeAnnotations extends Transformer {
def apply( tree:Tree ) = transform(tree)
Expand Down Expand Up @@ -121,30 +174,33 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac

def columnName( sym:Symbol ) = mapper.fieldToColumn( sym )
def columnType( tpe:Type ) = {
val underlying = columnTypes(typeName(underlyingType(tpe)))
if( tpe.typeSymbol == typeOf[Option[_]].typeSymbol ){
val underlying = columnTypes(underlyingTypeSymbol(tpe))
if( isNullable(tpe) ){
underlying.optionType
} else {
underlying
}
}
private def columnField( sym:Symbol ) =
sq.FieldSymbol( columnName(sym) )(
if(isNullable(sym)) List(ColumnOption.Nullable) else List()
if( isNullable(sym) )
List(ColumnOption.Nullable)
else
List()
, columnType(sym.typeSignature)
)
private def typeName( sym:Symbol ) : String = sym.name.decoded
private def typeName( tpe:Type ) : String = typeName( tpe.typeSymbol )
private def isNullable( sym:Symbol ) = typeName(sym) == "Option"
private def isNullable( sym:Symbol ) = sym == typeOf[Option[_]].typeSymbol
private def isNullable( tpe:Type ) : Boolean = isNullable(tpe.typeSymbol)
private def underlyingType( tpe:Type ) =
private def underlyingTypeSymbol( tpe:Type ) : Symbol =
if( isNullable(tpe) )
tpe match {
case TypeRef(_,_,args) => args(0)
case TypeRef(_,_,args) => args(0).typeSymbol
case t => throw new Exception("failed to compute underlying type of "+tpe)
}
else tpe
private def canBeMapped( tpe:Type ) : Boolean = columnTypes.isDefinedAt( typeName(underlyingType(tpe)) )
else tpe.typeSymbol
private def canBeMapped( tpe:Type ) : Boolean = columnTypes.isDefinedAt(underlyingTypeSymbol(tpe))
private def columnSelect( sym:Symbol, sq_symbol:sq.Node ) =
sq.Select(
sq.Ref(sq_symbol.nodeIntrinsicSymbol),
Expand Down Expand Up @@ -205,34 +261,28 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
case o:This => throw new SlickException( "Cannot handle reference to a query in non-static symbol "+o.symbol )
case _ => throw new SlickException("Cannot eval: " + showRaw(tree))
}
def matchingOps(term:Name,actualTypes:List[Type]) = {
operatorMap.collect{
case (str2sym, types)
if str2sym.isDefinedAt( term.decoded )
&& types.zipWithIndex.forall{
case (expectedTypes, index) => expectedTypes.exists( actualTypes(index) <:< _ )
}
=> str2sym( term.decoded )
}
}

private def scala2scalaquery_typed( tree:Tree, scope : Scope ) : Query = {
def s2sq( tree:Tree, scope:Scope=scope ) : Query = scala2scalaquery_typed( tree, scope )
def applyOp( lhs:Tree, term:Name, args:List[Tree], resultType : Type ) : sq.Node = {
val actualTypes = lhs.tpe :: args.map(_.tpe)
val sig = lhs.tpe +"."+term.decoded+(if(args.length > 0)"("+ args.map(_.tpe).mkString(",") +")" else "")
def mapOp( op:Tree, args:List[Tree] ) : sq.Node = {
val Select(lhs:Tree, term:Name) = op
if( term.decoded == "!=" ){
Library.Not.typed(
columnTypes("Boolean"),
applyOp( lhs, newTermName("=="), args, resultType )
columnTypes(typeOf[Boolean].typeSymbol),
mapOp( Select(lhs, newTermName("==")), args )
)
} else {
val matchingOps_ = matchingOps(term,actualTypes)
matchingOps_.size match{
case 0 => throw new SlickException("Operator not supported: "+ sig)
case 1 => matchingOps_.head.typed(columnTypes(resultType.toString), (s2sq( lhs ).node :: args.map( s2sq(_).node )) : _* )
case _ => throw new SlickException("Internal Slick error: resolution of "+ sig +" was ambigious")
}
val (slickOp,slickType) =
if(term.decoded == "=="){
Library.== -> columnTypes(typeOf[Boolean].typeSymbol)
} else {
val sym = op.symbol.asMethod
if( !operatorSymbolMap.keys.toList.contains(sym) ){
throw new SlickException("Direct embedding does not support method "+sym.owner.name+"."+sym.name.decoded+sym.paramss.map(_.map(_.typeSignature.normalize)).mkString("").toString.replace("List","")+":"+sym.returnType)
}
operatorSymbolMap( sym ) -> columnTypes(sym.returnType.typeSymbol)
}
slickOp.typed( slickType, (lhs::args).map(s2sq(_).node) : _* )
}
}
implicit def node2Query(node:sq.Node) = new Query( node, scope )
Expand All @@ -250,7 +300,7 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
}
case ident@Ident(name) => scope(ident.symbol)

case Select( t, term ) if t.tpe.erasure <:< typeOf[BaseQueryable[_]].erasure && term.decoded == "queryable" => s2sq(t)
case op@Select( t, _ ) if op.symbol == typeOf[BaseQueryable[_]].member(newTermName("queryable")) => s2sq(t)

// match columns
case Select(from,name) if mapper.isMapped( from.tpe.widen )
Expand All @@ -266,15 +316,25 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
case x => s2sq( Literal(Constant(x)) )
}
*/

case Apply( Select( queryOps, term ), queryable::Nil )
if queryOps.tpe <:< typeOf[QueryOps.type] && queryable.tpe.erasure <:< typeOf[BaseQueryable[_]].erasure && term.decoded == "query"
=> s2sq( queryable ).node
case a@Apply(op@Select(lhs,term),args) if isMapped( op.symbol )
=> mapOp(op,args)

case op@Select(lhs,term) if isMapped( op.symbol )
=> mapOp(op,List())

case Apply( op, queryable::Nil )
if op.symbol == typeOf[QueryOps.type].member(newTermName("query"))
=> s2sq( queryable ).node

// match queryable methods
case Apply(Select(scala_lhs,term),rhs::Nil)
if scala_lhs.tpe.erasure <:< typeOf[QueryOps[_]].erasure
case op@Select(scala_lhs, term) if typeOf[QueryOps[_]].members.toList.contains(op.symbol) =>
term.decoded match {
case "length" => sq.Pure( Library.CountAll.typed[Int](s2sq(scala_lhs).node ) )
}

case Apply(op@Select(scala_lhs,term),args) if typeOf[QueryOps[_]].members.toList.contains(op.symbol)
=>
val (rhs::Nil) = args
val sq_lhs = s2sq( scala_lhs ).node
val sq_symbol = new sq.AnonSymbol
def flattenAndPrepareForSortBy( node:sq.Node ) : Seq[(sq.Node,sq.Ordering)] = node match {
Expand Down Expand Up @@ -318,62 +378,34 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
)
}

// FIXME: this case is required because of a bug, but should be covered by the next case
case d@Apply(Select(lhs,term),rhs::Nil)
if {
/*println("_a__")
println(showRaw(d))
println(showRaw(lhs))
println(rhs.symbol.asInstanceOf[scala.reflect.internal.Symbols#FreeTerm].value)
println(rhs.tpe)
println("_b__")*/
(
(string_types contains lhs.tpe.widen.toString) //(lhs.tpe <:< typeOf[String])
&& (string_types contains rhs.tpe.widen.toString) // (rhs.tpe <:< typeOf[String] )
&& (List("+").contains( term.decoded ))
)
}
=>
term.decoded match {
case "+" => Library.Concat.typed[String](s2sq( lhs ).node, s2sq( rhs ).node )
}

case a@Apply(op@Select(lhs,term),args) if matchingOps(term,lhs.tpe :: args.map(_.tpe)).size > 0 => applyOp(lhs,term,args,a.tpe)
case op@Select(lhs,term) if matchingOps(term,lhs.tpe :: Nil).size > 0 => applyOp(lhs,term,List(),op.tpe)

// Tuples
case Apply(
Select(Select(Ident(package_), class_), method_),
op,
components
)
if package_.decoded == "scala" && class_.decoded.startsWith("Tuple") && method_.decoded == "apply" // FIXME: match smarter than matching strings
if definitions
.TupleClass
.filter(_ != NoSymbol)
.map( _.companionSymbol.typeSignature.member( newTermName("apply") ) )
.contains( op.symbol )
=>
sq.ProductNode( components.map(s2sq(_).node) )

case Select(scala_lhs, term)
if scala_lhs.tpe.erasure <:< typeOf[QueryOps[_]].erasure && (term.decoded == "length" || term.decoded == "size")
=> sq.Pure( Library.CountAll.typed[Int](s2sq(scala_lhs).node ) )
case Apply( op, scala_rhs::Nil )
if typeOf[NullAndReverseOrder].member(newTermName("nonesLast")).asTerm.alternatives.contains(op.symbol)
=> Nullsorting( s2sq(scala_rhs).node, Nullsorting.Last )

case Apply(
Select(_, term),
scala_rhs::Nil
) if term.decoded == "nonesLast" =>
Nullsorting( s2sq(scala_rhs).node, Nullsorting.Last )
case Apply( op, scala_rhs::Nil )
if typeOf[NullAndReverseOrder].member(newTermName("nonesFirst")).asTerm.alternatives.contains(op.symbol)
=> Nullsorting( s2sq(scala_rhs).node, Nullsorting.First )

case Apply(
Select(_, term),
scala_rhs::Nil
) if term.decoded == "nonesFirst" =>
Nullsorting( s2sq(scala_rhs).node, Nullsorting.First )
case Apply( op, scala_rhs::Nil )
if typeOf[NullAndReverseOrder].member(newTermName("reversed")).asTerm.alternatives.contains(op.symbol)
=> Reverse( s2sq(scala_rhs).node )

case tree if tree.tpe.erasure <:< typeOf[BaseQueryable[_]].erasure
=> val (tpe,query) = toQuery( eval(tree).asInstanceOf[BaseQueryable[_]] ); query

case Apply(
Select(_, term),
scala_rhs::Nil
) if term.decoded == "reversed" =>
Reverse( s2sq(scala_rhs).node )
case tree => throw new Exception( "You probably used currently not supported scala code in a query. No match for:\n" + showRaw(tree) )
}
} catch{
Expand Down

0 comments on commit 5ad68bb

Please sign in to comment.