From ca7fbe7ffef9eefbedea610b057834dfefdce94a Mon Sep 17 00:00:00 2001 From: Zhang Zhichao <441586683@qq.com> Date: Fri, 28 Sep 2018 23:22:49 +0800 Subject: [PATCH] [EXT][SPARK-21860][core]Improve memory reuse for heap memory in `HeapMemoryAllocator` #19077 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit In HeapMemoryAllocator, when allocating memory from pool, and the key of pool is memory size. Actually some size of memory ,such as 1025bytes,1026bytes,......1032bytes, we can think they are the sameļ¼Œbecause we allocate memory in multiples of 8 bytes. In this case, we can improve memory reuse. --- .../unsafe/memory/HeapMemoryAllocator.java | 18 +++++++++------- .../spark/unsafe/PlatformUtilSuite.java | 21 +++++++++++++++++++ 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index a9603c1aba051..2733760dd19ef 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -46,9 +46,12 @@ private boolean shouldPool(long size) { @Override public MemoryBlock allocate(long size) throws OutOfMemoryError { - if (shouldPool(size)) { + int numWords = (int) ((size + 7) / 8); + long alignedSize = numWords * 8L; + assert (alignedSize >= size); + if (shouldPool(alignedSize)) { synchronized (this) { - final LinkedList> pool = bufferPoolsBySize.get(size); + final LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool != null) { while (!pool.isEmpty()) { final WeakReference arrayReference = pool.pop(); @@ -62,11 +65,11 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { return memory; } } - bufferPoolsBySize.remove(size); + bufferPoolsBySize.remove(alignedSize); } } } - long[] array = new long[(int) ((size + 7) / 8)]; + long[] array = new long[numWords]; MemoryBlock memory = new MemoryBlock(array, Platform.LONG_ARRAY_OFFSET, size); if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); @@ -98,12 +101,13 @@ public void free(MemoryBlock memory) { long[] array = (long[]) memory.obj; memory.setObjAndOffset(null, 0); - if (shouldPool(size)) { + long alignedSize = ((size + 7) / 8) * 8; + if (shouldPool(alignedSize)) { synchronized (this) { - LinkedList> pool = bufferPoolsBySize.get(size); + LinkedList> pool = bufferPoolsBySize.get(alignedSize); if (pool == null) { pool = new LinkedList<>(); - bufferPoolsBySize.put(size, pool); + bufferPoolsBySize.put(alignedSize, pool); } pool.add(new WeakReference<>(array)); } diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 62854837b05ed..08ae65290e151 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.unsafe; +import org.apache.spark.unsafe.memory.HeapMemoryAllocator; import org.apache.spark.unsafe.memory.MemoryAllocator; import org.apache.spark.unsafe.memory.MemoryBlock; @@ -134,4 +135,24 @@ public void memoryDebugFillEnabledInTest() { MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); MemoryAllocator.UNSAFE.free(offheap); } + + @Test + public void heapMemoryReuse() { + MemoryAllocator heapMem = new HeapMemoryAllocator(); + // The size is less than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`,allocate new memory every time. + MemoryBlock onheap1 = heapMem.allocate(513); + Object obj1 = onheap1.getBaseObject(); + heapMem.free(onheap1); + MemoryBlock onheap2 = heapMem.allocate(514); + Assert.assertNotEquals(obj1, onheap2.getBaseObject()); + // The size is greater than `HeapMemoryAllocator.POOLING_THRESHOLD_BYTES`, + // reuse the previous memory which has released. + MemoryBlock onheap3 = heapMem.allocate(1024 * 1024 + 1); + Assert.assertEquals(onheap3.size(), 1024 * 1024 + 1); + Object obj3 = onheap3.getBaseObject(); + heapMem.free(onheap3); + MemoryBlock onheap4 = heapMem.allocate(1024 * 1024 + 7); + Assert.assertEquals(onheap4.size(), 1024 * 1024 + 7); + Assert.assertEquals(obj3, onheap4.getBaseObject()); + } }