Skip to content

Commit

Permalink
[SPARK-42168][3.2][SQL][PYTHON] Fix required child distribution of Fl…
Browse files Browse the repository at this point in the history
…atMapCoGroupsInPandas (as in CoGroup)

### What changes were proposed in this pull request?
Make `FlatMapCoGroupsInPandas` (used by PySpark) report its required child distribution as `HashClusteredDistribution`, rather than `ClusteredDistribution`. That is the same distribution as reported by `CoGroup` (used by Scala).

### Why are the changes needed?
This allows the `EnsureRequirements` rule to correctly recognizes that `FlatMapCoGroupsInPandas` requiring `HashClusteredDistribution(id, day)` is not compatible with `HashPartitioning(day, id)`, while `ClusteredDistribution(id, day)` is compatible with `HashPartitioning(day, id)`.

The following example returns an incorrect result in Spark 3.0, 3.1, and 3.2.

```Scala
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{col, lit, sum}

val ids = 1000
val days = 1000
val parts = 10

val id_df = spark.range(ids)
val day_df = spark.range(days).withColumnRenamed("id", "day")
val id_day_df = id_df.join(day_df)
// these redundant aliases are needed to workaround bug SPARK-42132
val left_df = id_day_df.select($"id".as("id"), $"day".as("day"), lit("left").as("side")).repartition(parts).cache()
val right_df = id_day_df.select($"id".as("id"), $"day".as("day"), lit("right").as("side")).repartition(parts).cache()  //.withColumnRenamed("id", "id2")

// note the column order is different to the groupBy("id", "day") column order below
val window = Window.partitionBy("day", "id")

case class Key(id: BigInt, day: BigInt)
case class Value(id: BigInt, day: BigInt, side: String)
case class Sum(id: BigInt, day: BigInt, side: String, day_sum: BigInt)

val left_grouped_df = left_df.groupBy("id", "day").as[Key, Value]
val right_grouped_df = right_df.withColumn("day_sum", sum(col("day")).over(window)).groupBy("id", "day").as[Key, Sum]

val df = left_grouped_df.cogroup(right_grouped_df)((key: Key, left: Iterator[Value], right: Iterator[Sum]) => left)

df.explain()
df.show(5)
```

Output was
```
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- FlatMapCoGroupsInPandas [id#8L, day#9L], [id#29L, day#30L], cogroup(id#8L, day#9L, side#10, id#29L, day#30L, side#31, day_sum#54L), [id#64L, day#65L, lefts#66, rights#67]
   :- Sort [id#8L ASC NULLS FIRST, day#9L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#8L, day#9L, 200), ENSURE_REQUIREMENTS, [plan_id=117]
   :     +- ...
   +- Sort [id#29L ASC NULLS FIRST, day#30L ASC NULLS FIRST], false, 0
      +- Project [id#29L, day#30L, id#29L, day#30L, side#31, day_sum#54L]
         +- Window [sum(day#30L) windowspecdefinition(day#30L, id#29L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS day_sum#54L], [day#30L, id#29L]
            +- Sort [day#30L ASC NULLS FIRST, id#29L ASC NULLS FIRST], false, 0
               +- Exchange hashpartitioning(day#30L, id#29L, 200), ENSURE_REQUIREMENTS, [plan_id=112]
                  +- ...

+---+---+-----+------+
| id|day|lefts|rights|
+---+---+-----+------+
|  0|  3|    0|     1|
|  0|  4|    0|     1|
|  0| 13|    1|     0|
|  0| 27|    0|     1|
|  0| 31|    0|     1|
+---+---+-----+------+
only showing top 5 rows
```

Output now is
```
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- FlatMapCoGroupsInPandas [id#8L, day#9L], [id#29L, day#30L], cogroup(id#8L, day#9L, side#10, id#29L, day#30L, side#31, day_sum#54L), [id#64L, day#65L, lefts#66, rights#67]
   :- Sort [id#8L ASC NULLS FIRST, day#9L ASC NULLS FIRST], false, 0
   :  +- Exchange hashpartitioning(id#8L, day#9L, 200), ENSURE_REQUIREMENTS, [plan_id=117]
   :     +- ...
   +- Sort [id#29L ASC NULLS FIRST, day#30L ASC NULLS FIRST], false, 0
      +- Exchange hashpartitioning(id#29L, day#30L, 200), ENSURE_REQUIREMENTS, [plan_id=118]
         +- Project [id#29L, day#30L, id#29L, day#30L, side#31, day_sum#54L]
            +- Window [sum(day#30L) windowspecdefinition(day#30L, id#29L, specifiedwindowframe(RowFrame, unboundedpreceding$(), unboundedfollowing$())) AS day_sum#54L], [day#30L, id#29L]
               +- Sort [day#30L ASC NULLS FIRST, id#29L ASC NULLS FIRST], false, 0
                  +- Exchange hashpartitioning(day#30L, id#29L, 200), ENSURE_REQUIREMENTS, [plan_id=112]
                     +- ...

+---+---+-----+------+
| id|day|lefts|rights|
+---+---+-----+------+
|  0| 13|    1|     1|
|  0| 63|    1|     1|
|  0| 89|    1|     1|
|  0| 95|    1|     1|
|  0| 96|    1|     1|
+---+---+-----+------+
only showing top 5 rows
```

Spark 3.3 [reworked](https://github.com/apache/spark/pull/32875/files#diff-e938569a4ca4eba8f7e10fe473d4f9c306ea253df151405bcaba880a601f075fR75-R76) `HashClusteredDistribution`, and is not sensitive to using `ClusteredDistribution`: apache#32875

### Does this PR introduce _any_ user-facing change?
This fixes correctness.

### How was this patch tested?
A unit test in `EnsureRequirementsSuite`.

Closes apache#39717 from EnricoMi/branch-3.2-cogroup-window-bug.

Authored-by: Enrico Minack <github@enrico.minack.dev>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
EnricoMi authored and kazuyukitanimura committed Jan 27, 2023
1 parent 164fa3f commit 382d82a
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 5 deletions.
47 changes: 46 additions & 1 deletion python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@

import unittest

from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf
from pyspark.sql.functions import array, explode, col, lit, udf, pandas_udf, sum
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.sql.window import Window
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -215,6 +216,50 @@ def test_self_join(self):

self.assertEqual(row.asDict(), Row(column=2, value=2).asDict())

def test_with_window_function(self):
# SPARK-42168: a window function with same partition keys but differing key order
ids = 2
days = 100
vals = 10000
parts = 10

id_df = self.spark.range(ids)
day_df = self.spark.range(days).withColumnRenamed("id", "day")
vals_df = self.spark.range(vals).withColumnRenamed("id", "value")
df = id_df.join(day_df).join(vals_df)

left_df = df.withColumnRenamed("value", "left").repartition(parts).cache()
# SPARK-42132: this bug requires us to alias all columns from df here
right_df = df.select(
col("id").alias("id"), col("day").alias("day"), col("value").alias("right")
).repartition(parts).cache()

# note the column order is different to the groupBy("id", "day") column order below
window = Window.partitionBy("day", "id")

left_grouped_df = left_df.groupBy("id", "day")
right_grouped_df = right_df \
.withColumn("day_sum", sum(col("day")).over(window)) \
.groupBy("id", "day")

def cogroup(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame:
return pd.DataFrame([{
"id": left["id"][0] if not left.empty else (
right["id"][0] if not right.empty else None
),
"day": left["day"][0] if not left.empty else (
right["day"][0] if not right.empty else None
),
"lefts": len(left.index),
"rights": len(right.index)
}])

df = left_grouped_df.cogroup(right_grouped_df) \
.applyInPandas(cogroup, schema="id long, day long, lefts integer, rights integer")

actual = df.orderBy("id", "day").take(days)
self.assertEqual(actual, [Row(0, day, vals, vals) for day in range(days)])

@staticmethod
def _test_with_key(left, right, isLeft):

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution, HashClusteredDistribution, Partitioning}
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
import org.apache.spark.sql.execution.python.PandasGroupUtils._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -66,8 +66,8 @@ case class FlatMapCoGroupsInPandasExec(
override def outputPartitioning: Partitioning = left.outputPartitioning

override def requiredChildDistribution: Seq[Distribution] = {
val leftDist = if (leftGroup.isEmpty) AllTuples else ClusteredDistribution(leftGroup)
val rightDist = if (rightGroup.isEmpty) AllTuples else ClusteredDistribution(rightGroup)
val leftDist = if (leftGroup.isEmpty) AllTuples else HashClusteredDistribution(leftGroup)
val rightDist = if (rightGroup.isEmpty) AllTuples else HashClusteredDistribution(rightGroup)
leftDist :: rightDist :: Nil
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

package org.apache.spark.sql.execution.exchange

import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.api.python.PythonEvalType
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.Sum
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, PartitioningCollection}
import org.apache.spark.sql.execution.{DummySparkPlan, SortExec}
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.execution.python.FlatMapCoGroupsInPandasExec
import org.apache.spark.sql.execution.window.WindowExec
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}

class EnsureRequirementsSuite extends SharedSparkSession {
private val exprA = Literal(1)
Expand Down Expand Up @@ -135,4 +140,55 @@ class EnsureRequirementsSuite extends SharedSparkSession {
}.size == 2)
}
}

test("SPARK-42168: FlatMapCoGroupInPandas and Window function with differing key order") {
val lKey = AttributeReference("key", IntegerType)()
val lKey2 = AttributeReference("key2", IntegerType)()

val rKey = AttributeReference("key", IntegerType)()
val rKey2 = AttributeReference("key2", IntegerType)()
val rValue = AttributeReference("value", IntegerType)()

val left = DummySparkPlan()
val right = WindowExec(
Alias(
WindowExpression(
Sum(rValue).toAggregateExpression(),
WindowSpecDefinition(
Seq(rKey2, rKey),
Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, UnboundedFollowing)
)
), "sum")() :: Nil,
Seq(rKey2, rKey),
Nil,
DummySparkPlan()
)

val pythonUdf = PythonUDF("pyUDF", null,
StructType(Seq(StructField("value", IntegerType))),
Seq.empty,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
true)

val flapMapCoGroup = FlatMapCoGroupsInPandasExec(
Seq(lKey, lKey2),
Seq(rKey, rKey2),
pythonUdf,
AttributeReference("value", IntegerType)() :: Nil,
left,
right
)

val result = EnsureRequirements.apply(flapMapCoGroup)
result match {
case FlatMapCoGroupsInPandasExec(leftKeys, rightKeys, _, _,
SortExec(leftOrder, false, _, _), SortExec(rightOrder, false, _, _)) =>
assert(leftKeys === Seq(lKey, lKey2))
assert(rightKeys === Seq(rKey, rKey2))
assert(leftKeys.map(k => SortOrder(k, Ascending)) === leftOrder)
assert(rightKeys.map(k => SortOrder(k, Ascending)) === rightOrder)
case other => fail(other.toString)
}
}
}

0 comments on commit 382d82a

Please sign in to comment.