Skip to content
This repository was archived by the owner on Oct 16, 2022. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package eu.toolchain.async;

import com.google.common.util.concurrent.AtomicLongMap;
import com.google.common.util.concurrent.MoreExecutors;
import eu.toolchain.async.RecursionSafeAsyncCaller;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import org.junit.Before;
import org.junit.Test;

import java.util.List;
import java.lang.ThreadLocal;
import java.lang.Thread;

import com.google.common.util.concurrent.AtomicLongMap;

import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

public class RecursionSafeAsyncCallerIntegrationTest {

private AtomicLongMap<Long> recursionDepthPerThread;
private AtomicLongMap<Long> maxRecursionDepthPerThread;
AtomicLong totIterations;

@SuppressWarnings("unchecked")
@Before
public void setup() throws Exception {
recursionDepthPerThread = AtomicLongMap.create();
maxRecursionDepthPerThread = AtomicLongMap.create();
totIterations = new AtomicLong(0);
}


public void testBasicRecursionMethod(RecursionSafeAsyncCaller caller, ConcurrentLinkedQueue<Integer> testData) {

class RecursionRunnable implements Runnable {
RecursionSafeAsyncCaller caller;
ConcurrentLinkedQueue<Integer> testData;

RecursionRunnable(RecursionSafeAsyncCaller caller, ConcurrentLinkedQueue<Integer> testData) {
this.caller = caller;
this.testData = testData;
}

@Override
public void run() {
Long threadId = Thread.currentThread().getId();
Long currDepth = recursionDepthPerThread.addAndGet(threadId, 1L);
Long currMax = maxRecursionDepthPerThread.get(threadId);
if (currDepth > currMax) {
maxRecursionDepthPerThread.put(threadId, currDepth);
}

if (testData.size() == 0)
return;
testData.poll();
totIterations.incrementAndGet();

// Recursive call, via caller
testBasicRecursionMethod(caller, testData);

recursionDepthPerThread.addAndGet(threadId, -1L);
}
};

RecursionRunnable runnable = new RecursionRunnable(caller, testData);
caller.execute(runnable);
}

@Test
public void testBasic() throws Exception {
final long MAX_RECURSION_DEPTH = 2;
ExecutorService executorServiceReal = Executors.newFixedThreadPool(10);
AsyncCaller caller2 = mock(AsyncCaller.class);
RecursionSafeAsyncCaller recursionCaller =
new RecursionSafeAsyncCaller(executorServiceReal, caller2, MAX_RECURSION_DEPTH);
ConcurrentLinkedQueue<Integer>
testData = new ConcurrentLinkedQueue<>(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10));

testBasicRecursionMethod(recursionCaller, testData);

// Wait for recursive calls in thread pool to create new work until done, or timeout.
// The executorServiceReal.shutdown() below is not enough since potentially some new work
// is still on the way in, due to recursive nature of the test.
long startTime = System.currentTimeMillis();
long maxTime = 1000;
while (testData.size() > 0) {
if (System.currentTimeMillis() - startTime > maxTime) {
// Timeout, test will fail in the assert below
break;
}
Thread.sleep(10);
}

executorServiceReal.shutdown();
executorServiceReal.awaitTermination(1000, TimeUnit.MILLISECONDS);

assert(testData.size() == 0);
assert(totIterations.get() == 10);

Long maxStackDepth=-1L;
Map<Long, Long> readOnlyMap = maxRecursionDepthPerThread.asMap();
for (Long key : readOnlyMap.keySet()) {
Long val = readOnlyMap.get(key);
if (val > maxStackDepth)
maxStackDepth = val;
}
assert(maxStackDepth != -1L);

// Checking with +1 since our initial call to testBasicRecursionMethod() above adds 1
assert(maxStackDepth <= MAX_RECURSION_DEPTH+1);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* An AsyncCaller implementation that will try to run the call in the current thread as much as
* possible, while keeping track of recursion to avoid StackOverflowException in the thread. If
* recursion becomes too deep, the next call is deferred to a separate thread (normal thread pool).
* State is kept per-thread - stack overflow will be avoided for any thread that passes this code.
* It is vital to choose a suitable maximum recursion depth.
*/
package eu.toolchain.async;

import java.util.concurrent.ExecutorService;

public final class RecursionSafeAsyncCaller implements AsyncCaller {
private final ExecutorService executorService;
private final AsyncCaller caller;
private final long maxRecursionDepth;

private class ThreadLocalInteger extends ThreadLocal<Integer> {
protected Integer initialValue() {
return 0;
}
};
private final ThreadLocalInteger recursionDepthPerThread;

public RecursionSafeAsyncCaller(ExecutorService executorService, AsyncCaller caller, long maxRecursionDepth) {
this.executorService = executorService;
this.caller = caller;
this.maxRecursionDepth = maxRecursionDepth;
this.recursionDepthPerThread = new ThreadLocalInteger();
}

public RecursionSafeAsyncCaller(ExecutorService executorService, AsyncCaller caller) {
this(executorService, caller, 100);
}

@Override
public <T> void resolve(final FutureDone<T> handle, final T result) {
execute(() -> caller.resolve(handle, result));
}

@Override
public <T> void fail(final FutureDone<T> handle, final Throwable error) {
execute(() -> caller.fail(handle, error));
}

@Override
public <T> void cancel(final FutureDone<T> handle) {
execute(() -> caller.cancel(handle));
}

@Override
public void cancel(final FutureCancelled cancelled) {
execute(() -> caller.cancel(cancelled));
}

@Override
public void finish(final FutureFinished finishable) {
execute(() -> caller.finish(finishable));
}

@Override
public <T> void resolve(final FutureResolved<T> resolved, final T value) {
execute(() -> caller.resolve(resolved, value));
}

@Override
public <T, R> void resolve(final StreamCollector<T, R> collector, final T result) {
execute(() -> caller.resolve(collector, result));
}

@Override
public <T, R> void fail(final StreamCollector<T, R> collector, final Throwable error) {
execute(() -> caller.fail(collector, error));
}

@Override
public <T, R> void cancel(final StreamCollector<T, R> collector) {
execute(() -> caller.cancel(collector));
}

@Override
public void fail(final FutureFailed failed, final Throwable cause) {
execute(() -> caller.fail(failed, cause));
}

@Override
public <T> void referenceLeaked(final T reference, final StackTraceElement[] stack) {
execute(() -> caller.referenceLeaked(reference, stack));
}

@Override
public void execute(final Runnable runnable) {
// Use thread local counter for recursionDepth
final Integer recursionDepth = recursionDepthPerThread.get();
// ++
recursionDepthPerThread.set(recursionDepth + 1);

if (recursionDepth + 1 <= maxRecursionDepth) {
// Case A: Call immediately, this is default until we've reached deep recursion
runnable.run();
} else {
/*
* Case B: Defer to a separate thread
* This happens when recursion depth of the current thread is larger than limit, to
* avoid stack overflow.
*/
executorService.submit(runnable);
}

// --
recursionDepthPerThread.set(recursionDepth);
}


@Override
public boolean isThreaded() {
return caller.isThreaded();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
public class TinyAsyncBuilder {
private AsyncCaller caller;
private boolean threaded;
private boolean useRecursionSafeCaller;
private ExecutorService executor;
private ExecutorService callerExecutor;
private ScheduledExecutorService scheduler;
Expand All @@ -32,6 +33,24 @@ public TinyAsyncBuilder threaded(boolean threaded) {
return this;
}

/**
* Configure that all caller invocations should use a recursion safe mechanism. In the normal
* case this doesn't change the behaviour of caller and threadedCaller, but when deep recursion
* is detected in the current thread the next recursive call is deferred to a separate thread.
* <p>
* Recursion is tracked for all threads that call the AsyncCallers.
* <p>
* This will make even the non-threaded caller use a thread in the case of deep recursion.
*
* @param useRecursionSafeCaller Set {@code true} if all caller invocations should be done with
* a recursion safe mechanism, {@code false} otherwise.
* @return This builder.
*/
public TinyAsyncBuilder recursionSafeAsyncCaller(boolean useRecursionSafeCaller) {
this.useRecursionSafeCaller = useRecursionSafeCaller;
return this;
}

/**
* Specify an asynchronous caller implementation.
* <p>
Expand Down Expand Up @@ -119,11 +138,17 @@ private AsyncCaller setupThreadedCaller(AsyncCaller caller, ExecutorService call
return caller;
}

if (callerExecutor != null) {
return new ExecutorAsyncCaller(callerExecutor, caller);
if (callerExecutor == null) {
return null;
}

return null;
AsyncCaller threadedCaller = new ExecutorAsyncCaller(callerExecutor, caller);

if (useRecursionSafeCaller) {
threadedCaller = new RecursionSafeAsyncCaller(callerExecutor, threadedCaller);
}

return threadedCaller;
}

private ExecutorService setupDefaultExecutor() {
Expand Down Expand Up @@ -164,18 +189,28 @@ private ExecutorService setupCallerExecutor(ExecutorService defaultExecutor) {
* @return A caller implementation according to the provided configuration.
*/
private AsyncCaller setupCaller() {
if (caller == null) {
return new PrintStreamDefaultAsyncCaller(
System.err, Executors.newSingleThreadExecutor(new ThreadFactory() {
@Override
public Thread newThread(final Runnable r) {
final Thread thread = new Thread(r);
thread.setName("tiny-async-deferrer");
return thread;
}
}));
if (caller != null) {
if (useRecursionSafeCaller && callerExecutor != null) {
// Wrap user supplied AsyncCaller
return new RecursionSafeAsyncCaller(callerExecutor, caller);
}
return caller;
}

AsyncCaller newCaller = new PrintStreamDefaultAsyncCaller(
System.err, Executors.newSingleThreadExecutor(new ThreadFactory() {
@Override
public Thread newThread(final Runnable r) {
final Thread thread = new Thread(r);
thread.setName("tiny-async-deferrer");
return thread;
}
}));

if (useRecursionSafeCaller && callerExecutor != null) {
newCaller = new RecursionSafeAsyncCaller(callerExecutor, newCaller);
}

return caller;
return newCaller;
}
}
Loading