Skip to content

Commit

Permalink
[SPARK-47385] Fix tuple encoders with Option inputs
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

apache#40755  adds a null check on the input of the child deserializer in the tuple encoder. It breaks the deserializer for the `Option` type, because null should be deserialized into `None` rather than null. This PR adds a boolean parameter to `ExpressionEncoder.tuple` so that only the user that apache#40755 intended to fix has this null check.

## How was this patch tested?

Unit test.

Closes apache#45508 from chenhao-db/SPARK-47385.

Authored-by: Chenhao Li <chenhao.li@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
chenhao-db authored and sweisdb committed Apr 1, 2024
1 parent 8d7ed2d commit ab5197a
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,14 @@ object ExpressionEncoder {
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
* N-tuple. Note that these encoders should be unresolved so that information about
* name/positional binding is preserved.
* When `useNullSafeDeserializer` is true, the deserialization result for a child will be null if
* the input is null. It is false by default as most deserializers handle null input properly and
* don't require an extra null check. Some of them are null-tolerant, such as the deserializer for
* `Option[T]`, and we must not set it to true in this case.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
def tuple(
encoders: Seq[ExpressionEncoder[_]],
useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
if (encoders.length > 22) {
throw QueryExecutionErrors.elementsOfTupleExceedLimitError()
}
Expand Down Expand Up @@ -119,7 +125,7 @@ object ExpressionEncoder {
case GetColumnByOrdinal(0, _) => input
}

if (enc.objSerializer.nullable) {
if (useNullSafeDeserializer && enc.objSerializer.nullable) {
nullSafe(input, childDeserializer)
} else {
childDeserializer
Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1246,7 +1246,9 @@ class Dataset[T] private[sql](
JoinHint.NONE)).analyzed.asInstanceOf[Join]

implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
ExpressionEncoder
.tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true)
.asInstanceOf[Encoder[(T, U)]]

withTypedPlan(JoinWith.typedJoinWith(
joined,
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2576,6 +2576,18 @@ class DatasetSuite extends QueryTest
assert(result == expected)
}

test("SPARK-47385: Tuple encoder with Option inputs") {
implicit val enc: Encoder[(SingleData, Option[SingleData])] =
Encoders.tuple(Encoders.product[SingleData], Encoders.product[Option[SingleData]])

val input = Seq(
(SingleData(1), Some(SingleData(1))),
(SingleData(2), None)
)
val ds = spark.createDataFrame(input).as[(SingleData, Option[SingleData])]
checkDataset(ds, input: _*)
}

test("SPARK-43124: Show does not trigger job execution on CommandResults") {
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> "") {
withTable("t1") {
Expand Down

0 comments on commit ab5197a

Please sign in to comment.