Skip to content

Commit

Permalink
Support casting between IntervalType and StringType
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jul 11, 2015
1 parent e14b545 commit ac1f3d1
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.types.{Interval, UTF8String}


object Cast {
Expand Down Expand Up @@ -55,6 +55,9 @@ object Cast {

case (_, DateType) => true

case (StringType, IntervalType) => true
case (IntervalType, StringType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
Expand Down Expand Up @@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case _ => _ => null
}

// IntervalConverter
private[this] def castToInterval(from: DataType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => Interval.fromString(s.toString))
case _ => _ => null
}

// LongConverter
private[this] def castToLong(from: DataType): Any => Any = from match {
case StringType =>
Expand Down Expand Up @@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case DateType => castToDate(from)
case decimal: DecimalType => castToDecimal(from, decimal)
case TimestampType => castToTimestamp(from)
case IntervalType => castToInterval(from)
case BooleanType => castToBoolean(from)
case ByteType => castToByte(from)
case ShortType => castToShort(from)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper {
InternalRow(0L)))
}

test("case between string and interval") {
import org.apache.spark.unsafe.types.Interval

checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType),
new Interval(-3, 7 * Interval.MICROS_PER_HOUR))
checkEvaluation(Cast(Literal.create(
new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType),
"interval 1 years 3 months -3 days")
}

}
38 changes: 38 additions & 0 deletions unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.spark.unsafe.types;

import java.io.Serializable;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* The internal representation of interval type.
Expand All @@ -30,6 +32,42 @@ public final class Interval implements Serializable {
public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24;
public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7;

private static String unitRegex(String unit) {
return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?";
}

private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") +
unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") +
unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond"));

private static long toLong(String s) {
if (s == null) {
return 0;
} else {
return Long.valueOf(s);
}
}

public static Interval fromString(String s) {
if (s == null) {
return null;
}
Matcher m = p.matcher(s);
if (!m.matches() || s.equals("interval")) {
return null;
} else {
long months = toLong(m.group(1)) * 12 + toLong(m.group(2));
long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK;
microseconds += toLong(m.group(4)) * MICROS_PER_DAY;
microseconds += toLong(m.group(5)) * MICROS_PER_HOUR;
microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE;
microseconds += toLong(m.group(7)) * MICROS_PER_SECOND;
microseconds += toLong(m.group(8)) * MICROS_PER_MILLI;
microseconds += toLong(m.group(9));
return new Interval((int) months, microseconds);
}
}

public final int months;
public final long microseconds;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,34 @@ public void toStringTest() {
i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123);
assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds");
}

@Test
public void fromStringTest() {
String s;
Interval i;

s = "interval 2 weeks -6 minute";
i = new Interval(0, 2 * MICROS_PER_WEEK - 6 * MICROS_PER_MINUTE);
assertEquals(Interval.fromString(s), i);

s = "interval -5 years 23 month";
i = new Interval(-5 * 12 + 23, 0);
assertEquals(Interval.fromString(s), i);

s = "interval 3month 1 hour";
i = null;
assertEquals(Interval.fromString(s), i);

s = "interval 3 moth 1 hour";
i = null;
assertEquals(Interval.fromString(s), i);

s = "interval";
i = null;
assertEquals(Interval.fromString(s), i);

s = null;
i = null;
assertEquals(Interval.fromString(s), i);
}
}

0 comments on commit ac1f3d1

Please sign in to comment.