Skip to content

Commit

Permalink
add python api, fix date functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jul 29, 2015
1 parent 01943d0 commit e47ff2c
Show file tree
Hide file tree
Showing 8 changed files with 380 additions and 349 deletions.
76 changes: 64 additions & 12 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
__all__ += ['lag', 'lead', 'ntile']

__all__ += [
'date_format',
'date_format', 'date_add', 'date_sub', 'add_months', 'months_between',
'year', 'quarter', 'month', 'hour', 'minute', 'second',
'dayofmonth', 'dayofyear', 'weekofyear']

Expand Down Expand Up @@ -716,7 +716,7 @@ def date_format(dateCol, format):
[Row(date=u'04/08/2015')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_format(dateCol, format))
return Column(sc._jvm.functions.date_format(_to_java_column(dateCol), format))


@since(1.5)
Expand All @@ -729,7 +729,7 @@ def year(col):
[Row(year=2015)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.year(col))
return Column(sc._jvm.functions.year(_to_java_column(col)))


@since(1.5)
Expand All @@ -742,7 +742,7 @@ def quarter(col):
[Row(quarter=2)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.quarter(col))
return Column(sc._jvm.functions.quarter(_to_java_column(col)))


@since(1.5)
Expand All @@ -755,7 +755,7 @@ def month(col):
[Row(month=4)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.month(col))
return Column(sc._jvm.functions.month(_to_java_column(col)))


@since(1.5)
Expand All @@ -768,7 +768,7 @@ def dayofmonth(col):
[Row(day=8)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.dayofmonth(col))
return Column(sc._jvm.functions.dayofmonth(_to_java_column(col)))


@since(1.5)
Expand All @@ -781,7 +781,7 @@ def dayofyear(col):
[Row(day=98)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.dayofyear(col))
return Column(sc._jvm.functions.dayofyear(_to_java_column(col)))


@since(1.5)
Expand All @@ -794,7 +794,7 @@ def hour(col):
[Row(hour=13)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.hour(col))
return Column(sc._jvm.functions.hour(_to_java_column(col)))


@since(1.5)
Expand All @@ -807,7 +807,7 @@ def minute(col):
[Row(minute=8)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.minute(col))
return Column(sc._jvm.functions.minute(_to_java_column(col)))


@since(1.5)
Expand All @@ -820,7 +820,7 @@ def second(col):
[Row(second=15)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.second(col))
return Column(sc._jvm.functions.second(_to_java_column(col)))


@since(1.5)
Expand All @@ -829,11 +829,63 @@ def weekofyear(col):
Extract the week number of a given date as integer.
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['a'])
>>> df.select(weekofyear('a').alias('week')).collect()
>>> df.select(weekofyear(df.a).alias('week')).collect()
[Row(week=15)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.weekofyear(col))
return Column(sc._jvm.functions.weekofyear(_to_java_column(col)))


@since(1.5)
def date_add(start, days):
"""
Returns the date that is `days` days after `start`
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(date_add(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 4, 9))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_add(_to_java_column(start), days))


@since(1.5)
def date_sub(start, days):
"""
Returns the date that is `days` days before `start`
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(date_sub(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 4, 7))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.date_sub(_to_java_column(start), days))


@since(1.5)
def add_months(start, months):
"""
Returns the date that is `months` months after `start`
>>> df = sqlContext.createDataFrame([('2015-04-08',)], ['d'])
>>> df.select(add_months(df.d, 1).alias('d')).collect()
[Row(d=datetime.date(2015, 5, 8))]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.add_months(_to_java_column(start), months))


@since(1.5)
def months_between(date1, date2):
"""
Returns the number of months between date1 and date2.
>>> df = sqlContext.createDataFrame([('1997-02-28 10:30:00', '1996-10-30')], ['t', 'd'])
>>> df.select(months_between(df.t, df.d).alias('months')).collect()
[Row(months=3.94959677)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.months_between(_to_java_column(date1), _to_java_column(date2)))


@since(1.5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ object HiveTypeCoercion {
Division ::
PropagateTypes ::
ImplicitTypeCasts ::
DateTimeOperations ::
Nil

// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
Expand Down Expand Up @@ -638,6 +639,24 @@ object HiveTypeCoercion {
}
}

/**
* Turns Add/Subtract of DateType/TimestampType and IntervalType to TimeAdd/TimeSub
*/
object DateTimeOperations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case Add(left, right) if left.dataType == IntervalType =>
Add(right, left) // switch the order

case Add(left, right) if right.dataType == IntervalType =>
Cast(TimeAdd(Cast(left, TimestampType), right), left.dataType)
case Subtract(left, right) if right.dataType == IntervalType =>
Cast(TimeSub(Cast(left, TimestampType), right), left.dataType)
}
}

/**
* Casts types according to the expected input types for [[Expression]]s.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,111 +379,95 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression)
}

/**
* Time Adds Interval.
* Adds an interval to timestamp.
*/
case class TimeAdd(left: Expression, right: Expression)
case class TimeAdd(start: Expression, interval: Expression)
extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left + $right"
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DateType, TimestampType), IntervalType)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, IntervalType)

override def dataType: DataType = TimestampType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[Interval]
left.dataType match {
case DateType =>
DateTimeUtils.dateAddFullInterval(
start.asInstanceOf[Int], itvl.months, itvl.microseconds)
case TimestampType =>
DateTimeUtils.timestampAddFullInterval(
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
}
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], itvl.months, itvl.microseconds)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
left.dataType match {
case DateType =>
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.dateAddFullInterval($sd, $i.months, $i.microseconds)"""
})
case TimestampType => // TimestampType
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddFullInterval($sd, $i.months, $i.microseconds)"""
})
}
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)"""
})
}
}

/**
* Time Subtracts Interval.
* Subtracts an interval from timestamp.
*/
case class TimeSub(left: Expression, right: Expression)
case class TimeSub(start: Expression, interval: Expression)
extends BinaryExpression with ExpectsInputTypes {

override def left: Expression = start
override def right: Expression = interval

override def toString: String = s"$left - $right"
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(DateType, TimestampType), IntervalType)
override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, IntervalType)

override def dataType: DataType = TimestampType

override def nullSafeEval(start: Any, interval: Any): Any = {
val itvl = interval.asInstanceOf[Interval]
left.dataType match {
case DateType =>
DateTimeUtils.dateAddFullInterval(
start.asInstanceOf[Int], 0 - itvl.months, 0 - itvl.microseconds)
case TimestampType =>
DateTimeUtils.timestampAddFullInterval(
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
}
DateTimeUtils.timestampAddInterval(
start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds)
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
left.dataType match {
case DateType =>
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.dateAddFullInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
})
case TimestampType => // TimestampType
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddFullInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
})
}
defineCodeGen(ctx, ev, (sd, i) => {
s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)"""
})
}
}

/**
* Returns the date that is num_months after start_date.
*/
case class AddMonths(left: Expression, right: Expression)
case class AddMonths(startDate: Expression, numMonths: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = startDate
override def right: Expression = numMonths

override def inputTypes: Seq[AbstractDataType] = Seq(DateType, IntegerType)

override def dataType: DataType = DateType

override def nullSafeEval(start: Any, months: Any): Any = {
DateTimeUtils.dateAddYearMonthInterval(start.asInstanceOf[Int], months.asInstanceOf[Int])
DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int])
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val dtu = DateTimeUtils.getClass.getName.stripSuffix("$")
defineCodeGen(ctx, ev, (sd, m) => {
s"""$dtu.dateAddYearMonthInterval($sd, $m)"""
s"""$dtu.dateAddMonths($sd, $m)"""
})
}
}

/**
* Returns number of months between dates date1 and date2.
*/
case class MonthsBetween(left: Expression, right: Expression)
case class MonthsBetween(date1: Expression, date2: Expression)
extends BinaryExpression with ImplicitCastInputTypes {

override def left: Expression = date1
override def right: Expression = date2

override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType, TimestampType)

override def dataType: DataType = DoubleType
Expand Down
Loading

0 comments on commit e47ff2c

Please sign in to comment.