Skip to content

Commit

Permalink
Allow building of non-erased collections through ClassTags.
Browse files Browse the repository at this point in the history
- All Types and Shapes carry a ClassTag.

- Executors use the result type's ClassTag to get the correct builder
  for the result.

- Provide a CollectionTypeConstructor for Array types.
  • Loading branch information
szeiger committed Mar 19, 2014
1 parent 6e981e5 commit 3fc16d8
Show file tree
Hide file tree
Showing 18 changed files with 139 additions and 66 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,10 @@ class ExecutorTest extends TestkitTest[RelationalTestDB] {
assertEquals(e2, r2a)
assertEquals(e2, r2b)
assertEquals(e2, r2c)

val r3a = ts.to[Array].run
assertTrue(r3a.isInstanceOf[Array[(Int, String)]])
val r3b = ts.to[Array].map(_.a).run
assertTrue(r3b.isInstanceOf[Array[Int]])
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.typesafe.slick.testkit.tests
import org.junit.Assert._

import com.typesafe.slick.testkit.util.{JdbcTestDB, TestkitTest}
import scala.reflect.ClassTag

class JdbcMapperTest extends TestkitTest[JdbcTestDB] {
import tdb.profile.simple._
Expand Down Expand Up @@ -191,7 +192,7 @@ class JdbcMapperTest extends TestkitTest[JdbcTestDB] {
case class Pair[A, B](a: A, b: B)

// A Shape that maps Pair to a ProductNode
final class PairShape[Level <: ShapeLevel, M <: Pair[_,_], U <: Pair[_,_], P <: Pair[_,_]](val shapes: Seq[Shape[_, _, _, _]]) extends MappedScalaProductShape[Level, Pair[_,_], M, U, P] {
final class PairShape[Level <: ShapeLevel, M <: Pair[_,_], U <: Pair[_,_] : ClassTag, P <: Pair[_,_]](val shapes: Seq[Shape[_, _, _, _]]) extends MappedScalaProductShape[Level, Pair[_,_], M, U, P] {
def buildValue(elems: IndexedSeq[Any]) = Pair(elems(0), elems(1))
def copy(shapes: Seq[Shape[_, _, _, _]]) = new PairShape(shapes)
}
Expand Down
6 changes: 4 additions & 2 deletions src/main/scala/scala/slick/ast/Node.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package scala.slick.ast

import scala.language.existentials
import scala.slick.SlickException
import scala.slick.util.Logging
import TypeUtil.typeToTypeUtil
import Util._
import scala.reflect.ClassTag

/**
* A node in the query AST.
Expand Down Expand Up @@ -584,11 +586,11 @@ final case class CompiledStatement(statement: String, extra: Any, tpe: Type) ext
}

/** A client-side type mapping */
final case class TypeMapping(val child: Node, val toBase: Any => Any, val toMapped: Any => Any) extends UnaryNode with SimplyTypedNode { self =>
final case class TypeMapping(val child: Node, val toBase: Any => Any, val toMapped: Any => Any, classTag: ClassTag[_]) extends UnaryNode with SimplyTypedNode { self =>
type Self = TypeMapping
def nodeRebuild(ch: Node) = copy(child = ch)
override def toString = "TypeMapping"
protected def buildType = new MappedScalaType(child.nodeType, toBase, toMapped)
protected def buildType = new MappedScalaType(child.nodeType, toBase, toMapped, classTag)
}

/** A parameter from a QueryTemplate which gets turned into a bind variable. */
Expand Down
85 changes: 63 additions & 22 deletions src/main/scala/scala/slick/ast/Type.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ package scala.slick.ast
import scala.language.{implicitConversions, higherKinds}
import scala.slick.SlickException
import scala.collection.generic.CanBuild
import scala.collection.mutable.Builder
import scala.reflect.ClassTag
import scala.collection.mutable.{Builder, ArrayBuilder}
import scala.reflect.{ClassTag, classTag => mkClassTag}
import Util._
import scala.collection.mutable.ArrayBuffer
import scala.annotation.implicitNotFound
import scala.slick.util.TupleSupport

/** Super-trait for all types */
trait Type {
Expand All @@ -20,6 +22,8 @@ trait Type {
throw new SlickException("No type for symbol "+sym+" found in "+this)
/** The structural view of this type */
def structural: Type = this
/** A ClassTag for the erased type of this type's Scala values */
def classTag: ClassTag[_]
}

/** An atomic type (i.e. a type which does not contain other types) */
Expand All @@ -28,11 +32,11 @@ trait AtomicType extends Type {
def children: Seq[Type] = Seq.empty
}

final case class StructType(elements: Seq[(Symbol, Type)]) extends Type {
final case class StructType(elements: IndexedSeq[(Symbol, Type)]) extends Type {
override def toString = "{" + elements.iterator.map{ case (s, t) => s + ": " + t }.mkString(", ") + "}"
lazy val symbolToIndex: Map[Symbol, Int] =
elements.zipWithIndex.map { case ((sym, _), idx) => (sym, idx) }(collection.breakOut)
def children: Seq[Type] = elements.map(_._2)
def children: IndexedSeq[Type] = elements.map(_._2)
def mapChildren(f: Type => Type): StructType =
mapOrNone(elements.map(_._2))(f) match {
case Some(types2) => StructType((elements, types2).zipped.map((e, t) => (e._1, t)))
Expand All @@ -42,12 +46,14 @@ final case class StructType(elements: Seq[(Symbol, Type)]) extends Type {
case ElementSymbol(idx) => elements(idx-1)._2
case _ => elements.find(x => x._1 == sym).map(_._2).getOrElse(super.select(sym))
}
def classTag = TupleSupport.classTagForArity(elements.size)
}

trait OptionType extends Type {
override def toString = "Option[" + elementType + "]"
def elementType: Type
def children: Seq[Type] = Seq(elementType)
def classTag = OptionType.classTag
}

object OptionType {
Expand All @@ -59,6 +65,7 @@ object OptionType {
else OptionType(e2)
}
}
private val classTag = mkClassTag[Option[_]]
}

final case class ProductType(elements: IndexedSeq[Type]) extends Type {
Expand All @@ -75,6 +82,7 @@ final case class ProductType(elements: IndexedSeq[Type]) extends Type {
def children: Seq[Type] = elements
def numberedElements: Iterator[(ElementSymbol, Type)] =
elements.iterator.zipWithIndex.map { case (t, i) => (new ElementSymbol(i+1), t) }
def classTag = TupleSupport.classTagForArity(elements.size)
}

final case class CollectionType(cons: CollectionTypeConstructor, elementType: Type) extends Type {
Expand All @@ -85,50 +93,82 @@ final case class CollectionType(cons: CollectionTypeConstructor, elementType: Ty
else CollectionType(cons, e2)
}
def children: Seq[Type] = Seq(elementType)
def classTag = cons.classTag
}

/** Represents a type constructor that can be usd for a collection-valued query.
* The relevant information for Slick is whether the elements of the collection
* keep their insertion order (isSequential) and whether only distinct elements
* are allowed (isUnique). */
trait CollectionTypeConstructor {
/** The ClassTag for the type constructor */
def classTag: ClassTag[_]
/** Determines if order is relevant */
def isSequential: Boolean
/** Determines if only distinct elements are allowed */
def isUnique: Boolean
def createErasedBuilder: Builder[Any, Any]
/** Create a `Builder` for the collection type, given a ClassTag for the element type */
def createBuilder[E : ClassTag]: Builder[E, Any]
/** Return a CollectionTypeConstructor which builds a subtype of Iterable
* but has the same properties otherwise. */
def iterableSubstitute: CollectionTypeConstructor =
if(isUnique && !isSequential) TypedCollectionTypeConstructor.set
else TypedCollectionTypeConstructor.seq
//TODO We should have a better substitute for (isUnique && isSequential)
}

trait TypedCollectionTypeConstructor[C[_]] extends CollectionTypeConstructor {
def createErasedBuilder: Builder[Any, C[Any]]
@implicitNotFound("Cannot use collection in a query\n collection type: ${C}[_]\n requires implicit of type: scala.slick.ast.TypedCollectionTypeConstructor[${C}]")
abstract class TypedCollectionTypeConstructor[C[_]](val classTag: ClassTag[C[_]]) extends CollectionTypeConstructor {
override def toString = s"Coll[$classTag]"
def createBuilder[E : ClassTag]: Builder[E, C[E]]
}

class ErasedCollectionTypeConstructor[C[_]](canBuildFrom: CanBuild[Any, C[Any]], tag: ClassTag[C[_]]) extends TypedCollectionTypeConstructor[C] {
override def toString = s"Coll[$canBuildFrom]"
val isSequential = classOf[scala.collection.Seq[_]].isAssignableFrom(tag.runtimeClass)
val isUnique = classOf[scala.collection.Set[_]].isAssignableFrom(tag.runtimeClass)
def createErasedBuilder: Builder[Any, C[Any]] = canBuildFrom()
class ErasedCollectionTypeConstructor[C[_]](canBuildFrom: CanBuild[Any, C[Any]], classTag: ClassTag[C[_]]) extends TypedCollectionTypeConstructor[C](classTag) {
val isSequential = classOf[scala.collection.Seq[_]].isAssignableFrom(classTag.runtimeClass)
val isUnique = classOf[scala.collection.Set[_]].isAssignableFrom(classTag.runtimeClass)
def createBuilder[E : ClassTag] = canBuildFrom().asInstanceOf[Builder[E, C[E]]]
}

object TypedCollectionTypeConstructor {
private[this] val arrayClassTag = mkClassTag[Array[_]]
/** The standard TypedCollectionTypeConstructor for Seq */
def seq = forColl[Vector]
/** The standard TypedCollectionTypeConstructor for Set */
def set = forColl[Set]
/** Get a TypedCollectionTypeConstructor for an Iterable type */
implicit def forColl[C[X] <: Iterable[X]](implicit cbf: CanBuild[Any, C[Any]], tag: ClassTag[C[_]]): TypedCollectionTypeConstructor[C] =
new ErasedCollectionTypeConstructor[C](cbf, tag)
/** Get a TypedCollectionTypeConstructor for an Array type */
implicit val forArray: TypedCollectionTypeConstructor[Array] = new TypedCollectionTypeConstructor[Array](arrayClassTag) {
def isSequential = true
def isUnique = false
def createBuilder[E : ClassTag]: Builder[E, Array[E]] = ArrayBuilder.make[E]
}
}

final class MappedScalaType(val baseType: Type, _toBase: Any => Any, _toMapped: Any => Any) extends Type {
final class MappedScalaType(val baseType: Type, _toBase: Any => Any, _toMapped: Any => Any, val classTag: ClassTag[_]) extends Type {
def toBase(v: Any): Any = _toBase(v)
def toMapped(v: Any): Any = _toMapped(v)
override def toString = s"Mapped[$baseType]"
def mapChildren(f: Type => Type): MappedScalaType = {
val e2 = f(baseType)
if(e2 eq baseType) this
else new MappedScalaType(e2, _toBase, _toMapped)
else new MappedScalaType(e2, _toBase, _toMapped, classTag)
}
def children: Seq[Type] = Seq(baseType)
override def select(sym: Symbol) = baseType.select(sym)
}

/** The standard type for freshly constructed nodes without an explicit type. */
final case object UnassignedType extends AtomicType
case object UnassignedType extends AtomicType {
def classTag = throw new SlickException("UnassignedType does not have a ClassTag")
}

/** The type of a structural view of a NominalType before computing the
* proper type in the `inferTypes` phase. */
final case class UnassignedStructuralType(sym: TypeSymbol) extends AtomicType
final case class UnassignedStructuralType(sym: TypeSymbol) extends AtomicType {
def classTag = throw new SlickException("UnassignedStructuralType does not have a ClassTag")
}

/* A type with a name, as used by tables.
*
Expand All @@ -153,6 +193,7 @@ final case class NominalType(sym: TypeSymbol)(val structuralView: Type) extends
case n: NominalType => n.sourceNominalType
case _ => this
}
def classTag = structuralView.classTag
}

/** Something that has a type */
Expand Down Expand Up @@ -227,7 +268,7 @@ object TypeUtilOps {
import TypeUtil.typeToTypeUtil

def replace(tpe: Type, f: PartialFunction[Type, Type]): Type =
f.applyOrElse(tpe, ({ case t: Type => t.mapChildren(_.replace(f)) }): PartialFunction[Type, Type])
f.applyOrElse(tpe, { case t: Type => t.mapChildren(_.replace(f)) }: PartialFunction[Type, Type])

def collect[T](tpe: Type, pf: PartialFunction[Type, T]): Iterable[T] = {
val b = new ArrayBuffer[T]
Expand Down Expand Up @@ -268,8 +309,8 @@ trait ScalaType[T] extends TypedType[T] {
final def scalaType = this
}

class ScalaBaseType[T](implicit val tag: ClassTag[T], val ordering: scala.math.Ordering[T]) extends ScalaType[T] with BaseTypedType[T] {
override def toString = "ScalaType[" + tag.runtimeClass.getName + "]"
class ScalaBaseType[T](implicit val classTag: ClassTag[T], val ordering: scala.math.Ordering[T]) extends ScalaType[T] with BaseTypedType[T] {
override def toString = "ScalaType[" + classTag.runtimeClass.getName + "]"
def nullable = false
def ordered = ordering ne null
def scalaOrderingFor(ord: Ordering) = {
Expand All @@ -285,9 +326,9 @@ class ScalaBaseType[T](implicit val tag: ClassTag[T], val ordering: scala.math.O
}
}
}
override def hashCode = tag.hashCode
override def hashCode = classTag.hashCode
override def equals(o: Any) = o match {
case t: ScalaBaseType[_] => tag == t.tag
case t: ScalaBaseType[_] => classTag == t.classTag
case _ => false
}
}
Expand All @@ -307,7 +348,7 @@ object ScalaBaseType {

private[this] val all: Map[ClassTag[_], ScalaBaseType[_]] =
Seq(booleanType, bigDecimalType, byteType, charType, doubleType,
floatType, intType, longType, nullType, shortType, stringType).map(s => (s.tag, s)).toMap
floatType, intType, longType, nullType, shortType, stringType).map(s => (s.classTag, s)).toMap

def apply[T](implicit tag: ClassTag[T], ord: scala.math.Ordering[T] = null): ScalaBaseType[T] =
all.getOrElse(tag, new ScalaBaseType[T]).asInstanceOf[ScalaBaseType[T]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import scala.language.experimental.macros
import scala.annotation.unchecked.{uncheckedVariance => uv}
import scala.reflect.macros.Context
import scala.slick.lifted.{MappedScalaProductShape, Shape, ShapeLevel}
import scala.reflect.ClassTag

/** A heterogenous list where each element has its own type. */
sealed abstract class HList extends Product {
Expand Down Expand Up @@ -126,7 +127,7 @@ sealed abstract class HList extends Product {
final object HList {
import syntax._

final class HListShape[Level <: ShapeLevel, M <: HList, U <: HList, P <: HList](val shapes: Seq[Shape[_, _, _, _]]) extends MappedScalaProductShape[Level, HList, M, U, P] {
final class HListShape[Level <: ShapeLevel, M <: HList, U <: HList : ClassTag, P <: HList](val shapes: Seq[Shape[_, _, _, _]]) extends MappedScalaProductShape[Level, HList, M, U, P] {
def buildValue(elems: IndexedSeq[Any]) = elems.foldRight(HNil: HList)(_ :: _)
def copy(shapes: Seq[Shape[_, _, _, _]]) = new HListShape(shapes)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class CreateResultSetMapping extends Phase {
case StructType(ch) =>
ProductNode(ch.map { case (_, t) => f(t) })
case t: MappedScalaType =>
TypeMapping(f(t.baseType), t.toBase, t.toMapped)
TypeMapping(f(t.baseType), t.toBase, t.toMapped, t.classTag)
case n @ NominalType(ts) => tables.get(ts) match {
case Some(n) => f(n.nodeType)
case None => f(n.structuralView)
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/scala/slick/direct/SlickBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import scala.slick.driver._
import scala.slick.{ast => sq}
import scala.slick.ast.{Library, FunctionSymbol, Dump, ColumnOption}
import scala.slick.compiler.CompilerState
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeRef
import scala.annotation.StaticAnnotation
import scala.reflect.runtime.universe._
Expand Down Expand Up @@ -213,7 +214,8 @@ class SlickBackend( val driver: JdbcDriver, mapper:Mapper ) extends QueryableBac
)( (v match {
case v:Vector[_] => v
case v:Product => v.productIterator.toVector
}):_* )
}):_* ),
ClassTag(typetag.mirror.runtimeClass(typetag.tpe))
))
new Query( tableExp, Scope() )
}
Expand Down
13 changes: 7 additions & 6 deletions src/main/scala/scala/slick/driver/JdbcExecutorComponent.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scala.slick.driver

import scala.slick.ast.{CompiledStatement, First, ResultSetMapping, Node}
import scala.collection.mutable.Builder
import scala.slick.ast._
import scala.slick.ast.Util._
import scala.slick.ast.TypeUtil._
import scala.slick.util.SQLBuilder
Expand All @@ -21,14 +22,14 @@ trait JdbcExecutorComponent extends SqlExecutorComponent { driver: JdbcDriver =>
tree.findNode(_.isInstanceOf[CompiledStatement]).get
.asInstanceOf[CompiledStatement].extra.asInstanceOf[SQLBuilder.Result].sql

def run(implicit session: Backend#Session): R = (tree match {
case rsm: ResultSetMapping =>
val b = rsm.nodeType.asCollectionType.cons.createErasedBuilder
def run(implicit session: Backend#Session): R = tree match {
case rsm @ ResultSetMapping(_, _, CompiledMapping(_, elemType)) :@ CollectionType(cons, el) =>
val b = cons.createBuilder(el.classTag).asInstanceOf[Builder[Any, R]]
createQueryInvoker[Any](rsm, param).foreach({ x => b += x }, 0)(session)
b.result()
case First(rsm: ResultSetMapping) =>
createQueryInvoker[Any](rsm, param).first
}).asInstanceOf[R]
createQueryInvoker[R](rsm, param).first
}
}

class UnshapedQueryExecutorDef[M](value: M) extends super.UnshapedQueryExecutorDef[M](value) {
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/scala/slick/driver/JdbcTypesComponent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ trait JdbcTypesComponent extends RelationalTypesComponent { driver: JdbcDriver =
}
}

abstract class MappedJdbcType[T, U](implicit tmd: JdbcType[U], tag: ClassTag[T]) extends JdbcType[T] {
abstract class MappedJdbcType[T, U](implicit tmd: JdbcType[U], val classTag: ClassTag[T]) extends JdbcType[T] {
def map(t: T): U
def comap(u: U): T

Expand Down Expand Up @@ -112,7 +112,7 @@ trait JdbcTypesComponent extends RelationalTypesComponent { driver: JdbcDriver =
throw new SlickException("No SQL type name found in java.sql.Types for code "+t))
}

abstract class DriverJdbcType[T : ClassTag] extends JdbcType[T] with BaseTypedType[T] {
abstract class DriverJdbcType[T](implicit val classTag: ClassTag[T]) extends JdbcType[T] with BaseTypedType[T] {
def scalaType = ScalaBaseType[T]
def sqlTypeName: String = driver.defaultSqlTypeName(this)
def valueToSQLLiteral(value: T) =
Expand Down
Loading

0 comments on commit 3fc16d8

Please sign in to comment.