Skip to content

Commit

Permalink
[FLINK-6926] [table] Add support for SHA2 (Compatible with SQL vendor…
Browse files Browse the repository at this point in the history
…s like MySQL)

This closes apache#5324.
  • Loading branch information
genged authored and twalthr committed May 15, 2018
1 parent 5544ab5 commit 8eec662
Show file tree
Hide file tree
Showing 13 changed files with 280 additions and 37 deletions.
26 changes: 19 additions & 7 deletions docs/dev/table/sql.md
Expand Up @@ -2419,7 +2419,7 @@ MD5(string)
{% endhighlight %}
</td>
<td>
<p>Returns the MD5 hash of the string argument as a string of 32 hexadecimal digits; null if <i>string</i> is null.</p>
<p>Returns the MD5 hash of the <i>string</i> argument as a string of 32 hexadecimal digits; null if <i>string</i> is null.</p>
</td>
</tr>

Expand All @@ -2430,7 +2430,7 @@ SHA1(string)
{% endhighlight %}
</td>
<td>
<p>Returns the SHA-1 hash of the string argument as a string of 40 hexadecimal digits; null if <i>string</i> is null.</p>
<p>Returns the SHA-1 hash of the <i>string</i> argument as a string of 40 hexadecimal digits; null if <i>string</i> is null.</p>
</td>
</tr>

Expand All @@ -2441,7 +2441,7 @@ SHA224(string)
{% endhighlight %}
</td>
<td>
<p>Returns the SHA-224 hash of the string argument as a string of 56 hexadecimal digits; null if <i>string</i> is null.</p>
<p>Returns the SHA-224 hash of the <i>string</i> argument as a string of 56 hexadecimal digits; null if <i>string</i> is null.</p>
</td>
</tr>

Expand All @@ -2452,7 +2452,7 @@ SHA256(string)
{% endhighlight %}
</td>
<td>
<p>Returns the SHA-256 hash of the string argument as a string of 64 hexadecimal digits; null if <i>string</i> is null.</p>
<p>Returns the SHA-256 hash of the <i>string</i> argument as a string of 64 hexadecimal digits; null if <i>string</i> is null.</p>
</td>
</tr>

Expand All @@ -2463,7 +2463,7 @@ SHA384(string)
{% endhighlight %}
</td>
<td>
<p>Returns the SHA-384 hash of the string argument as a string of 96 hexadecimal digits; null if <i>string</i> is null.</p>
<p>Returns the SHA-384 hash of the <i>string</i> argument as a string of 96 hexadecimal digits; null if <i>string</i> is null.</p>
</td>
</tr>

Expand All @@ -2474,9 +2474,21 @@ SHA512(string)
{% endhighlight %}
</td>
<td>
<p>Returns the SHA-512 hash of the string argument as a string of 128 hexadecimal digits; null if <i>string</i> is null.</p>
<p>Returns the SHA-512 hash of the <i>string</i> argument as a string of 128 hexadecimal digits; null if <i>string</i> is null.</p>
</td>
</tr>
</tr>

<tr>
<td>
{% highlight text %}
SHA2(string, hashLength)
{% endhighlight %}
</td>
<td>
<p>Returns the hash using the SHA-2 family of hash functions (SHA-224, SHA-256, SHA-384, or SHA-512). The first argument <i>string</i> is the string to be hashed. <i>hashLength</i> is the bit length of the result (either 224, 256, 384, or 512). Returns <i>null</i> if <i>string</i> or <i>hashLength</i> is <i>null</i>.
</p>
</td>
</tr>
</tbody>
</table>

Expand Down
Expand Up @@ -775,7 +775,7 @@ trait ImplicitExpressionOperations {
/**
* Returns the SHA-224 hash of the string argument; null if string is null.
*
* @return string of 64 hexadecimal digits or null
* @return string of 56 hexadecimal digits or null
*/
def sha224() = Sha224(expr)

Expand All @@ -789,16 +789,26 @@ trait ImplicitExpressionOperations {
/**
* Returns the SHA-384 hash of the string argument; null if string is null.
*
* @return string of 64 hexadecimal digits or null
* @return string of 96 hexadecimal digits or null
*/
def sha384() = Sha384(expr)

/**
* Returns the SHA-512 hash of the string argument; null if string is null.
*
* @return string of 64 hexadecimal digits or null
* @return string of 128 hexadecimal digits or null
*/
def sha512() = Sha512(expr)

/**
* Returns the hash for the given string expression using the SHA-2 family of hash
* functions (SHA-224, SHA-256, SHA-384, or SHA-512).
*
* @param hashLength bit length of the result (either 224, 256, 384, or 512)
* @return string or null if one of the arguments is null.
*/
def sha2(hashLength: Expression) = Sha2(expr, hashLength)

}

/**
Expand Down
Expand Up @@ -1895,7 +1895,7 @@ abstract class CodeGenerator(
}

/**
* Adds a reusable MessageDigest to the member area of the generated [[Function]].
* Adds a known reusable MessageDigest to the member area of the generated [[Function]].
*
* @return member variable term
*/
Expand All @@ -1904,20 +1904,69 @@ abstract class CodeGenerator(

val field =
s"""
|final java.security.MessageDigest $fieldTerm;
|""".stripMargin
|final java.security.MessageDigest $fieldTerm;
|""".stripMargin
reusableMemberStatements.add(field)

val fieldInit =
val init =
s"""
|try {
| $fieldTerm = java.security.MessageDigest.getInstance("$algorithm");
|} catch (java.security.NoSuchAlgorithmException e) {
| throw new RuntimeException("Algorithm for '$algorithm' is not available.", e);
|}
|try {
| $fieldTerm = java.security.MessageDigest.getInstance("$algorithm");
|} catch (java.security.NoSuchAlgorithmException e) {
| throw new RuntimeException("Algorithm for '$algorithm' is not available.", e);
|}
|""".stripMargin

reusableInitStatements.add(init)
fieldTerm
}

/**
* Adds a constant SHA2 reusable MessageDigest to the member area of the generated [[Function]].
*
* @return member variable term
*/
def addReusableSha2MessageDigest(constant: GeneratedExpression): String = {
require(constant.literal, "Literal expected")
val fieldTerm = newName("messageDigest")

val field =
s"""
|final java.security.MessageDigest $fieldTerm;
|""".stripMargin
reusableMemberStatements.add(field)

val bitLen = constant.resultTerm
val init = s"""
|if ($bitLen == 224 || $bitLen == 256 || $bitLen == 384 || $bitLen == 512) {
| try {
| $fieldTerm = java.security.MessageDigest.getInstance("SHA-" + $bitLen);
| } catch (java.security.NoSuchAlgorithmException e) {
| throw new RuntimeException(
| "Algorithm for 'SHA-" + $bitLen + "' is not available.", e);
| }
|} else {
| throw new RuntimeException("Unsupported algorithm.");
|}
|""".stripMargin

val nullableInit = if (nullCheck) {
s"""
|${constant.code}
|if (${constant.nullTerm}) {
| $fieldTerm = null;
|} else {
| $init
|}
|""".stripMargin
} else {
s"""
|${constant.code}
|$init
|""".stripMargin
}
reusableInitStatements.add(nullableInit)

reusableInitStatements.add(fieldInit)
fieldTerm
}
}
Expand Up @@ -583,6 +583,12 @@ object FunctionGenerator {
new HashCalcCallGen("SHA-512")
)

addSqlFunction(
ScalarSqlFunctions.SHA2,
Seq(STRING_TYPE_INFO, INT_TYPE_INFO),
new HashCalcCallGen("SHA-2")
)

// ----------------------------------------------------------------------------------------------

/**
Expand Down
Expand Up @@ -21,6 +21,7 @@ package org.apache.flink.table.codegen.calls
import org.apache.commons.codec.Charsets
import org.apache.commons.codec.binary.Hex
import org.apache.flink.api.common.typeinfo.BasicTypeInfo._
import org.apache.flink.table.codegen.CodeGenUtils.newName
import org.apache.flink.table.codegen.calls.CallGenerator.generateCallWithStmtIfArgsNotNull
import org.apache.flink.table.codegen.{CodeGenerator, GeneratedExpression}

Expand All @@ -31,12 +32,44 @@ class HashCalcCallGen(algName: String) extends CallGenerator {
operands: Seq[GeneratedExpression])
: GeneratedExpression = {

val md = codeGenerator.addReusableMessageDigest(algName)
val (initStmt, md) = operands.size match {

// for function calls of MD5, SHA1, SHA224, SHA256, SHA384, SHA512
case 1 =>
(None, codeGenerator.addReusableMessageDigest(algName))

// for function calls of SHA2 with constant bit length
case 2 if operands(1).literal =>
(None, codeGenerator.addReusableSha2MessageDigest(operands(1)))

// for function calls of SHA2 with variable bit length
case 2 =>
val messageDigest = newName("messageDigest")
val bitLen = operands(1).resultTerm
val init =
s"""
|final java.security.MessageDigest $messageDigest;
|if ($bitLen == 224 || $bitLen == 256 || $bitLen == 384 || $bitLen == 512) {
| try {
| $messageDigest = java.security.MessageDigest.getInstance("SHA-" + $bitLen);
| } catch (java.security.NoSuchAlgorithmException e) {
| throw new RuntimeException(
| "Algorithm for 'SHA-" + $bitLen + "' is not available.", e);
| }
|} else {
| throw new RuntimeException("Unsupported algorithm.");
|}
|""".stripMargin
(Some(init), messageDigest)
}

generateCallWithStmtIfArgsNotNull(codeGenerator.nullCheck, STRING_TYPE_INFO, operands) {
(terms) =>
val auxiliaryStmt =
s"$md.update(${terms.head}.getBytes(${classOf[Charsets].getCanonicalName}.UTF_8));"
s"""
|${initStmt.getOrElse("")}
|$md.update(${terms.head}.getBytes(${classOf[Charsets].getCanonicalName}.UTF_8));
|""".stripMargin
val result = s"${classOf[Hex].getCanonicalName}.encodeHexString($md.digest())"
(Some(auxiliaryStmt), result)
}
Expand Down
Expand Up @@ -101,3 +101,24 @@ case class Sha512(child: Expression) extends UnaryExpression with InputTypeSpec
relBuilder.call(ScalarSqlFunctions.SHA512, child.toRexNode)
}
}

case class Sha2(child: Expression, hashLength: Expression)
extends BinaryExpression with InputTypeSpec {

override private[flink] def left = child
override private[flink] def right = hashLength

override private[flink] def resultType: TypeInformation[_] = STRING_TYPE_INFO

override private[flink] def expectedTypes: Seq[TypeInformation[_]] =
STRING_TYPE_INFO :: INT_TYPE_INFO :: Nil

override def toString: String = s"($child).sha2($hashLength)"

override private[flink] def toRexNode(implicit relBuilder: RelBuilder): RexNode = {
relBuilder.call(ScalarSqlFunctions.SHA2, left.toRexNode, right.toRexNode)
}

}


Expand Up @@ -139,6 +139,16 @@ object ScalarSqlFunctions {
SqlFunctionCategory.STRING
)

val SHA2 = new SqlFunction(
"SHA2",
SqlKind.OTHER_FUNCTION,
ReturnTypes.ARG0_NULLABLE,
InferTypes.RETURN_TYPE,
OperandTypes.sequence("'(DATA, HASH_LENGTH)'",
OperandTypes.STRING, OperandTypes.NUMERIC_INTEGER),
SqlFunctionCategory.STRING
)

val DATE_FORMAT = new SqlFunction(
"DATE_FORMAT",
SqlKind.OTHER_FUNCTION,
Expand Down
Expand Up @@ -19,7 +19,7 @@
package org.apache.flink.table.plan

import org.apache.flink.api.common.typeutils.CompositeType
import org.apache.flink.table.api.{OverWindow, TableEnvironment}
import org.apache.flink.table.api.{OverWindow, TableEnvironment, ValidationException}
import org.apache.flink.table.expressions._
import org.apache.flink.table.plan.logical.{LogicalNode, Project}

Expand Down Expand Up @@ -92,6 +92,11 @@ object ProjectionTranslator {
case _ => (x._1, x._2)
}
}

// Expression is null
case null =>
throw new ValidationException("Scala 'null' is not a valid expression. " +
"Use 'Null(TYPE)' to specify typed null expressions. For example: Null(Types.INT)")
}
}

Expand Down
Expand Up @@ -84,7 +84,14 @@ abstract class LogicalNode extends TreeNode[LogicalNode] {
resolvedNode.expressionPostOrderTransform {
case a: Attribute if !a.valid =>
val from = children.flatMap(_.output).map(_.name).mkString(", ")
failValidation(s"Cannot resolve [${a.name}] given input [$from].")
// give helpful error message for null literals
if (a.name == "null") {
failValidation(s"Cannot resolve field [${a.name}] given input [$from]. If you want to " +
s"express a null literal, use 'Null(TYPE)' for typed null expressions. " +
s"For example: Null(INT)")
} else {
failValidation(s"Cannot resolve field [${a.name}] given input [$from].")
}

case e: Expression if e.validateInput().isFailure =>
failValidation(s"Expression $e failed on input check: " +
Expand Down
Expand Up @@ -276,7 +276,8 @@ object FunctionCatalog {
"sha224" -> classOf[Sha224],
"sha256" -> classOf[Sha256],
"sha384" -> classOf[Sha384],
"sha512" -> classOf[Sha512]
"sha512" -> classOf[Sha512],
"sha2" -> classOf[Sha2]
)

/**
Expand Down Expand Up @@ -439,6 +440,7 @@ class BasicOperatorTable extends ReflectiveSqlOperatorTable {
ScalarSqlFunctions.SHA256,
ScalarSqlFunctions.SHA384,
ScalarSqlFunctions.SHA512,
ScalarSqlFunctions.SHA2,
// EXTENSIONS
BasicOperatorTable.TUMBLE,
BasicOperatorTable.HOP,
Expand Down

0 comments on commit 8eec662

Please sign in to comment.