Skip to content

Commit

Permalink
[SPARK-8944][SQL] Support casting between IntervalType and StringType
Browse files Browse the repository at this point in the history
Author: Wenchen Fan <cloud0fan@outlook.com>

Closes apache#7355 from cloud-fan/fromString and squashes the following commits:

3bbb9d6 [Wenchen Fan] fix code gen
7dab957 [Wenchen Fan] naming fix
0fbbe19 [Wenchen Fan] address comments
ac1f3d1 [Wenchen Fan] Support casting between IntervalType and StringType
  • Loading branch information
cloud-fan authored and rxin committed Jul 13, 2015
1 parent 92540d2 commit 6b89943
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{Interval, UTF8String}


object Cast {
Expand Down Expand Up @@ -55,6 +55,9 @@ object Cast {

case (_, DateType) => true

case (StringType, IntervalType) => true
case (IntervalType, StringType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
Expand Down Expand Up @@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case _ => _ => null
}

// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => Interval.fromString(s.toString))
case _ => _ => null
}

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
Expand Down Expand Up @@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case IntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
Expand Down Expand Up @@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, StringType) =>
defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))")

case (StringType, IntervalType) =>
defineCodeGen(ctx, ev, c =>
s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())")

// fallback for DecimalType, this must be before other numeric types
case (_, dt: DecimalType) =>
super.genCode(ctx, ev)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(0L)))
}

test("case between string and interval") {
import org.apache.spark.unsafe.types.Interval

checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
"interval 1 years 3 months -3 days")
}

}
48 changes: 48 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.unsafe.types;

import java.io.Serializable;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* The internal representation of interval type.
Expand All @@ -30,6 +32,52 @@ public final class Interval implements Serializable {
public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;

/**
* A function to generate regex which matches interval string's unit part like "3 years".
*
* First, we can leave out some units in interval string, and we only care about the value of
* unit, so here we use non-capturing group to wrap the actual regex.
* At the beginning of the actual regex, we should match spaces before the unit part.
* Next is the number part, starts with an optional "-" to represent negative value. We use
* capturing group to wrap this part as we need the value later.
* Finally is the unit name, ends with an optional "s".
*/
private static String unitRegex(String unit) {
return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
}

private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));

private static long toLong(String s) {
if (s == null) {
return 0;
} else {
return Long.valueOf(s);
}
}

public static Interval fromString(String s) {
if (s == null) {
return null;
}
Matcher m = p.matcher(s);
if (!m.matches() || s.equals("interval")) {
return null;
} else {
long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
microseconds += toLong(m.group(9));
return new Interval((int) months, microseconds);
}
}

public final int months;
public final long microseconds;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,50 @@ public void toStringTest() {
i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds");
}

@Test
public void fromStringTest() {
testSingleUnit("year", 3, 36, 0);
testSingleUnit("month", 3, 3, 0);
testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK);
testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY);
testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR);
testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE);
testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND);
testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI);
testSingleUnit("microsecond", 3, 0, 3);

String input;

input = "interval -5 years 23 month";
Interval result = new Interval(-5 * 12 + 23, 0);
assertEquals(Interval.fromString(input), result);

// Error cases
input = "interval 3month 1 hour";
assertEquals(Interval.fromString(input), null);

input = "interval 3 moth 1 hour";
assertEquals(Interval.fromString(input), null);

input = "interval";
assertEquals(Interval.fromString(input), null);

input = "int";
assertEquals(Interval.fromString(input), null);

input = "";
assertEquals(Interval.fromString(input), null);

input = null;
assertEquals(Interval.fromString(input), null);
}

private void testSingleUnit(String unit, int number, int months, long microseconds) {
String input1 = "interval " + number + " " + unit;
String input2 = "interval " + number + " " + unit + "s";
Interval result = new Interval(months, microseconds);
assertEquals(Interval.fromString(input1), result);
assertEquals(Interval.fromString(input2), result);
}
}

0 comments on commit 6b89943

Please sign in to comment.