Skip to content

Commit

Permalink
[SPARK-8945][SQL] Add add and subtract expressions for IntervalType
Browse files Browse the repository at this point in the history
JIRA: https://issues.apache.org/jira/browse/SPARK-8945

Add add and subtract expressions for IntervalType.

Author: Liang-Chi Hsieh <viirya@appier.com>

This patch had conflicts when merged, resolved by
Committer: Reynold Xin <rxin@databricks.com>

Closes apache#7398 from viirya/interval_add_subtract and squashes the following commits:

acd1f1e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract
5abae28 [Liang-Chi Hsieh] For comments.
6f5b72e [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract
dbe3906 [Liang-Chi Hsieh] For comments.
13a2fc5 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into interval_add_subtract
83ec129 [Liang-Chi Hsieh] Remove intervalMethod.
acfe1ab [Liang-Chi Hsieh] Fix scala style.
d3e9d0e [Liang-Chi Hsieh] Add add and subtract expressions for IntervalType.
  • Loading branch information
viirya authored and rxin committed Jul 17, 2015
1 parent 305e77c commit eba6a1a
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.Interval


case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def dataType: DataType = child.dataType

Expand All @@ -36,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
case dt: IntervalType => defineCodeGen(ctx, ev, c => s"$c.negate()")
}

protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
protected override def nullSafeEval(input: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input.asInstanceOf[Interval].negate()
} else {
numeric.negate(input)
}
}
}

case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def prettyName: String = "positive"

override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

override def dataType: DataType = child.dataType

Expand Down Expand Up @@ -95,32 +103,66 @@ private[sql] object BinaryArithmetic {

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

override def symbol: String = "+"
override def decimalMethod: String = "$plus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input1.asInstanceOf[Interval].add(input2.asInstanceOf[Interval])
} else {
numeric.plus(input1, input2)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case IntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.add($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = NumericType
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

override def symbol: String = "-"
override def decimalMethod: String = "$minus"

override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)

private lazy val numeric = TypeUtils.getNumeric(dataType)

protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
if (dataType.isInstanceOf[IntervalType]) {
input1.asInstanceOf[Interval].subtract(input2.asInstanceOf[Interval])
} else {
numeric.minus(input1, input2)
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
case dt: DecimalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)")
case ByteType | ShortType =>
defineCodeGen(ctx, ev,
(eval1, eval2) => s"(${ctx.javaType(dataType)})($eval1 $symbol $eval2)")
case IntervalType =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.subtract($eval2)")
case _ =>
defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1 $symbol $eval2")
}
}

case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types._


// These classes are here to avoid issues with serialization and integration with quasiquotes.
Expand Down Expand Up @@ -69,6 +69,7 @@ class CodeGenContext {
mutableStates += ((javaType, variableName, initialValue))
}

final val intervalType: String = classOf[Interval].getName
final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand Down Expand Up @@ -137,6 +138,7 @@ class CodeGenContext {
case dt: DecimalType => "Decimal"
case BinaryType => "byte[]"
case StringType => "UTF8String"
case IntervalType => intervalType
case _: StructType => "InternalRow"
case _: ArrayType => s"scala.collection.Seq"
case _: MapType => s"scala.collection.Map"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types._

object Literal {
def apply(v: Any): Literal = v match {
Expand All @@ -42,6 +42,7 @@ object Literal {
case t: Timestamp => Literal(DateTimeUtils.fromJavaTimestamp(t), TimestampType)
case d: Date => Literal(DateTimeUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case i: Interval => Literal(i, IntervalType)
case null => Literal(null, NullType)
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)

/**
* Types that include numeric types and interval type. They are only used in unary_minus,
* unary_positive, add and subtract operations.
*/
val NumericAndInterval = TypeCollection(NumericType, IntervalType)

def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)

def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
}

test("check types for unary arithmetic") {
assertError(UnaryMinus('stringField), "expected to be of type numeric")
assertError(UnaryMinus('stringField), "type (numeric or interval)")
assertError(Abs('stringField), "expected to be of type numeric")
assertError(BitwiseNot('stringField), "expected to be of type integral")
}
Expand All @@ -78,8 +78,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertErrorForDifferingTypes(MaxOf('intField, 'booleanField))
assertErrorForDifferingTypes(MinOf('intField, 'booleanField))

assertError(Add('booleanField, 'booleanField), "accepts numeric type")
assertError(Subtract('booleanField, 'booleanField), "accepts numeric type")
assertError(Add('booleanField, 'booleanField), "accepts (numeric or interval) type")
assertError(Subtract('booleanField, 'booleanField), "accepts (numeric or interval) type")
assertError(Multiply('booleanField, 'booleanField), "accepts numeric type")
assertError(Divide('booleanField, 'booleanField), "accepts numeric type")
assertError(Remainder('booleanField, 'booleanField), "accepts numeric type")
Expand Down
17 changes: 17 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1492,4 +1492,21 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
// Currently we don't yet support nanosecond
checkIntervalParseError("select interval 23 nanosecond")
}

test("SPARK-8945: add and subtract expressions for interval type") {
import org.apache.spark.unsafe.types.Interval

val df = sql("select interval 3 years -3 month 7 week 123 microseconds as i")
checkAnswer(df, Row(new Interval(12 * 3 - 3, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123)))

checkAnswer(df.select(df("i") + new Interval(2, 123)),
Row(new Interval(12 * 3 - 3 + 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 + 123)))

checkAnswer(df.select(df("i") - new Interval(2, 123)),
Row(new Interval(12 * 3 - 3 - 2, 7L * 1000 * 1000 * 3600 * 24 * 7 + 123 - 123)))

// unary minus
checkAnswer(df.select(-df("i")),
Row(new Interval(-(12 * 3 - 3), -(7L * 1000 * 1000 * 3600 * 24 * 7 + 123))))
}
}
16 changes: 16 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,22 @@ public Interval(int months, long microseconds) {
this.microseconds = microseconds;
}

public Interval add(Interval that) {
int months = this.months + that.months;
long microseconds = this.microseconds + that.microseconds;
return new Interval(months, microseconds);
}

public Interval subtract(Interval that) {
int months = this.months - that.months;
long microseconds = this.microseconds - that.microseconds;
return new Interval(months, microseconds);
}

public Interval negate() {
return new Interval(-this.months, -this.microseconds);
}

@Override
public boolean equals(Object other) {
if (this == other) return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,44 @@ public void fromStringTest() {
assertEquals(Interval.fromString(input), null);
}

@Test
public void addTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";

Interval interval = Interval.fromString(input);
Interval interval2 = Interval.fromString(input2);

assertEquals(interval.add(interval2), new Interval(5, 101 * MICROS_PER_HOUR));

input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";

interval = Interval.fromString(input);
interval2 = Interval.fromString(input2);

assertEquals(interval.add(interval2), new Interval(65, 119 * MICROS_PER_HOUR));
}

@Test
public void subtractTest() {
String input = "interval 3 month 1 hour";
String input2 = "interval 2 month 100 hour";

Interval interval = Interval.fromString(input);
Interval interval2 = Interval.fromString(input2);

assertEquals(interval.subtract(interval2), new Interval(1, -99 * MICROS_PER_HOUR));

input = "interval -10 month -81 hour";
input2 = "interval 75 month 200 hour";

interval = Interval.fromString(input);
interval2 = Interval.fromString(input2);

assertEquals(interval.subtract(interval2), new Interval(-85, -281 * MICROS_PER_HOUR));
}

private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
Expand Down

0 comments on commit eba6a1a

Please sign in to comment.