Skip to content

Commit

Permalink
[FLINK-6579] [table] Add proper support for BasicArrayTypeInfo
Browse files Browse the repository at this point in the history
This closes apache#3902.
  • Loading branch information
twalthr committed May 15, 2017
1 parent e23328e commit 0c2d0da
Show file tree
Hide file tree
Showing 10 changed files with 194 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.calcite.sql.`type`.SqlTypeName._
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.flink.api.common.functions._
import org.apache.flink.api.common.io.GenericInputFormat
import org.apache.flink.api.common.typeinfo.{AtomicType, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.api.java.typeutils._
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
Expand Down Expand Up @@ -1522,13 +1522,15 @@ class CodeGenerator(

case ITEM =>
operands.head.resultType match {
case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] =>
case _: ObjectArrayTypeInfo[_, _] |
_: BasicArrayTypeInfo[_, _] |
_: PrimitiveArrayTypeInfo[_] =>
val array = operands.head
val index = operands(1)
requireInteger(index)
generateArrayElementAt(this, array, index)

case map: MapTypeInfo[_, _] =>
case _: MapTypeInfo[_, _] =>
val key = operands(1)
generateMapGet(this, operands.head, key)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.calcite.avatica.util.DateTimeUtils.MILLIS_PER_DAY
import org.apache.calcite.avatica.util.{DateTimeUtils, TimeUnitRange}
import org.apache.calcite.util.BuiltInMethod
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.api.common.typeinfo.{NumericTypeInfo, PrimitiveArrayTypeInfo, SqlTimeTypeInfo, TypeInformation}
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.java.typeutils.{MapTypeInfo, ObjectArrayTypeInfo}
import org.apache.flink.table.codegen.CodeGenUtils._
import org.apache.flink.table.codegen.{CodeGenException, CodeGenerator, GeneratedExpression}
Expand Down Expand Up @@ -93,7 +93,8 @@ object ScalarOperators {
generateComparison("==", nullCheck, left, right)
}
// array types
else if (isArray(left.resultType) && left.resultType == right.resultType) {
else if (isArray(left.resultType) &&
left.resultType.getTypeClass == right.resultType.getTypeClass) {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
(leftTerm, rightTerm) => s"java.util.Arrays.equals($leftTerm, $rightTerm)"
}
Expand Down Expand Up @@ -133,7 +134,8 @@ object ScalarOperators {
generateComparison("!=", nullCheck, left, right)
}
// array types
else if (isArray(left.resultType) && left.resultType == right.resultType) {
else if (isArray(left.resultType) &&
left.resultType.getTypeClass == right.resultType.getTypeClass) {
generateOperatorIfNotNull(nullCheck, BOOLEAN_TYPE_INFO, left, right) {
(leftTerm, rightTerm) => s"!java.util.Arrays.equals($leftTerm, $rightTerm)"
}
Expand Down Expand Up @@ -456,6 +458,11 @@ object ScalarOperators {
case (fromTp, toTp) if fromTp == toTp =>
operand

// array identity casting
// (e.g. for Integer[] that can be ObjectArrayTypeInfo or BasicArrayTypeInfo)
case (fromTp, toTp) if isArray(fromTp) && fromTp.getTypeClass == toTp.getTypeClass =>
operand

// Date/Time/Timestamp -> String
case (dtt: SqlTimeTypeInfo[_], STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
Expand All @@ -479,13 +486,13 @@ object ScalarOperators {
}

// Object array -> String
case (_:ObjectArrayTypeInfo[_, _], STRING_TYPE_INFO) =>
case (_: ObjectArrayTypeInfo[_, _] | _: BasicArrayTypeInfo[_, _], STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"java.util.Arrays.deepToString($operandTerm)"
}

// Primitive array -> String
case (_:PrimitiveArrayTypeInfo[_], STRING_TYPE_INFO) =>
case (_: PrimitiveArrayTypeInfo[_], STRING_TYPE_INFO) =>
generateUnaryOperatorIfNotNull(nullCheck, targetType, operand) {
(operandTerm) => s"java.util.Arrays.toString($operandTerm)"
}
Expand Down Expand Up @@ -792,37 +799,45 @@ object ScalarOperators {

val resultTerm = newName("result")

def unboxArrayElement(componentInfo: TypeInformation[_]): GeneratedExpression = {
// get boxed array element
val resultTypeTerm = boxedTypeTermForTypeInfo(componentInfo)

val arrayAccessCode = if (codeGenerator.nullCheck) {
s"""
|${array.code}
|${index.code}
|$resultTypeTerm $resultTerm = (${array.nullTerm} || ${index.nullTerm}) ?
| null : ${array.resultTerm}[${index.resultTerm} - 1];
|""".stripMargin
} else {
s"""
|${array.code}
|${index.code}
|$resultTypeTerm $resultTerm = ${array.resultTerm}[${index.resultTerm} - 1];
|""".stripMargin
}

// generate unbox code
val unboxing = codeGenerator.generateInputFieldUnboxing(componentInfo, resultTerm)

unboxing.copy(code =
s"""
|$arrayAccessCode
|${unboxing.code}
|""".stripMargin
)
}

array.resultType match {

// unbox object array types
case oati: ObjectArrayTypeInfo[_, _] =>
// get boxed array element
val resultTypeTerm = boxedTypeTermForTypeInfo(oati.getComponentInfo)
unboxArrayElement(oati.getComponentInfo)

val arrayAccessCode = if (codeGenerator.nullCheck) {
s"""
|${array.code}
|${index.code}
|$resultTypeTerm $resultTerm = (${array.nullTerm} || ${index.nullTerm}) ?
| null : ${array.resultTerm}[${index.resultTerm} - 1];
|""".stripMargin
} else {
s"""
|${array.code}
|${index.code}
|$resultTypeTerm $resultTerm = ${array.resultTerm}[${index.resultTerm} - 1];
|""".stripMargin
}

// generate unbox code
val unboxing = codeGenerator.generateInputFieldUnboxing(oati.getComponentInfo, resultTerm)

unboxing.copy(code =
s"""
|$arrayAccessCode
|${unboxing.code}
|""".stripMargin
)
// unbox basic array types
case bati: BasicArrayTypeInfo[_, _] =>
unboxArrayElement(bati.getComponentInfo)

// no unboxing necessary
case pati: PrimitiveArrayTypeInfo[_] =>
Expand All @@ -841,6 +856,7 @@ object ScalarOperators {
val resultTerm = newName("result")
val resultType = array.resultType match {
case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo
case bati: BasicArrayTypeInfo[_, _] => bati.getComponentInfo
case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType
}
val resultTypeTerm = primitiveTypeTermForTypeInfo(resultType)
Expand All @@ -852,31 +868,38 @@ object ScalarOperators {
s"${array.resultTerm}.length"
}

def unboxArrayElement(componentInfo: TypeInformation[_]): String = {
// generate unboxing code
val unboxing = codeGenerator.generateInputFieldUnboxing(
componentInfo,
s"${array.resultTerm}[0]")

s"""
|${array.code}
|${if (codeGenerator.nullCheck) s"boolean $nullTerm;" else "" }
|$resultTypeTerm $resultTerm;
|switch ($arrayLengthCode) {
| case 0:
| ${if (codeGenerator.nullCheck) s"$nullTerm = true;" else "" }
| $resultTerm = $defaultValue;
| break;
| case 1:
| ${unboxing.code}
| ${if (codeGenerator.nullCheck) s"$nullTerm = ${unboxing.nullTerm};" else "" }
| $resultTerm = ${unboxing.resultTerm};
| break;
| default:
| throw new RuntimeException("Array has more than one element.");
|}
|""".stripMargin
}

val arrayAccessCode = array.resultType match {
case oati: ObjectArrayTypeInfo[_, _] =>
// generate unboxing code
val unboxing = codeGenerator.generateInputFieldUnboxing(
oati.getComponentInfo,
s"${array.resultTerm}[0]")
unboxArrayElement(oati.getComponentInfo)

s"""
|${array.code}
|${if (codeGenerator.nullCheck) s"boolean $nullTerm;" else "" }
|$resultTypeTerm $resultTerm;
|switch ($arrayLengthCode) {
| case 0:
| ${if (codeGenerator.nullCheck) s"$nullTerm = true;" else "" }
| $resultTerm = $defaultValue;
| break;
| case 1:
| ${unboxing.code}
| ${if (codeGenerator.nullCheck) s"$nullTerm = ${unboxing.nullTerm};" else "" }
| $resultTerm = ${unboxing.resultTerm};
| break;
| default:
| throw new RuntimeException("Array has more than one element.");
|}
|""".stripMargin
case bati: BasicArrayTypeInfo[_, _] =>
unboxArrayElement(bati.getComponentInfo)

case pati: PrimitiveArrayTypeInfo[_] =>
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ import org.apache.calcite.rex.RexNode
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo.INT_TYPE_INFO
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, PrimitiveArrayTypeInfo}
import org.apache.flink.api.common.typeinfo.{BasicArrayTypeInfo, BasicTypeInfo, PrimitiveArrayTypeInfo, TypeInformation}
import org.apache.flink.api.java.typeutils.ObjectArrayTypeInfo
import org.apache.flink.table.calcite.FlinkRelBuilder
import org.apache.flink.table.typeutils.TypeCheckUtils.isArray
import org.apache.flink.table.validate.{ValidationFailure, ValidationResult, ValidationSuccess}

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -75,12 +76,13 @@ case class ArrayElementAt(array: Expression, index: Expression) extends Expressi

override private[flink] def resultType = array.resultType match {
case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo
case bati: BasicArrayTypeInfo[_, _] => bati.getComponentInfo
case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType
}

override private[flink] def validateInput(): ValidationResult = {
array.resultType match {
case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] =>
case ati: TypeInformation[_] if isArray(ati) =>
if (index.resultType == INT_TYPE_INFO) {
// check for common user mistake
index match {
Expand Down Expand Up @@ -114,7 +116,7 @@ case class ArrayCardinality(array: Expression) extends Expression {

override private[flink] def validateInput(): ValidationResult = {
array.resultType match {
case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => ValidationSuccess
case ati: TypeInformation[_] if isArray(ati) => ValidationSuccess
case other@_ => ValidationFailure(s"Array expected but was '$other'.")
}
}
Expand All @@ -134,12 +136,13 @@ case class ArrayElement(array: Expression) extends Expression {

override private[flink] def resultType = array.resultType match {
case oati: ObjectArrayTypeInfo[_, _] => oati.getComponentInfo
case bati: BasicArrayTypeInfo[_, _] => bati.getComponentInfo
case pati: PrimitiveArrayTypeInfo[_] => pati.getComponentType
}

override private[flink] def validateInput(): ValidationResult = {
array.resultType match {
case _: ObjectArrayTypeInfo[_, _] | _: PrimitiveArrayTypeInfo[_] => ValidationSuccess
case ati: TypeInformation[_] if isArray(ati) => ValidationSuccess
case other@_ => ValidationFailure(s"Array expected but was '$other'.")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.tools.RelBuilder
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.table.typeutils.TypeCheckUtils.{isComparable, isNumeric}
import org.apache.flink.table.typeutils.TypeCheckUtils
import org.apache.flink.table.typeutils.TypeCheckUtils.{isArray, isComparable, isNumeric}
import org.apache.flink.table.validate._

import scala.collection.JavaConversions._
Expand Down Expand Up @@ -56,6 +57,8 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
(left.resultType, right.resultType) match {
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) if isArray(lType) && lType.getTypeClass == rType.getTypeClass =>
ValidationSuccess
case (lType, rType) =>
ValidationFailure(s"Equality predicate on incompatible types: $lType and $rType")
}
Expand All @@ -70,6 +73,8 @@ case class NotEqualTo(left: Expression, right: Expression) extends BinaryCompari
(left.resultType, right.resultType) match {
case (lType, rType) if isNumeric(lType) && isNumeric(rType) => ValidationSuccess
case (lType, rType) if lType == rType => ValidationSuccess
case (lType, rType) if isArray(lType) && lType.getTypeClass == rType.getTypeClass =>
ValidationSuccess
case (lType, rType) =>
ValidationFailure(s"Inequality predicate on incompatible types: $lType and $rType")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ object TypeCheckUtils {

def isArray(dataType: TypeInformation[_]): Boolean = dataType match {
case _: ObjectArrayTypeInfo[_, _] |
_: PrimitiveArrayTypeInfo[_] |
_: BasicArrayTypeInfo[_, _] => true
_: BasicArrayTypeInfo[_, _] |
_: PrimitiveArrayTypeInfo[_] => true
case _ => false
}

Expand Down
Loading

0 comments on commit 0c2d0da

Please sign in to comment.