diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 7f2383dedc035..ab02addfb4d25 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{Interval, UTF8String} object Cast { @@ -55,6 +55,9 @@ object Cast { case (_, DateType) => true + case (StringType, IntervalType) => true + case (IntervalType, StringType) => true + case (StringType, _: NumericType) => true case (BooleanType, _: NumericType) => true case (DateType, _: NumericType) => true @@ -232,6 +235,13 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case _ => _ => null } + // IntervalConverter + private[this] def castToInterval(from: DataType): Any => Any = from match { + case StringType => + buildCast[UTF8String](_, s => Interval.fromString(s.toString)) + case _ => _ => null + } + // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => @@ -405,6 +415,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case DateType => castToDate(from) case decimal: DecimalType => castToDecimal(from, decimal) case TimestampType => castToTimestamp(from) + case IntervalType => castToInterval(from) case BooleanType => castToBoolean(from) case ByteType => castToByte(from) case ShortType => castToShort(from) @@ -442,6 +453,10 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w case (_, StringType) => defineCodeGen(ctx, ev, c => s"${ctx.stringType}.fromString(String.valueOf($c))") + case (StringType, IntervalType) => + defineCodeGen(ctx, ev, c => + s"org.apache.spark.unsafe.types.Interval.fromString($c.toString())") + // fallback for DecimalType, this must be before other numeric types case (_, dt: DecimalType) => super.genCode(ctx, ev) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 919fdd470b79a..1de161c367a1d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -563,4 +563,14 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(0L))) } + test("case between string and interval") { + import org.apache.spark.unsafe.types.Interval + + checkEvaluation(Cast(Literal("interval -3 month 7 hours"), IntervalType), + new Interval(-3, 7 * Interval.MICROS_PER_HOUR)) + checkEvaluation(Cast(Literal.create( + new Interval(15, -3 * Interval.MICROS_PER_DAY), IntervalType), StringType), + "interval 1 years 3 months -3 days") + } + } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java index 0af982d4844c2..eb7475e9df869 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/Interval.java @@ -18,6 +18,8 @@ package org.apache.spark.unsafe.types; import java.io.Serializable; +import java.util.regex.Matcher; +import java.util.regex.Pattern; /** * The internal representation of interval type. @@ -30,6 +32,52 @@ public final class Interval implements Serializable { public static final long MICROS_PER_DAY = MICROS_PER_HOUR * 24; public static final long MICROS_PER_WEEK = MICROS_PER_DAY * 7; + /** + * A function to generate regex which matches interval string's unit part like "3 years". + * + * First, we can leave out some units in interval string, and we only care about the value of + * unit, so here we use non-capturing group to wrap the actual regex. + * At the beginning of the actual regex, we should match spaces before the unit part. + * Next is the number part, starts with an optional "-" to represent negative value. We use + * capturing group to wrap this part as we need the value later. + * Finally is the unit name, ends with an optional "s". + */ + private static String unitRegex(String unit) { + return "(?:\\s+(-?\\d+)\\s+" + unit + "s?)?"; + } + + private static Pattern p = Pattern.compile("interval" + unitRegex("year") + unitRegex("month") + + unitRegex("week") + unitRegex("day") + unitRegex("hour") + unitRegex("minute") + + unitRegex("second") + unitRegex("millisecond") + unitRegex("microsecond")); + + private static long toLong(String s) { + if (s == null) { + return 0; + } else { + return Long.valueOf(s); + } + } + + public static Interval fromString(String s) { + if (s == null) { + return null; + } + Matcher m = p.matcher(s); + if (!m.matches() || s.equals("interval")) { + return null; + } else { + long months = toLong(m.group(1)) * 12 + toLong(m.group(2)); + long microseconds = toLong(m.group(3)) * MICROS_PER_WEEK; + microseconds += toLong(m.group(4)) * MICROS_PER_DAY; + microseconds += toLong(m.group(5)) * MICROS_PER_HOUR; + microseconds += toLong(m.group(6)) * MICROS_PER_MINUTE; + microseconds += toLong(m.group(7)) * MICROS_PER_SECOND; + microseconds += toLong(m.group(8)) * MICROS_PER_MILLI; + microseconds += toLong(m.group(9)); + return new Interval((int) months, microseconds); + } + } + public final int months; public final long microseconds; diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java index 0f4f38b2b03be..44a949a371f2b 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/IntervalSuite.java @@ -56,4 +56,50 @@ public void toStringTest() { i = new Interval(34, 3 * MICROS_PER_WEEK + 13 * MICROS_PER_HOUR + 123); assertEquals(i.toString(), "interval 2 years 10 months 3 weeks 13 hours 123 microseconds"); } + + @Test + public void fromStringTest() { + testSingleUnit("year", 3, 36, 0); + testSingleUnit("month", 3, 3, 0); + testSingleUnit("week", 3, 0, 3 * MICROS_PER_WEEK); + testSingleUnit("day", 3, 0, 3 * MICROS_PER_DAY); + testSingleUnit("hour", 3, 0, 3 * MICROS_PER_HOUR); + testSingleUnit("minute", 3, 0, 3 * MICROS_PER_MINUTE); + testSingleUnit("second", 3, 0, 3 * MICROS_PER_SECOND); + testSingleUnit("millisecond", 3, 0, 3 * MICROS_PER_MILLI); + testSingleUnit("microsecond", 3, 0, 3); + + String input; + + input = "interval -5 years 23 month"; + Interval result = new Interval(-5 * 12 + 23, 0); + assertEquals(Interval.fromString(input), result); + + // Error cases + input = "interval 3month 1 hour"; + assertEquals(Interval.fromString(input), null); + + input = "interval 3 moth 1 hour"; + assertEquals(Interval.fromString(input), null); + + input = "interval"; + assertEquals(Interval.fromString(input), null); + + input = "int"; + assertEquals(Interval.fromString(input), null); + + input = ""; + assertEquals(Interval.fromString(input), null); + + input = null; + assertEquals(Interval.fromString(input), null); + } + + private void testSingleUnit(String unit, int number, int months, long microseconds) { + String input1 = "interval " + number + " " + unit; + String input2 = "interval " + number + " " + unit + "s"; + Interval result = new Interval(months, microseconds); + assertEquals(Interval.fromString(input1), result); + assertEquals(Interval.fromString(input2), result); + } }