diff --git a/tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java b/tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java new file mode 100644 index 0000000..00fdd6f --- /dev/null +++ b/tiny-async-core/src/it/java/eu/toolchain/async/RecursionSafeAsyncCallerIntegrationTest.java @@ -0,0 +1,123 @@ +package eu.toolchain.async; + +import com.google.common.util.concurrent.AtomicLongMap; +import com.google.common.util.concurrent.MoreExecutors; +import eu.toolchain.async.RecursionSafeAsyncCaller; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.Before; +import org.junit.Test; + +import java.util.List; +import java.lang.ThreadLocal; +import java.lang.Thread; + +import com.google.common.util.concurrent.AtomicLongMap; + +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class RecursionSafeAsyncCallerIntegrationTest { + + private AtomicLongMap recursionDepthPerThread; + private AtomicLongMap maxRecursionDepthPerThread; + AtomicLong totIterations; + + @SuppressWarnings("unchecked") + @Before + public void setup() throws Exception { + recursionDepthPerThread = AtomicLongMap.create(); + maxRecursionDepthPerThread = AtomicLongMap.create(); + totIterations = new AtomicLong(0); + } + + + public void testBasicRecursionMethod(RecursionSafeAsyncCaller caller, ConcurrentLinkedQueue testData) { + + class RecursionRunnable implements Runnable { + RecursionSafeAsyncCaller caller; + ConcurrentLinkedQueue testData; + + RecursionRunnable(RecursionSafeAsyncCaller caller, ConcurrentLinkedQueue testData) { + this.caller = caller; + this.testData = testData; + } + + @Override + public void run() { + Long threadId = Thread.currentThread().getId(); + Long currDepth = recursionDepthPerThread.addAndGet(threadId, 1L); + Long currMax = maxRecursionDepthPerThread.get(threadId); + if (currDepth > currMax) { + maxRecursionDepthPerThread.put(threadId, currDepth); + } + + if (testData.size() == 0) + return; + testData.poll(); + totIterations.incrementAndGet(); + + // Recursive call, via caller + testBasicRecursionMethod(caller, testData); + + recursionDepthPerThread.addAndGet(threadId, -1L); + } + }; + + RecursionRunnable runnable = new RecursionRunnable(caller, testData); + caller.execute(runnable); + } + + @Test + public void testBasic() throws Exception { + final long MAX_RECURSION_DEPTH = 2; + ExecutorService executorServiceReal = Executors.newFixedThreadPool(10); + AsyncCaller caller2 = mock(AsyncCaller.class); + RecursionSafeAsyncCaller recursionCaller = + new RecursionSafeAsyncCaller(executorServiceReal, caller2, MAX_RECURSION_DEPTH); + ConcurrentLinkedQueue + testData = new ConcurrentLinkedQueue<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)); + + testBasicRecursionMethod(recursionCaller, testData); + + // Wait for recursive calls in thread pool to create new work until done, or timeout. + // The executorServiceReal.shutdown() below is not enough since potentially some new work + // is still on the way in, due to recursive nature of the test. + long startTime = System.currentTimeMillis(); + long maxTime = 1000; + while (testData.size() > 0) { + if (System.currentTimeMillis() - startTime > maxTime) { + // Timeout, test will fail in the assert below + break; + } + Thread.sleep(10); + } + + executorServiceReal.shutdown(); + executorServiceReal.awaitTermination(1000, TimeUnit.MILLISECONDS); + + assert(testData.size() == 0); + assert(totIterations.get() == 10); + + Long maxStackDepth=-1L; + Map readOnlyMap = maxRecursionDepthPerThread.asMap(); + for (Long key : readOnlyMap.keySet()) { + Long val = readOnlyMap.get(key); + if (val > maxStackDepth) + maxStackDepth = val; + } + assert(maxStackDepth != -1L); + + // Checking with +1 since our initial call to testBasicRecursionMethod() above adds 1 + assert(maxStackDepth <= MAX_RECURSION_DEPTH+1); + } +} diff --git a/tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java b/tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java new file mode 100644 index 0000000..ccb33ab --- /dev/null +++ b/tiny-async-core/src/main/java/eu/toolchain/async/RecursionSafeAsyncCaller.java @@ -0,0 +1,118 @@ +/* + * An AsyncCaller implementation that will try to run the call in the current thread as much as + * possible, while keeping track of recursion to avoid StackOverflowException in the thread. If + * recursion becomes too deep, the next call is deferred to a separate thread (normal thread pool). + * State is kept per-thread - stack overflow will be avoided for any thread that passes this code. + * It is vital to choose a suitable maximum recursion depth. + */ +package eu.toolchain.async; + +import java.util.concurrent.ExecutorService; + +public final class RecursionSafeAsyncCaller implements AsyncCaller { + private final ExecutorService executorService; + private final AsyncCaller caller; + private final long maxRecursionDepth; + + private class ThreadLocalInteger extends ThreadLocal { + protected Integer initialValue() { + return 0; + } + }; + private final ThreadLocalInteger recursionDepthPerThread; + + public RecursionSafeAsyncCaller(ExecutorService executorService, AsyncCaller caller, long maxRecursionDepth) { + this.executorService = executorService; + this.caller = caller; + this.maxRecursionDepth = maxRecursionDepth; + this.recursionDepthPerThread = new ThreadLocalInteger(); + } + + public RecursionSafeAsyncCaller(ExecutorService executorService, AsyncCaller caller) { + this(executorService, caller, 100); + } + + @Override + public void resolve(final FutureDone handle, final T result) { + execute(() -> caller.resolve(handle, result)); + } + + @Override + public void fail(final FutureDone handle, final Throwable error) { + execute(() -> caller.fail(handle, error)); + } + + @Override + public void cancel(final FutureDone handle) { + execute(() -> caller.cancel(handle)); + } + + @Override + public void cancel(final FutureCancelled cancelled) { + execute(() -> caller.cancel(cancelled)); + } + + @Override + public void finish(final FutureFinished finishable) { + execute(() -> caller.finish(finishable)); + } + + @Override + public void resolve(final FutureResolved resolved, final T value) { + execute(() -> caller.resolve(resolved, value)); + } + + @Override + public void resolve(final StreamCollector collector, final T result) { + execute(() -> caller.resolve(collector, result)); + } + + @Override + public void fail(final StreamCollector collector, final Throwable error) { + execute(() -> caller.fail(collector, error)); + } + + @Override + public void cancel(final StreamCollector collector) { + execute(() -> caller.cancel(collector)); + } + + @Override + public void fail(final FutureFailed failed, final Throwable cause) { + execute(() -> caller.fail(failed, cause)); + } + + @Override + public void referenceLeaked(final T reference, final StackTraceElement[] stack) { + execute(() -> caller.referenceLeaked(reference, stack)); + } + + @Override + public void execute(final Runnable runnable) { + // Use thread local counter for recursionDepth + final Integer recursionDepth = recursionDepthPerThread.get(); + // ++ + recursionDepthPerThread.set(recursionDepth + 1); + + if (recursionDepth + 1 <= maxRecursionDepth) { + // Case A: Call immediately, this is default until we've reached deep recursion + runnable.run(); + } else { + /* + * Case B: Defer to a separate thread + * This happens when recursion depth of the current thread is larger than limit, to + * avoid stack overflow. + */ + executorService.submit(runnable); + } + + // -- + recursionDepthPerThread.set(recursionDepth); + } + + + @Override + public boolean isThreaded() { + return caller.isThreaded(); + } +} diff --git a/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java b/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java index 7c6b578..2a53036 100644 --- a/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java +++ b/tiny-async-core/src/main/java/eu/toolchain/async/TinyAsyncBuilder.java @@ -9,6 +9,7 @@ public class TinyAsyncBuilder { private AsyncCaller caller; private boolean threaded; + private boolean useRecursionSafeCaller; private ExecutorService executor; private ExecutorService callerExecutor; private ScheduledExecutorService scheduler; @@ -32,6 +33,24 @@ public TinyAsyncBuilder threaded(boolean threaded) { return this; } + /** + * Configure that all caller invocations should use a recursion safe mechanism. In the normal + * case this doesn't change the behaviour of caller and threadedCaller, but when deep recursion + * is detected in the current thread the next recursive call is deferred to a separate thread. + *

+ * Recursion is tracked for all threads that call the AsyncCallers. + *

+ * This will make even the non-threaded caller use a thread in the case of deep recursion. + * + * @param useRecursionSafeCaller Set {@code true} if all caller invocations should be done with + * a recursion safe mechanism, {@code false} otherwise. + * @return This builder. + */ + public TinyAsyncBuilder recursionSafeAsyncCaller(boolean useRecursionSafeCaller) { + this.useRecursionSafeCaller = useRecursionSafeCaller; + return this; + } + /** * Specify an asynchronous caller implementation. *

@@ -119,11 +138,17 @@ private AsyncCaller setupThreadedCaller(AsyncCaller caller, ExecutorService call return caller; } - if (callerExecutor != null) { - return new ExecutorAsyncCaller(callerExecutor, caller); + if (callerExecutor == null) { + return null; } - return null; + AsyncCaller threadedCaller = new ExecutorAsyncCaller(callerExecutor, caller); + + if (useRecursionSafeCaller) { + threadedCaller = new RecursionSafeAsyncCaller(callerExecutor, threadedCaller); + } + + return threadedCaller; } private ExecutorService setupDefaultExecutor() { @@ -164,18 +189,28 @@ private ExecutorService setupCallerExecutor(ExecutorService defaultExecutor) { * @return A caller implementation according to the provided configuration. */ private AsyncCaller setupCaller() { - if (caller == null) { - return new PrintStreamDefaultAsyncCaller( - System.err, Executors.newSingleThreadExecutor(new ThreadFactory() { - @Override - public Thread newThread(final Runnable r) { - final Thread thread = new Thread(r); - thread.setName("tiny-async-deferrer"); - return thread; - } - })); + if (caller != null) { + if (useRecursionSafeCaller && callerExecutor != null) { + // Wrap user supplied AsyncCaller + return new RecursionSafeAsyncCaller(callerExecutor, caller); + } + return caller; + } + + AsyncCaller newCaller = new PrintStreamDefaultAsyncCaller( + System.err, Executors.newSingleThreadExecutor(new ThreadFactory() { + @Override + public Thread newThread(final Runnable r) { + final Thread thread = new Thread(r); + thread.setName("tiny-async-deferrer"); + return thread; + } + })); + + if (useRecursionSafeCaller && callerExecutor != null) { + newCaller = new RecursionSafeAsyncCaller(callerExecutor, newCaller); } - return caller; + return newCaller; } } diff --git a/tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java b/tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java new file mode 100644 index 0000000..7bf4bb9 --- /dev/null +++ b/tiny-async-core/src/test/java/eu/toolchain/async/RecursionSafeAsyncCallerTest.java @@ -0,0 +1,140 @@ +package eu.toolchain.async; + +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.runners.MockitoJUnitRunner; +import org.mockito.stubbing.Answer; + +import java.util.concurrent.ExecutorService; + +import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +@RunWith(MockitoJUnitRunner.class) +public class RecursionSafeAsyncCallerTest { + private final Object result = new Object(); + private final Throwable cause = new Exception(); + + private AsyncCaller caller; + + private RecursionSafeAsyncCaller underTest; + + @Mock + private FutureDone done; + @Mock + private FutureCancelled cancelled; + @Mock + private FutureFinished finished; + @Mock + private FutureResolved resolved; + @Mock + private FutureFailed failed; + + @Mock + private StreamCollector streamCollector; + + private StackTraceElement[] stack = new StackTraceElement[0]; + + + @Before + public void setup() { + ExecutorService executor = mock(ExecutorService.class); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + final Runnable runnable = (Runnable) invocation.getArguments()[0]; + runnable.run(); + return null; + } + }).when(executor).submit(any(Runnable.class)); + + caller = mock(AsyncCaller.class); + underTest = new RecursionSafeAsyncCaller(executor, caller); + } + + @Test + public void testIsThreaded() { + underTest.isThreaded(); + verify(caller).isThreaded(); + } + + @Test + public void testResolveFutureDone() { + underTest.resolve(done, result); + verify(caller).resolve(done, result); + } + + @Test + public void testFailFutureDone() { + underTest.fail(done, cause); + verify(caller).fail(done, cause); + } + + @Test + public void testCancelFutureDone() { + underTest.cancel(done); + verify(caller).cancel(done); + } + + @Test + public void testRunFutureCancelled() { + underTest.cancel(cancelled); + verify(caller).cancel(cancelled); + } + + @Test + public void testRunFutureFinished() { + underTest.finish(finished); + verify(caller).finish(finished); + } + + @Test + public void testRunFutureResolved() { + underTest.resolve(resolved, result); + verify(caller).resolve(resolved, result); + } + + @Test + public void testRunFutureFailed() { + underTest.fail(failed, cause); + verify(caller).fail(failed, cause); + } + + @Test + public void testResolveStreamCollector() { + underTest.resolve(streamCollector, result); + verify(caller).resolve(streamCollector, result); + } + + @Test + public void testFailStreamCollector() { + underTest.fail(streamCollector, cause); + verify(caller).fail(streamCollector, cause); + } + + @Test + public void testCancelStreamCollector() { + underTest.cancel(streamCollector); + verify(caller).cancel(streamCollector); + } + + @Test + public void testLeakedManagedReference() { + underTest.referenceLeaked(result, stack); + verify(caller).referenceLeaked(result, stack); + } + + @Test + public void testExecute() { + Runnable runnable = mock(Runnable.class); + underTest.execute(runnable); + verify(runnable).run(); + } +}