Skip to content

Commit

Permalink
[SPARK-8186] [SPARK-8187] [SPARK-8194] [SPARK-8198] [SPARK-9133] [SPA…
Browse files Browse the repository at this point in the history
…RK-9290] [SQL] functions: date_add, date_sub, add_months, months_between, time-interval calculation

This PR is based on apache#7589 , thanks to adrian-wang

Added SQL function date_add, date_sub, add_months, month_between, also add a rule for
add/subtract of date/timestamp and interval.

Closes apache#7589

cc rxin

Author: Daoyuan Wang <daoyuan.wang@intel.com>
Author: Davies Liu <davies@databricks.com>

Closes apache#7754 from davies/date_add and squashes the following commits:

e8c633a [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add
9e8e085 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add
6224ce4 [Davies Liu] fix conclict
bd18cd4 [Davies Liu] Merge branch 'master' of github.com:apache/spark into date_add
e47ff2c [Davies Liu] add python api, fix date functions
01943d0 [Davies Liu] Merge branch 'master' into date_add
522e91a [Daoyuan Wang] fix
e8a639a [Daoyuan Wang] fix
42df486 [Daoyuan Wang] fix style
87c4b77 [Daoyuan Wang] function add_months, months_between and some fixes
1a68e03 [Daoyuan Wang] poc of time interval calculation
c506661 [Daoyuan Wang] function date_add , date_sub
  • Loading branch information
adrian-wang authored and davies committed Jul 30, 2015
1 parent d8cfd53 commit 1abf7dc
Show file tree
Hide file tree
Showing 10 changed files with 791 additions and 162 deletions.
76 changes: 64 additions & 12 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
__all__ += ['lag', 'lead', 'ntile']

__all__ += [
'date_format',
'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
'year', 'quarter', 'month', 'hour', 'minute', 'second',
'dayofmonth', 'dayofyear', 'weekofyear']

Expand Down Expand Up @@ -716,7 +716,7 @@ def date_format(dateCol, format):
[Row(date=u'04/08/2015')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_format(dateCol, format))
return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))


@since(1.5)
Expand All @@ -729,7 +729,7 @@ def year(col):
[Row(year=2015)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.year(col))
return Column(sc._jvm.functions.year(_to_java_column(col)))


@since(1.5)
Expand All @@ -742,7 +742,7 @@ def quarter(col):
[Row(quarter=2)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.quarter(col))
return Column(sc._jvm.functions.quarter(_to_java_column(col)))


@since(1.5)
Expand All @@ -755,7 +755,7 @@ def month(col):
[Row(month=4)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.month(col))
return Column(sc._jvm.functions.month(_to_java_column(col)))


@since(1.5)
Expand All @@ -768,7 +768,7 @@ def dayofmonth(col):
[Row(day=8)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.dayofmonth(col))
return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))


@since(1.5)
Expand All @@ -781,7 +781,7 @@ def dayofyear(col):
[Row(day=98)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.dayofyear(col))
return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))


@since(1.5)
Expand All @@ -794,7 +794,7 @@ def hour(col):
[Row(hour=13)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.hour(col))
return Column(sc._jvm.functions.hour(_to_java_column(col)))


@since(1.5)
Expand All @@ -807,7 +807,7 @@ def minute(col):
[Row(minute=8)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.minute(col))
return Column(sc._jvm.functions.minute(_to_java_column(col)))


@since(1.5)
Expand All @@ -820,7 +820,7 @@ def second(col):
[Row(second=15)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.second(col))
return Column(sc._jvm.functions.second(_to_java_column(col)))


@since(1.5)
Expand All @@ -829,11 +829,63 @@ def weekofyear(col):
Extract the week number of a given date as integer.
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
>>> df.select(weekofyear('a').alias('week')).collect()
>>> df.select(weekofyear(df.a).alias('week')).collect()
[Row(week=15)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.weekofyear(col))
return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))


@since(1.5)
def date_add(start, days):
"""
Returns the date that is `days` days after `start`
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(date_add(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 4, 9))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_add(_to_java_column(start), days))


@since(1.5)
def date_sub(start, days):
"""
Returns the date that is `days` days before `start`
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(date_sub(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 4, 7))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))


@since(1.5)
def add_months(start, months):
"""
Returns the date that is `months` months after `start`
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(add_months(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 5, 8))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.add_months(_to_java_column(start), months))


@since(1.5)
def months_between(date1, date2):
"""
Returns the number of months between date1 and date2.
>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
>>> df.select(months_between(df.t, df.d).alias('months')).collect()
[Row(months=3.9495967...)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))


@since(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,12 @@ object FunctionRegistry {
expression[Upper]("upper"),

// datetime functions
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
expression[DateSub]("date_sub"),
expression[DayOfMonth]("day"),
expression[DayOfYear]("dayofyear"),
expression[DayOfMonth]("dayofmonth"),
Expand All @@ -216,6 +219,7 @@ object FunctionRegistry {
expression[LastDay]("last_day"),
expression[Minute]("minute"),
expression[Month]("month"),
expression[MonthsBetween]("months_between"),
expression[NextDay]("next_day"),
expression[Quarter]("quarter"),
expression[Second]("second"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ object HiveTypeCoercion {
Division ::
PropagateTypes ::
ImplicitTypeCasts ::
DateTimeOperations ::
Nil

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
Expand Down Expand Up @@ -638,6 +639,27 @@ object HiveTypeCoercion {
}
}

/**
* Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType
* to TimeAdd/TimeSub
*/
object DateTimeOperations extends Rule[LogicalPlan] {

private val acceptedTypes = Seq(DateType, TimestampType, StringType)

def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case Add(l @ CalendarIntervalType(), r) if acceptedTypes.contains(r.dataType) =>
Cast(TimeAdd(r, l), r.dataType)
case Add(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
Cast(TimeAdd(l, r), l.dataType)
case Subtract(l, r @ CalendarIntervalType()) if acceptedTypes.contains(l.dataType) =>
Cast(TimeSub(l, r), l.dataType)
}
}

/**
* Casts types according to the expected input types for [[Expression]]s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

import scala.util.Try

Expand Down Expand Up @@ -63,6 +63,53 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback {
}
}

/**
* Adds a number of days to startdate.
*/
case class DateAdd(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] + d.asInstanceOf[Int]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (sd, d) => {
s"""${ev.primitive} = $sd + $d;"""
})
}
}

/**
* Subtracts a number of days to startdate.
*/
case class DateSub(startDate: Expression, days: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = startDate
override def right: Expression = days

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, d: Any): Any = {
start.asInstanceOf[Int] - d.asInstanceOf[Int]
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, (sd, d) => {
s"""${ev.primitive} = $sd - $d;"""
})
}
}

case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType)
Expand Down Expand Up @@ -543,3 +590,109 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)

override def prettyName: String = "next_day"
}

/**
* Adds an interval to timestamp.
*/
case class TimeAdd(start: Expression, interval: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left + $right"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
})
}
}

/**
* Subtracts an interval from timestamp.
*/
case class TimeSub(start: Expression, interval: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left - $right"
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, CalendarIntervalType)

override def dataType: DataType = TimestampType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[CalendarInterval]
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
})
}
}

/**
* Returns the date that is num_months after start_date.
*/
case class AddMonths(startDate: Expression, numMonths: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def right: Expression = numMonths

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, months: Any): Any = {
DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, m) => {
s"""$dtu.dateAddMonths($sd, $m)"""
})
}
}

/**
* Returns number of months between dates date1 and date2.
*/
case class MonthsBetween(date1: Expression, date2: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = date1
override def right: Expression = date2

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)

override def dataType: DataType = DoubleType

override def nullSafeEval(t1: Any, t2: Any): Any = {
DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (l, r) => {
s"""$dtu.monthsBetween($l, $r)"""
})
}
}
Loading

0 comments on commit 1abf7dc

Please sign in to comment.