From b829f1d9a76fbe6a377556f0db7ea3fab6fbeb4f Mon Sep 17 00:00:00 2001 From: Gabriel Gerhardsson Date: Tue, 8 Nov 2016 13:51:31 +0100 Subject: [PATCH 1/2] Add RecursionSafeAsyncCaller: safe to use for recursive calls RecursionSafeAsyncCaller will try to run the call in the current thread as much as possible, while keeping track of recursion to avoid StackOverflowError. If recursion becomes too deep, the next call is deferred to a separate thread (ExecutorService thread pool). State is kept per-thread - will be tracked for any thread that passes this code. It is important to choose a suitable maximum recursion depth. --- ...cursionSafeAsyncCallerIntegrationTest.java | 123 +++++++++++++++ .../async/RecursionSafeAsyncCaller.java | 118 +++++++++++++++ .../async/RecursionSafeAsyncCallerTest.java | 140 ++++++++++++++++++ 3 files changed, 381 insertions(+) create mode 100644 tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java create mode 100644 tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java create mode 100644 tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java diff --git a/tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java b/tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java new file mode 100644 index 0000000..00fdd6f --- /dev/null +++ b/tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java @@ -0,0 +1,123 @@ +package eu.toolchain.async; + +import com.google.common.util.concurrent.AtomicLongMap; +import com.google.common.util.concurrent.MoreExecutors; +import eu.toolchain.async.RecursionSafeAsyncCaller; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; +import java.lang.ThreadLocal; +import java.lang.Thread; + +import com.google.common.util.concurrent.AtomicLongMap; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class RecursionSafeAsyncCallerIntegrationTest { + + private AtomicLongMap recursionDepthPerThread; + private AtomicLongMap maxRecursionDepthPerThread; + AtomicLong totIterations; + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + recursionDepthPerThread = AtomicLongMap.create(); + maxRecursionDepthPerThread = AtomicLongMap.create(); + totIterations = new AtomicLong(0); + } + + + public void testBasicRecursionMethod(RecursionSafeAsyncCaller caller, ConcurrentLinkedQueue testData) { + + class RecursionRunnable implements Runnable { + RecursionSafeAsyncCaller caller; + ConcurrentLinkedQueue testData; + + RecursionRunnable(RecursionSafeAsyncCaller caller, ConcurrentLinkedQueue testData) { + this.caller = caller; + this.testData = testData; + } + + @Override + public void run() { + Long threadId = Thread.currentThread().getId(); + Long currDepth = recursionDepthPerThread.addAndGet(threadId, 1L); + Long currMax = maxRecursionDepthPerThread.get(threadId); + if (currDepth > currMax) { + maxRecursionDepthPerThread.put(threadId, currDepth); + } + + if (testData.size() == 0) + return; + testData.poll(); + totIterations.incrementAndGet(); + + // Recursive call, via caller + testBasicRecursionMethod(caller, testData); + + recursionDepthPerThread.addAndGet(threadId, -1L); + } + }; + + RecursionRunnable runnable = new RecursionRunnable(caller, testData); + caller.execute(runnable); + } + + @Test + public void testBasic() throws Exception { + final long MAX_RECURSION_DEPTH = 2; + ExecutorService executorServiceReal = Executors.newFixedThreadPool(10); + AsyncCaller caller2 = mock(AsyncCaller.class); + RecursionSafeAsyncCaller recursionCaller = + new RecursionSafeAsyncCaller(executorServiceReal, caller2, MAX_RECURSION_DEPTH); + ConcurrentLinkedQueue + testData = new ConcurrentLinkedQueue<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + + testBasicRecursionMethod(recursionCaller, testData); + + // Wait for recursive calls in thread pool to create new work until done, or timeout. + // The executorServiceReal.shutdown() below is not enough since potentially some new work + // is still on the way in, due to recursive nature of the test. + long startTime = System.currentTimeMillis(); + long maxTime = 1000; + while (testData.size() > 0) { + if (System.currentTimeMillis() - startTime > maxTime) { + // Timeout, test will fail in the assert below + break; + } + Thread.sleep(10); + } + + executorServiceReal.shutdown(); + executorServiceReal.awaitTermination(1000, TimeUnit.MILLISECONDS); + + assert(testData.size() == 0); + assert(totIterations.get() == 10); + + Long maxStackDepth=-1L; + Map readOnlyMap = maxRecursionDepthPerThread.asMap(); + for (Long key : readOnlyMap.keySet()) { + Long val = readOnlyMap.get(key); + if (val > maxStackDepth) + maxStackDepth = val; + } + assert(maxStackDepth != -1L); + + // Checking with +1 since our initial call to testBasicRecursionMethod() above adds 1 + assert(maxStackDepth <= MAX_RECURSION_DEPTH+1); + } +} diff --git a/tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java b/tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java new file mode 100644 index 0000000..ccb33ab --- /dev/null +++ b/tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java @@ -0,0 +1,118 @@ +/* + * An AsyncCaller implementation that will try to run the call in the current thread as much as + * possible, while keeping track of recursion to avoid StackOverflowException in the thread. If + * recursion becomes too deep, the next call is deferred to a separate thread (normal thread pool). + * State is kept per-thread - stack overflow will be avoided for any thread that passes this code. + * It is vital to choose a suitable maximum recursion depth. + */ +package eu.toolchain.async; + +import java.util.concurrent.ExecutorService; + +public final class RecursionSafeAsyncCaller implements AsyncCaller { + private final ExecutorService executorService; + private final AsyncCaller caller; + private final long maxRecursionDepth; + + private class ThreadLocalInteger extends ThreadLocal { + protected Integer initialValue() { + return 0; + } + }; + private final ThreadLocalInteger recursionDepthPerThread; + + public RecursionSafeAsyncCaller(ExecutorService executorService, AsyncCaller caller, long maxRecursionDepth) { + this.executorService = executorService; + this.caller = caller; + this.maxRecursionDepth = maxRecursionDepth; + this.recursionDepthPerThread = new ThreadLocalInteger(); + } + + public RecursionSafeAsyncCaller(ExecutorService executorService, AsyncCaller caller) { + this(executorService, caller, 100); + } + + @Override + public void resolve(final FutureDone handle, final T result) { + execute(() -> caller.resolve(handle, result)); + } + + @Override + public void fail(final FutureDone handle, final Throwable error) { + execute(() -> caller.fail(handle, error)); + } + + @Override + public void cancel(final FutureDone handle) { + execute(() -> caller.cancel(handle)); + } + + @Override + public void cancel(final FutureCancelled cancelled) { + execute(() -> caller.cancel(cancelled)); + } + + @Override + public void finish(final FutureFinished finishable) { + execute(() -> caller.finish(finishable)); + } + + @Override + public void resolve(final FutureResolved resolved, final T value) { + execute(() -> caller.resolve(resolved, value)); + } + + @Override + public void resolve(final StreamCollector collector, final T result) { + execute(() -> caller.resolve(collector, result)); + } + + @Override + public void fail(final StreamCollector collector, final Throwable error) { + execute(() -> caller.fail(collector, error)); + } + + @Override + public void cancel(final StreamCollector collector) { + execute(() -> caller.cancel(collector)); + } + + @Override + public void fail(final FutureFailed failed, final Throwable cause) { + execute(() -> caller.fail(failed, cause)); + } + + @Override + public void referenceLeaked(final T reference, final StackTraceElement[] stack) { + execute(() -> caller.referenceLeaked(reference, stack)); + } + + @Override + public void execute(final Runnable runnable) { + // Use thread local counter for recursionDepth + final Integer recursionDepth = recursionDepthPerThread.get(); + // ++ + recursionDepthPerThread.set(recursionDepth + 1); + + if (recursionDepth + 1 <= maxRecursionDepth) { + // Case A: Call immediately, this is default until we've reached deep recursion + runnable.run(); + } else { + /* + * Case B: Defer to a separate thread + * This happens when recursion depth of the current thread is larger than limit, to + * avoid stack overflow. + */ + executorService.submit(runnable); + } + + // -- + recursionDepthPerThread.set(recursionDepth); + } + + + @Override + public boolean isThreaded() { + return caller.isThreaded(); + } +} diff --git a/tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java b/tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java new file mode 100644 index 0000000..7bf4bb9 --- /dev/null +++ b/tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java @@ -0,0 +1,140 @@ +package eu.toolchain.async; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +import java.util.concurrent.ExecutorService; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@RunWith(MockitoJUnitRunner.class) +public class RecursionSafeAsyncCallerTest { + private final Object result = new Object(); + private final Throwable cause = new Exception(); + + private AsyncCaller caller; + + private RecursionSafeAsyncCaller underTest; + + @Mock + private FutureDone done; + @Mock + private FutureCancelled cancelled; + @Mock + private FutureFinished finished; + @Mock + private FutureResolved resolved; + @Mock + private FutureFailed failed; + + @Mock + private StreamCollector streamCollector; + + private StackTraceElement[] stack = new StackTraceElement[0]; + + + @Before + public void setup() { + ExecutorService executor = mock(ExecutorService.class); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + final Runnable runnable = (Runnable) invocation.getArguments()[0]; + runnable.run(); + return null; + } + }).when(executor).submit(any(Runnable.class)); + + caller = mock(AsyncCaller.class); + underTest = new RecursionSafeAsyncCaller(executor, caller); + } + + @Test + public void testIsThreaded() { + underTest.isThreaded(); + verify(caller).isThreaded(); + } + + @Test + public void testResolveFutureDone() { + underTest.resolve(done, result); + verify(caller).resolve(done, result); + } + + @Test + public void testFailFutureDone() { + underTest.fail(done, cause); + verify(caller).fail(done, cause); + } + + @Test + public void testCancelFutureDone() { + underTest.cancel(done); + verify(caller).cancel(done); + } + + @Test + public void testRunFutureCancelled() { + underTest.cancel(cancelled); + verify(caller).cancel(cancelled); + } + + @Test + public void testRunFutureFinished() { + underTest.finish(finished); + verify(caller).finish(finished); + } + + @Test + public void testRunFutureResolved() { + underTest.resolve(resolved, result); + verify(caller).resolve(resolved, result); + } + + @Test + public void testRunFutureFailed() { + underTest.fail(failed, cause); + verify(caller).fail(failed, cause); + } + + @Test + public void testResolveStreamCollector() { + underTest.resolve(streamCollector, result); + verify(caller).resolve(streamCollector, result); + } + + @Test + public void testFailStreamCollector() { + underTest.fail(streamCollector, cause); + verify(caller).fail(streamCollector, cause); + } + + @Test + public void testCancelStreamCollector() { + underTest.cancel(streamCollector); + verify(caller).cancel(streamCollector); + } + + @Test + public void testLeakedManagedReference() { + underTest.referenceLeaked(result, stack); + verify(caller).referenceLeaked(result, stack); + } + + @Test + public void testExecute() { + Runnable runnable = mock(Runnable.class); + underTest.execute(runnable); + verify(runnable).run(); + } +} From 0cbff234dc72ab32a57cb07b13adcf24ada97ea6 Mon Sep 17 00:00:00 2001 From: Gabriel Gerhardsson Date: Tue, 8 Nov 2016 15:02:36 +0100 Subject: [PATCH 2/2] Make use of RecursionSafeAsyncCaller Add option to TinyAsyncBuilder to make use of RecursionSafeAsyncCaller in front of both the non-threaded caller and threadedCaller --- .../eu/toolchain/async/TinyAsyncBuilder.java | 63 ++++++++++++++----- 1 file changed, 49 insertions(+), 14 deletions(-) diff --git a/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java b/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java index 7c6b578..2a53036 100644 --- a/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java +++ b/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java @@ -9,6 +9,7 @@ public class TinyAsyncBuilder { private AsyncCaller caller; private boolean threaded; + private boolean useRecursionSafeCaller; private ExecutorService executor; private ExecutorService callerExecutor; private ScheduledExecutorService scheduler; @@ -32,6 +33,24 @@ public TinyAsyncBuilder threaded(boolean threaded) { return this; } + /** + * Configure that all caller invocations should use a recursion safe mechanism. In the normal + * case this doesn't change the behaviour of caller and threadedCaller, but when deep recursion + * is detected in the current thread the next recursive call is deferred to a separate thread. + *

+ * Recursion is tracked for all threads that call the AsyncCallers. + *

+ * This will make even the non-threaded caller use a thread in the case of deep recursion. + * + * @param useRecursionSafeCaller Set {@code true} if all caller invocations should be done with + * a recursion safe mechanism, {@code false} otherwise. + * @return This builder. + */ + public TinyAsyncBuilder recursionSafeAsyncCaller(boolean useRecursionSafeCaller) { + this.useRecursionSafeCaller = useRecursionSafeCaller; + return this; + } + /** * Specify an asynchronous caller implementation. *

@@ -119,11 +138,17 @@ private AsyncCaller setupThreadedCaller(AsyncCaller caller, ExecutorService call return caller; } - if (callerExecutor != null) { - return new ExecutorAsyncCaller(callerExecutor, caller); + if (callerExecutor == null) { + return null; } - return null; + AsyncCaller threadedCaller = new ExecutorAsyncCaller(callerExecutor, caller); + + if (useRecursionSafeCaller) { + threadedCaller = new RecursionSafeAsyncCaller(callerExecutor, threadedCaller); + } + + return threadedCaller; } private ExecutorService setupDefaultExecutor() { @@ -164,18 +189,28 @@ private ExecutorService setupCallerExecutor(ExecutorService defaultExecutor) { * @return A caller implementation according to the provided configuration. */ private AsyncCaller setupCaller() { - if (caller == null) { - return new PrintStreamDefaultAsyncCaller( - System.err, Executors.newSingleThreadExecutor(new ThreadFactory() { - @Override - public Thread newThread(final Runnable r) { - final Thread thread = new Thread(r); - thread.setName("tiny-async-deferrer"); - return thread; - } - })); + if (caller != null) { + if (useRecursionSafeCaller && callerExecutor != null) { + // Wrap user supplied AsyncCaller + return new RecursionSafeAsyncCaller(callerExecutor, caller); + } + return caller; + } + + AsyncCaller newCaller = new PrintStreamDefaultAsyncCaller( + System.err, Executors.newSingleThreadExecutor(new ThreadFactory() { + @Override + public Thread newThread(final Runnable r) { + final Thread thread = new Thread(r); + thread.setName("tiny-async-deferrer"); + return thread; + } + })); + + if (useRecursionSafeCaller && callerExecutor != null) { + newCaller = new RecursionSafeAsyncCaller(callerExecutor, newCaller); } - return caller; + return newCaller; } }