Skip to content

Commit

Permalink
[SPARK-31511][FOLLOW-UP][TEST][SQL] Make BytesToBytesMap iterators th…
Browse files Browse the repository at this point in the history
…read-safe

### What changes were proposed in this pull request?
Before SPARK-31511 is fixed, `BytesToBytesMap` iterator() is not thread-safe and may cause data inaccuracy.
We need to add a unit test.

### Why are the changes needed?
Increase test coverage to ensure that iterator() is thread-safe.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
add ut

Closes apache#29669 from cxzl25/SPARK-31511-test.

Authored-by: sychen <sychen@ctrip.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cxzl25 authored and cloud-fan committed Sep 8, 2020
1 parent 55d38a4 commit bd3dc2f
Showing 1 changed file with 39 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,45 @@ class HashedRelationSuite extends SharedSparkSession {
assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
}

test("SPARK-31511: Make BytesToBytesMap iterators thread-safe") {
val ser = sparkContext.env.serializer.newInstance()
val key = Seq(BoundReference(0, LongType, false))

val unsafeProj = UnsafeProjection.create(
Seq(BoundReference(0, LongType, false), BoundReference(1, IntegerType, true)))
val rows = (0 until 10000).map(i => unsafeProj(InternalRow(Int.int2long(i), i + 1)).copy())
val unsafeHashed = UnsafeHashedRelation(rows.iterator, key, 1, mm)

val os = new ByteArrayOutputStream()
val thread1 = new Thread {
override def run(): Unit = {
val out = new ObjectOutputStream(os)
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
}
}

val thread2 = new Thread {
override def run(): Unit = {
val threadOut = new ObjectOutputStream(new ByteArrayOutputStream())
unsafeHashed.asInstanceOf[UnsafeHashedRelation].writeExternal(threadOut)
threadOut.flush()
}
}

thread1.start()
thread2.start()
thread1.join()
thread2.join()

val unsafeHashed2 = ser.deserialize[UnsafeHashedRelation](ser.serialize(unsafeHashed))
val os2 = new ByteArrayOutputStream()
val out2 = new ObjectOutputStream(os2)
unsafeHashed2.writeExternal(out2)
out2.flush()
assert(java.util.Arrays.equals(os.toByteArray, os2.toByteArray))
}

// This test require 4G heap to run, should run it manually
ignore("build HashedRelation that is larger than 1G") {
val unsafeProj = UnsafeProjection.create(
Expand Down

0 comments on commit bd3dc2f

Please sign in to comment.