Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions dataset/src/main/scala/frameless/CanAccess.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package frameless

/** `CanAccess[_, A with B]` indicates that in this context it is possible to
* access columns from both table `A` and table `B`. The first type parameter
* is a dummy argument used for type inference.
*/
sealed trait CanAccess[-T, X]

object CanAccess {
private[this] val theInstance = new CanAccess[Nothing, Nothing] {}
private[frameless] def localCanAccessInstance[X]: CanAccess[Any, X] = theInstance.asInstanceOf[CanAccess[Any, X]]

implicit def globalCanAccessInstance[X] = theInstance.asInstanceOf[CanAccess[X, X]]
// The trick works as follows: `(df: TypedDataset[T]).col('a)` looks for a
// CanAccess[T, T] which is always available thanks to the `globalInstance`
// implicit defined above. Expression for joins (and other multi dataset
// operations) take an `implicit a: CanAccess[Any, U with T] =>` closure.
// Because the first (dummy) type parameter of `CanAccess` is contravariant,
// the locally defined implicit will always be preferred over
// `globalInstance`, which implements the desired behavior.
}
157 changes: 95 additions & 62 deletions dataset/src/main/scala/frameless/TypedDataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package frameless

import frameless.ops._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, CreateStruct, EqualTo}
import org.apache.spark.sql.catalyst.plans.logical.{Join, Project}
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Literal}
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter, FullOuter}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter}
import org.apache.spark.sql._
import shapeless._
import shapeless.ops.hlist.{Prepend, ToTraversable, Tupler}
import CanAccess.localCanAccessInstance

/** [[TypedDataset]] is a safer interface for working with `Dataset`.
*
Expand Down Expand Up @@ -163,13 +164,14 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
*
* It is statically checked that column with such name exists and has type `A`.
*/
def col[A](column: Witness.Lt[Symbol])(
def col[A, X](column: Witness.Lt[Symbol])(
implicit
ca: CanAccess[T, X],
exists: TypedColumn.Exists[T, column.T, A],
encoder: TypedEncoder[A]
): TypedColumn[T, A] = {
): TypedColumn[X, A] = {
val colExpr = dataset.col(column.value.name).as[A](TypedExpressionEncoder[A])
new TypedColumn[T, A](colExpr)
new TypedColumn[X, A](colExpr)
}

object colMany extends SingletonProductArgs {
Expand Down Expand Up @@ -288,6 +290,82 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
): GroupedByManyOps[T, TK, K, KT] = new GroupedByManyOps[T, TK, K, KT](self, groupedBy)
}

/** Computes the inner join of `this` `Dataset` with the `other` `Dataset`,
* returning a `Tuple2` for each pair where condition evaluates to true.
*/
def joinInner[U](other: TypedDataset[U])(condition: CanAccess[Any, T with U] => TypedColumn[T with U, Boolean])
(implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] = {
import FramelessInternals._
val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = resolveSelfJoin(Join(leftPlan, rightPlan, Inner, Some(condition(localCanAccessInstance).expr)))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, U)])
TypedDataset.create[(T, U)](joinedDs)
}

/** Computes the cartesian project of `this` `Dataset` with the `other` `Dataset` */
def joinCross[U](other: TypedDataset[U])
(implicit e: TypedEncoder[(T, U)]): TypedDataset[(T, U)] =
new TypedDataset(self.dataset.joinWith(other.dataset, new Column(Literal(true)), "cross"))

/** Computes the full outer join of `this` `Dataset` with the `other` `Dataset`,
* returning a `Tuple2` for each pair where condition evaluates to true.
*/
def joinFull[U](other: TypedDataset[U])(condition: CanAccess[Any, T with U] => TypedColumn[T with U, Boolean])
(implicit e: TypedEncoder[(Option[T], Option[U])]): TypedDataset[(Option[T], Option[U])] = {
import FramelessInternals._
val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = resolveSelfJoin(Join(leftPlan, rightPlan, FullOuter, Some(condition(localCanAccessInstance).expr)))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(Option[T], Option[U])])
TypedDataset.create[(Option[T], Option[U])](joinedDs)
}

/** Computes the right outer join of `this` `Dataset` with the `other` `Dataset`,
* returning a `Tuple2` for each pair where condition evaluates to true.
*/
def joinRight[U](other: TypedDataset[U])(condition: CanAccess[Any, T with U] => TypedColumn[T with U, Boolean])
(implicit e: TypedEncoder[(Option[T], U)]): TypedDataset[(Option[T], U)] = {
import FramelessInternals._
val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = resolveSelfJoin(Join(leftPlan, rightPlan, RightOuter, Some(condition(localCanAccessInstance).expr)))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(Option[T], U)])
TypedDataset.create[(Option[T], U)](joinedDs)
}

/** Computes the left outer join of `this` `Dataset` with the `other` `Dataset`,
* returning a `Tuple2` for each pair where condition evaluates to true.
*/
def joinLeft[U](other: TypedDataset[U])(condition: CanAccess[Any, T with U] => TypedColumn[T with U, Boolean])
(implicit e: TypedEncoder[(T, Option[U])]): TypedDataset[(T, Option[U])] = {
import FramelessInternals._
val leftPlan = logicalPlan(dataset)
val rightPlan = logicalPlan(other.dataset)
val join = resolveSelfJoin(Join(leftPlan, rightPlan, LeftOuter, Some(condition(localCanAccessInstance).expr)))
val joinedPlan = joinPlan(dataset, join, leftPlan, rightPlan)
val joinedDs = mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, Option[U])])

TypedDataset.create[(T, Option[U])](joinedDs)
}

/** Computes the left semi join of `this` `Dataset` with the `other` `Dataset`,
* returning a `Tuple2` for each pair where condition evaluates to true.
*/
def joinLeftSemi[U](other: TypedDataset[U])(condition: CanAccess[Any, T with U] => TypedColumn[T with U, Boolean]): TypedDataset[T] =
new TypedDataset(self.dataset.join(other.dataset, condition(localCanAccessInstance).untyped, "leftsemi")
.as[T](TypedExpressionEncoder(encoder)))

/** Computes the left anti join of `this` `Dataset` with the `other` `Dataset`,
* returning a `Tuple2` for each pair where condition evaluates to true.
*/
def joinLeftAnti[U](other: TypedDataset[U])(condition: CanAccess[Any, T with U] => TypedColumn[T with U, Boolean]): TypedDataset[T] =
new TypedDataset(self.dataset.join(other.dataset, condition(localCanAccessInstance).untyped, "leftanti")
.as[T](TypedExpressionEncoder(encoder)))

/** Fixes SPARK-6231, for more details see original code in [[Dataset#join]] **/
private def resolveSelfJoin(join: Join): Join = {
val plan = FramelessInternals.ofRows(dataset.sparkSession, join).queryExecution.analyzed.asInstanceOf[Join]
Expand Down Expand Up @@ -315,84 +393,39 @@ class TypedDataset[T] protected[frameless](val dataset: Dataset[T])(implicit val
}
}

def join[A, B](
right: TypedDataset[A],
leftCol: TypedColumn[T, B],
rightCol: TypedColumn[A, B]
): TypedDataset[(T, A)] = {
implicit def re = right.encoder

val leftPlan = FramelessInternals.logicalPlan(dataset)
val rightPlan = FramelessInternals.logicalPlan(right.dataset)
val condition = EqualTo(leftCol.expr, rightCol.expr)

val join = resolveSelfJoin(Join(leftPlan, rightPlan, Inner, Some(condition)))
val joined = FramelessInternals.executePlan(dataset, join)
val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)

val joinedPlan = Project(List(
Alias(CreateStruct(leftOutput), "_1")(),
Alias(CreateStruct(rightOutput), "_2")()
), joined.analyzed)

val joinedDs = FramelessInternals.mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, A)])

TypedDataset.create[(T, A)](joinedDs)
}

def joinLeft[A: TypedEncoder, B](
right: TypedDataset[A],
leftCol: TypedColumn[T, B],
rightCol: TypedColumn[A, B]
)(implicit e: TypedEncoder[(T, Option[A])]): TypedDataset[(T, Option[A])] = {
val leftPlan = FramelessInternals.logicalPlan(dataset)
val rightPlan = FramelessInternals.logicalPlan(right.dataset)
val condition = EqualTo(leftCol.expr, rightCol.expr)

val join = resolveSelfJoin(Join(leftPlan, rightPlan, LeftOuter, Some(condition)))
val joined = FramelessInternals.executePlan(dataset, join)
val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)

val joinedPlan = Project(List(
Alias(CreateStruct(leftOutput), "_1")(),
Alias(CreateStruct(rightOutput), "_2")()
), joined.analyzed)

val joinedDs = FramelessInternals.mkDataset(dataset.sqlContext, joinedPlan, TypedExpressionEncoder[(T, Option[A])])

TypedDataset.create[(T, Option[A])](joinedDs)
}

/** Takes a function from A => R and converts it to a UDF for TypedColumn[T, A] => TypedColumn[T, R].
/** Takes a function from A => R and converts it to a UDF for TypedColumn[A] => TypedColumn[R].
*/
def makeUDF[A: TypedEncoder, R: TypedEncoder](f: A => R):
TypedColumn[T, A] => TypedColumn[T, R] = functions.udf(f)
TypedColumn[T, A] => TypedColumn[T, R] =
functions.udf(f)

/** Takes a function from (A1, A2) => R and converts it to a UDF for
* (TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R].
*/
def makeUDF[A1: TypedEncoder, A2: TypedEncoder, R: TypedEncoder](f: (A1, A2) => R):
(TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] = functions.udf(f)
(TypedColumn[T, A1], TypedColumn[T, A2]) => TypedColumn[T, R] =
functions.udf(f)

/** Takes a function from (A1, A2, A3) => R and converts it to a UDF for
* (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R].
*/
def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3) => R):
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] = functions.udf(f)
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3]) => TypedColumn[T, R] =
functions.udf(f)

/** Takes a function from (A1, A2, A3, A4) => R and converts it to a UDF for
* (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R].
*/
def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4) => R):
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] = functions.udf(f)
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4]) => TypedColumn[T, R] =
functions.udf(f)

/** Takes a function from (A1, A2, A3, A4, A5) => R and converts it to a UDF for
* (TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R].
*/
def makeUDF[A1: TypedEncoder, A2: TypedEncoder, A3: TypedEncoder, A4: TypedEncoder, A5: TypedEncoder, R: TypedEncoder](f: (A1, A2, A3, A4, A5) => R):
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] = functions.udf(f)
(TypedColumn[T, A1], TypedColumn[T, A2], TypedColumn[T, A3], TypedColumn[T, A4], TypedColumn[T, A5]) => TypedColumn[T, R] =
functions.udf(f)

/** Type-safe projection from type T to Tuple1[A]
* {{{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package org.apache.spark.sql

import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.ObjectType

Expand All @@ -23,17 +24,25 @@ object FramelessInternals {

def logicalPlan(ds: Dataset[_]): LogicalPlan = ds.logicalPlan

def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution = {
def executePlan(ds: Dataset[_], plan: LogicalPlan): QueryExecution =
ds.sparkSession.sessionState.executePlan(plan)

def joinPlan(ds: Dataset[_], plan: LogicalPlan, leftPlan: LogicalPlan, rightPlan: LogicalPlan): LogicalPlan = {
val joined = executePlan(ds, plan)
val leftOutput = joined.analyzed.output.take(leftPlan.output.length)
val rightOutput = joined.analyzed.output.takeRight(rightPlan.output.length)

Project(List(
Alias(CreateStruct(leftOutput), "_1")(),
Alias(CreateStruct(rightOutput), "_2")()
), joined.analyzed)
}

def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] = {
def mkDataset[T](sqlContext: SQLContext, plan: LogicalPlan, encoder: Encoder[T]): Dataset[T] =
new Dataset(sqlContext, plan, encoder)
}

def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame = {
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
Dataset.ofRows(sparkSession, logicalPlan)
}

// because org.apache.spark.sql.types.UserDefinedType is private[spark]
type UserDefinedType[A >: Null] = org.apache.spark.sql.types.UserDefinedType[A]
Expand Down
12 changes: 6 additions & 6 deletions dataset/src/test/scala/frameless/ColTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ class ColTests extends TypedDatasetSuite {
x4.col('a)
t4.col('_1)

x4.col[Int]('a)
t4.col[Int]('_1)
x4.col[Int, X4[Int, String, Long, Boolean]]('a)
t4.col[Int, (Int, String, Long, Boolean)]('_1)

illTyped("x4.col[String]('a)", "No column .* of type String in frameless.X4.*")
illTyped("x4.col[String, X4[Int, String, Long, Boolean]]('a)", "No column .* of type String in frameless.X4.*")

x4.col('b)
t4.col('_2)

x4.col[String]('b)
t4.col[String]('_2)
x4.col[String, X4[Int, String, Long, Boolean]]('b)
t4.col[String, (Int, String, Long, Boolean)]('_2)

illTyped("x4.col[Int]('b)", "No column .* of type Int in frameless.X4.*")
illTyped("x4.col[Int, X4[Int, String, Long, Boolean]]('b)", "No column .* of type Int in frameless.X4.*")

()
}
Expand Down
Loading