From e5e54a3614ffd2a9150921e84e5b813d5cbf285a Mon Sep 17 00:00:00 2001 From: Tom van Bussel Date: Thu, 17 Sep 2020 12:35:40 +0200 Subject: [PATCH] [SPARK-32900][CORE] Allow UnsafeExternalSorter to spill when there are nulls ### What changes were proposed in this pull request? This PR changes the way `UnsafeExternalSorter.SpillableIterator` checks whether it has spilled already, by checking whether `inMemSorter` is null. It also allows it to spill other `UnsafeSorterIterator`s than `UnsafeInMemorySorter.SortedIterator`. ### Why are the changes needed? Before this PR `UnsafeExternalSorter.SpillableIterator` could not spill when there are NULLs in the input and radix sorting is used. Currently, Spark determines whether UnsafeExternalSorter.SpillableIterator has not spilled yet by checking whether `upstream` is an instance of `UnsafeInMemorySorter.SortedIterator`. When radix sorting is used and there are NULLs in the input however, `upstream` will be an instance of `UnsafeExternalSorter.ChainedIterator` instead, and Spark will assume that the `SpillableIterator` iterator has spilled already, and therefore cannot spill again when it's supposed to spill. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A test was added to `UnsafeExternalSorterSuite` (and therefore also to `UnsafeExternalSorterRadixSortSuite`). I manually confirmed that the test failed in `UnsafeExternalSorterRadixSortSuite` without this patch. Closes #29772 from tomvanbussel/SPARK-32900. Authored-by: Tom van Bussel Signed-off-by: herman --- .../unsafe/sort/UnsafeExternalSorter.java | 69 +++++++++++-------- .../unsafe/sort/UnsafeInMemorySorter.java | 1 + .../unsafe/sort/UnsafeSorterIterator.java | 2 + .../unsafe/sort/UnsafeSorterSpillMerger.java | 5 ++ .../unsafe/sort/UnsafeSorterSpillReader.java | 5 ++ .../sort/UnsafeExternalSorterSuite.java | 33 +++++++++ 6 files changed, 88 insertions(+), 27 deletions(-) diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 55e4e609c3c7b..71b9a5bc11542 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -501,11 +501,15 @@ private static void spillIterator(UnsafeSorterIterator inMemIterator, */ class SpillableIterator extends UnsafeSorterIterator { private UnsafeSorterIterator upstream; - private UnsafeSorterIterator nextUpstream = null; private MemoryBlock lastPage = null; private boolean loaded = false; private int numRecords = 0; + private Object currentBaseObject; + private long currentBaseOffset; + private int currentRecordLength; + private long currentKeyPrefix; + SpillableIterator(UnsafeSorterIterator inMemIterator) { this.upstream = inMemIterator; this.numRecords = inMemIterator.getNumRecords(); @@ -516,23 +520,26 @@ public int getNumRecords() { return numRecords; } + @Override + public long getCurrentPageNumber() { + throw new UnsupportedOperationException(); + } + public long spill() throws IOException { synchronized (this) { - if (!(upstream instanceof UnsafeInMemorySorter.SortedIterator && nextUpstream == null - && numRecords > 0)) { + if (inMemSorter == null || numRecords <= 0) { return 0L; } - UnsafeInMemorySorter.SortedIterator inMemIterator = - ((UnsafeInMemorySorter.SortedIterator) upstream).clone(); + long currentPageNumber = upstream.getCurrentPageNumber(); - ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); + ShuffleWriteMetrics writeMetrics = new ShuffleWriteMetrics(); // Iterate over the records that have not been returned and spill them. final UnsafeSorterSpillWriter spillWriter = new UnsafeSorterSpillWriter(blockManager, fileBufferSizeBytes, writeMetrics, numRecords); - spillIterator(inMemIterator, spillWriter); + spillIterator(upstream, spillWriter); spillWriters.add(spillWriter); - nextUpstream = spillWriter.getReader(serializerManager); + upstream = spillWriter.getReader(serializerManager); long released = 0L; synchronized (UnsafeExternalSorter.this) { @@ -540,8 +547,7 @@ public long spill() throws IOException { // is accessing the current record. We free this page in that caller's next loadNext() // call. for (MemoryBlock page : allocatedPages) { - if (!loaded || page.pageNumber != - ((UnsafeInMemorySorter.SortedIterator)upstream).getCurrentPageNumber()) { + if (!loaded || page.pageNumber != currentPageNumber) { released += page.size(); freePage(page); } else { @@ -575,22 +581,26 @@ public void loadNext() throws IOException { try { synchronized (this) { loaded = true; - if (nextUpstream != null) { - // Just consumed the last record from in memory iterator - if(lastPage != null) { - // Do not free the page here, while we are locking `SpillableIterator`. The `freePage` - // method locks the `TaskMemoryManager`, and it's a bad idea to lock 2 objects in - // sequence. We may hit dead lock if another thread locks `TaskMemoryManager` and - // `SpillableIterator` in sequence, which may happen in - // `TaskMemoryManager.acquireExecutionMemory`. - pageToFree = lastPage; - lastPage = null; - } - upstream = nextUpstream; - nextUpstream = null; + // Just consumed the last record from in memory iterator + if (lastPage != null) { + // Do not free the page here, while we are locking `SpillableIterator`. The `freePage` + // method locks the `TaskMemoryManager`, and it's a bad idea to lock 2 objects in + // sequence. We may hit dead lock if another thread locks `TaskMemoryManager` and + // `SpillableIterator` in sequence, which may happen in + // `TaskMemoryManager.acquireExecutionMemory`. + pageToFree = lastPage; + lastPage = null; } numRecords--; upstream.loadNext(); + + // Keep track of the current base object, base offset, record length, and key prefix, + // so that the current record can still be read in case a spill is triggered and we + // switch to the spill writer's iterator. + currentBaseObject = upstream.getBaseObject(); + currentBaseOffset = upstream.getBaseOffset(); + currentRecordLength = upstream.getRecordLength(); + currentKeyPrefix = upstream.getKeyPrefix(); } } finally { if (pageToFree != null) { @@ -601,22 +611,22 @@ public void loadNext() throws IOException { @Override public Object getBaseObject() { - return upstream.getBaseObject(); + return currentBaseObject; } @Override public long getBaseOffset() { - return upstream.getBaseOffset(); + return currentBaseOffset; } @Override public int getRecordLength() { - return upstream.getRecordLength(); + return currentRecordLength; } @Override public long getKeyPrefix() { - return upstream.getKeyPrefix(); + return currentKeyPrefix; } } @@ -693,6 +703,11 @@ public int getNumRecords() { return numRecords; } + @Override + public long getCurrentPageNumber() { + return current.getCurrentPageNumber(); + } + @Override public boolean hasNext() { while (!current.hasNext() && !iterators.isEmpty()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index 660eb790a550b..ff641a24a7b3e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -330,6 +330,7 @@ public void loadNext() { @Override public long getBaseOffset() { return baseOffset; } + @Override public long getCurrentPageNumber() { return currentPageNumber; } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java index 1b3167fcc250c..d9f22311d07c2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -34,4 +34,6 @@ public abstract class UnsafeSorterIterator { public abstract long getKeyPrefix(); public abstract int getNumRecords(); + + public abstract long getCurrentPageNumber(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index ab800288dcb43..f8603c5799e9b 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -70,6 +70,11 @@ public int getNumRecords() { return numRecords; } + @Override + public long getCurrentPageNumber() { + throw new UnsupportedOperationException(); + } + @Override public boolean hasNext() { return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index a524c4790407d..db79efd008530 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -89,6 +89,11 @@ public int getNumRecords() { return numRecords; } + @Override + public long getCurrentPageNumber() { + throw new UnsupportedOperationException(); + } + @Override public boolean hasNext() { return (numRecordsRemaining > 0); diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 43977717f6c97..087d090c1c60e 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -359,6 +359,39 @@ public void forcedSpillingWithReadIterator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void forcedSpillingNullsWithReadIterator() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + long[] record = new long[100]; + final int recordSize = record.length * 8; + final int n = (int) pageSizeBytes / recordSize * 3; + for (int i = 0; i < n; i++) { + boolean isNull = i % 2 == 0; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0, isNull); + } + assertTrue(sorter.getNumberOfAllocatedPages() >= 2); + + UnsafeExternalSorter.SpillableIterator iter = + (UnsafeExternalSorter.SpillableIterator) sorter.getSortedIterator(); + final int numRecordsToReadBeforeSpilling = n / 3; + for (int i = 0; i < numRecordsToReadBeforeSpilling; i++) { + assertTrue(iter.hasNext()); + iter.loadNext(); + } + + assertTrue(iter.spill() > 0); + assertEquals(0, iter.spill()); + + for (int i = numRecordsToReadBeforeSpilling; i < n; i++) { + assertTrue(iter.hasNext()); + iter.loadNext(); + } + assertFalse(iter.hasNext()); + + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void forcedSpillingWithNotReadIterator() throws Exception { final UnsafeExternalSorter sorter = newSorter();