diff --git a/zanata-war/src/main/java/org/zanata/limits/RateLimitingProcessor.java b/zanata-war/src/main/java/org/zanata/limits/RateLimitingProcessor.java index 99214580ba..26cb3c685c 100644 --- a/zanata-war/src/main/java/org/zanata/limits/RateLimitingProcessor.java +++ b/zanata-war/src/main/java/org/zanata/limits/RateLimitingProcessor.java @@ -1,9 +1,12 @@ package org.zanata.limits; +import java.io.IOException; import java.io.PrintWriter; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; import javax.servlet.FilterChain; +import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; @@ -15,6 +18,8 @@ import org.jboss.seam.servlet.ContextualHttpServletRequest; import org.zanata.ApplicationConfiguration; import com.google.common.base.Objects; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.RateLimiter; import lombok.extern.slf4j.Slf4j; /** @@ -27,9 +32,10 @@ */ @Slf4j public class RateLimitingProcessor extends ContextualHttpServletRequest { - // http://tools.ietf.org/html/rfc6585 public static final int TOO_MANY_REQUEST = 429; + + private static final RateLimiter logLimiter = RateLimiter.create(1); private final String apiKey; private final FilterChain filterChain; private final ServletRequest servletRequest; @@ -71,18 +77,26 @@ public void process() throws Exception { log.debug("check semaphore for {}", this); - if (rateLimiter.tryAcquire()) { - try { - filterChain.doFilter(servletRequest, servletResponse); - } finally { - log.debug("releasing semaphore for {}", apiKey); - rateLimiter.release(); + Runnable runnable = new Runnable() { + @Override + public void run() { + try { + filterChain.doFilter(servletRequest, servletResponse); + } + catch (IOException e) { + throw Throwables.propagate(e); + } + catch (ServletException e) { + throw Throwables.propagate(e); + } + } + }; + if (!rateLimiter.tryAcquireAndRun(runnable)) { + if (logLimiter.tryAcquire(1, TimeUnit.SECONDS)) { + log.warn( + "{} has too many concurrent requests. Returning status 429", + apiKey); } - } else { - // TODO pahuang rate limit the logging otherwise it may become excessive - log.warn( - "{} has too many concurrent requests. Returning status 429", - apiKey); httpResponse.setStatus(TOO_MANY_REQUEST); PrintWriter writer = httpResponse.getWriter(); writer.append(String.format( diff --git a/zanata-war/src/main/java/org/zanata/limits/RestCallLimiter.java b/zanata-war/src/main/java/org/zanata/limits/RestCallLimiter.java index 93ce478900..27dd42d2e8 100644 --- a/zanata-war/src/main/java/org/zanata/limits/RestCallLimiter.java +++ b/zanata-war/src/main/java/org/zanata/limits/RestCallLimiter.java @@ -4,6 +4,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.base.Objects; +import com.google.common.base.Throwables; import com.google.common.util.concurrent.RateLimiter; import com.google.common.util.concurrent.Uninterruptibles; import lombok.EqualsAndHashCode; @@ -17,11 +18,12 @@ */ @Slf4j public class RestCallLimiter { - private Semaphore maxConcurrentSemaphore; - private Semaphore maxActiveSemaphore; + private volatile Semaphore maxConcurrentSemaphore; + private volatile Semaphore maxActiveSemaphore; private RateLimiter rateLimiter; private RateLimitConfig limitConfig; - private volatile ActiveLimitChange change; + private volatile LimitChange activeChange; + private volatile LimitChange concurrentChange; public RestCallLimiter(RateLimitConfig limitConfig) { this.limitConfig = limitConfig; @@ -31,60 +33,89 @@ public RestCallLimiter(RateLimitConfig limitConfig) { rateLimiter = RateLimiter.create(limitConfig.rateLimitPerSecond); } - public boolean tryAcquire() { - log.debug("before try acquire concurrent semaphore:{}", - maxConcurrentSemaphore); - boolean got = maxConcurrentSemaphore.tryAcquire(); - log.debug("get permit:{}", got); - if (got) { - acquireActiveAndRatePermit(); - log.debug("got all permits and ready to go: {}", this); + public boolean tryAcquireAndRun(Runnable taskAfterAcquire) { + applyConcurrentPermitChangeIfApplicable(); + boolean gotConcurrentPermit = maxConcurrentSemaphore.tryAcquire(); + log.debug("try acquire [concurrent] permit:{}", gotConcurrentPermit); + if (gotConcurrentPermit) { + try { + if (acquireActiveAndRatePermit()) { + try { + taskAfterAcquire.run(); + } finally { + log.debug("releasing active concurrent semaphore"); + maxActiveSemaphore.release(); + } + } else { + throw new RuntimeException( + "Couldn't get an [active] permit in time"); + } + } finally { + log.debug("releasing max [concurrent] semaphore"); + maxConcurrentSemaphore.release(); + } + } + return gotConcurrentPermit; + } + + private boolean acquireActiveAndRatePermit() { + applyActivePermitChangeIfApplicable(); + log.debug("before acquire [active] semaphore:{}", maxActiveSemaphore); + try { + boolean gotActivePermit = + maxActiveSemaphore.tryAcquire(5, TimeUnit.MINUTES); + log.debug( + "got [active] semaphore [{}] and before acquire rate limit permit:{}", + gotActivePermit, rateLimiter); + if (gotActivePermit) { + rateLimiter.acquire(); + } + return gotActivePermit; + } catch (InterruptedException e) { + throw Throwables.propagate(e); + } + } + + private void applyConcurrentPermitChangeIfApplicable() { + if (concurrentChange != null) { + synchronized (this) { + if (concurrentChange != null) { + log.debug( + "change max [concurrent] semaphore with new permit ", + concurrentChange.newLimit); + maxConcurrentSemaphore = + new Semaphore(concurrentChange.newLimit, true); + concurrentChange = null; + } + } } - return got; } - private void acquireActiveAndRatePermit() { - if (change != null) { + private void applyActivePermitChangeIfApplicable() { + if (activeChange != null) { synchronized (this) { - if (change != null) { + if (activeChange != null) { // since this block is synchronized, there won't be new // permit acquired from maxActiveSemaphore other than this // thread. It ought to be the last and only one entering in // this block. It will have to wait for all other previous // blocked threads to complete before changing the semaphore log.debug( - "detects max active permit change [{}]. Will sleep until all blocking threads [#{}] released.", - change, maxActiveSemaphore.getQueueLength()); - while (maxActiveSemaphore.availablePermits() != change.oldLimit) { + "detects max [active] permit change [{}]. Will sleep until all blocking threads [#{}] released.", + activeChange, maxActiveSemaphore.getQueueLength()); + while (maxActiveSemaphore.availablePermits() != activeChange.oldLimit) { Uninterruptibles.sleepUninterruptibly(1, TimeUnit.NANOSECONDS); } - log.debug("change max active semaphore with new permit"); - maxActiveSemaphore = new Semaphore(change.newLimit, true); - change = null; + log.debug( + "change max [active] semaphore with new permit {}", + activeChange.newLimit); + maxActiveSemaphore = + new Semaphore(activeChange.newLimit, true); + activeChange = null; } } } - log.debug("before acquire active semaphore:{}", maxActiveSemaphore); - maxActiveSemaphore.acquireUninterruptibly(); -// if we want to enable timeout here, -// we must ensure release is not called when it timed out -// try { -// boolean gotIt = maxActiveSemaphore.tryAcquire(30, TimeUnit.SECONDS); -// if (!gotIt) { -// // timed out -// throw new WebApplicationException(Response.status( -// Response.Status.SERVICE_UNAVAILABLE) -// .entity("System too busy").build()); -// } -// } -// catch (InterruptedException e) { -// throw Throwables.propagate(e); -// } - log.debug( - "got active semaphore and before acquire rate limit permit:{}", - rateLimiter); - rateLimiter.acquire(); } public void release() { @@ -96,27 +127,29 @@ public void release() { public void changeConfig(RateLimitConfig newLimitConfig) { if (newLimitConfig.maxConcurrent != limitConfig.maxConcurrent) { - changeConcurrentLimit(newLimitConfig.maxConcurrent); + changeConcurrentLimit(limitConfig.maxConcurrent, + newLimitConfig.maxConcurrent); } if (newLimitConfig.rateLimitPerSecond != limitConfig.rateLimitPerSecond) { changeRateLimitPermitsPerSecond(newLimitConfig.rateLimitPerSecond); } + if (newLimitConfig.maxActive != limitConfig.maxActive) { + changeActiveLimit(limitConfig.maxActive, newLimitConfig.maxActive); + } limitConfig = newLimitConfig; } - protected synchronized void changeConcurrentLimit(int maxConcurrent) { - log.info("max concurrent limit changed: {}", maxConcurrent); - maxConcurrentSemaphore = new Semaphore(maxConcurrent); + protected synchronized void + changeConcurrentLimit(int oldLimit, int newLimit) { + this.concurrentChange = new LimitChange(oldLimit, newLimit); } protected synchronized void changeRateLimitPermitsPerSecond(double permits) { - log.info("rate limit changed: {}", permits); rateLimiter.setRate(permits); } protected synchronized void changeActiveLimit(int oldLimit, int newLimit) { - this.change = new ActiveLimitChange(oldLimit, newLimit); - log.info("max active limit changed: {}", change); + this.activeChange = new LimitChange(oldLimit, newLimit); } public int availableConcurrentPermit() { @@ -136,8 +169,10 @@ public String toString() { return Objects .toStringHelper(this) .add("id", super.toString()) - .add("maxConcurrent(available)", maxConcurrentSemaphore.availablePermits()) - .add("maxActive(available)", maxActiveSemaphore.availablePermits()) + .add("maxConcurrent(available)", + maxConcurrentSemaphore.availablePermits()) + .add("maxActive(available)", + maxActiveSemaphore.availablePermits()) .add("maxActive(queue)", maxActiveSemaphore.getQueueLength()) .add("rateLimiter", rateLimiter).toString(); } @@ -153,7 +188,7 @@ public static class RateLimitConfig { @RequiredArgsConstructor @ToString - private static class ActiveLimitChange { + private static class LimitChange { private final int oldLimit; private final int newLimit; } diff --git a/zanata-war/src/test/java/org/zanata/ZanataRestTest.java b/zanata-war/src/test/java/org/zanata/ZanataRestTest.java index d5fc72c96b..c153d4fc4a 100644 --- a/zanata-war/src/test/java/org/zanata/ZanataRestTest.java +++ b/zanata-war/src/test/java/org/zanata/ZanataRestTest.java @@ -159,11 +159,7 @@ protected void prepareProviders() { protected void prepareSeamAutowire() { seamAutowire .reset() - .ignoreNonResolvable() - .use(SeamAutowire.getComponentName(JndiBackedConfig.class), - jndiBackedConfig) - .use(SeamAutowire.getComponentName(RateLimitManager.class), - new RateLimitManager()); + .ignoreNonResolvable(); } /** diff --git a/zanata-war/src/test/java/org/zanata/limits/RateLimitingProcessorTest.java b/zanata-war/src/test/java/org/zanata/limits/RateLimitingProcessorTest.java index 55538531ea..2041273509 100644 --- a/zanata-war/src/test/java/org/zanata/limits/RateLimitingProcessorTest.java +++ b/zanata-war/src/test/java/org/zanata/limits/RateLimitingProcessorTest.java @@ -122,27 +122,5 @@ public Void call() throws Exception { verify(response, atLeastOnce()).setStatus(429); // one should go through verify(filterChain).doFilter(request, response); - // semaphore is released - assertThat(rateLimitManager.getIfPresent(API_KEY) - .availableConcurrentPermit(), Matchers.equalTo(1)); - } - - @Test - public void willReleaseSemaphoreWhenThereIsException() throws IOException, - ServletException { - when(rateLimitManager.getLimitConfig()).thenReturn( - new RestCallLimiter.RateLimitConfig(1, 1, 100.0)); - when(applicationConfiguration.getRateLimitSwitch()).thenReturn(true); - doThrow(new RuntimeException("bad")).when(filterChain).doFilter( - request, response); - - try { - processor.process(); - } catch (Exception e) { - // I know - } - - assertThat(rateLimitManager.getIfPresent(API_KEY) - .availableConcurrentPermit(), Matchers.equalTo(1)); } } diff --git a/zanata-war/src/test/java/org/zanata/limits/RestCallLimiterTest.java b/zanata-war/src/test/java/org/zanata/limits/RestCallLimiterTest.java index be6cd88090..3bbee945af 100644 --- a/zanata-war/src/test/java/org/zanata/limits/RestCallLimiterTest.java +++ b/zanata-war/src/test/java/org/zanata/limits/RestCallLimiterTest.java @@ -1,13 +1,17 @@ package org.zanata.limits; +import java.io.IOException; import java.util.Collections; import java.util.List; import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; +import javax.servlet.ServletException; + import org.apache.log4j.Level; import org.apache.log4j.LogManager; import org.apache.log4j.Logger; @@ -16,20 +20,29 @@ import org.junit.rules.TestRule; import org.junit.runner.Description; import org.junit.runners.model.Statement; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.testng.annotations.BeforeClass; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -import org.zanata.limits.RestCallLimiter; import com.google.common.base.Function; import com.google.common.base.Predicate; import com.google.common.base.Stopwatch; import com.google.common.base.Throwables; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.ListeningExecutorService; +import com.google.common.util.concurrent.MoreExecutors; import com.google.common.util.concurrent.Uninterruptibles; import lombok.extern.slf4j.Slf4j; import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.doThrow; /** * @author Patrick Huang task = new Callable() { + @Override + public Boolean call() throws Exception { + return limiter.tryAcquireAndRun(runntable); + } + }; + int numOfThreads = maxConcurrent + 1; + List result = + submitConcurrentTasksAndGetResult(task, numOfThreads); + log.debug("result: {}", result); + // requests that are within the max concurrent limit should get permit + Iterable successRequest = + Iterables.filter(result, new Predicate() { + @Override + public boolean apply(Boolean input) { + return input; + } + }); + assertThat(successRequest, + Matchers. iterableWithSize(maxConcurrent)); + // last request which exceeds the limit will fail to get permit + assertThat(result, Matchers.hasItem(false)); + } + + static List submitConcurrentTasksAndGetResult(Callable task, + int numOfThreads) throws InterruptedException, ExecutionException { + List> tasks = Collections.nCopies(numOfThreads, task); + final ListeningExecutorService executorService = + MoreExecutors.listeningDecorator(Executors + .newFixedThreadPool(numOfThreads)); + List> listenableFutures = + Lists.transform(tasks, new ToListenableFuture( + executorService)); + + ListenableFuture> listListenableFuture = + Futures.successfulAsList(listenableFutures); + return listListenableFuture.get(); } @Test public void canOnlyHaveMaxActiveConcurrentRequest() - throws InterruptedException { - // Given: each thread will take 15ms to do its job - final int timeSpendDoingWork = 15; + throws InterruptedException, ExecutionException { + // Given: each thread will take some time to do its job + final int timeSpentDoingWork = 30; + runnableWillTakeTime(timeSpentDoingWork); // When: max concurrent threads are accessing simultaneously - Callable callable = taskAcquireThenRelease(timeSpendDoingWork); - List> tasks = - Collections.nCopies(maxConcurrent, callable); - ExecutorService executorService = - Executors.newFixedThreadPool(maxConcurrent); - List> futures = executorService.invokeAll(tasks); + Callable callable = + taskToAcquireAndMeasureBlockedTime(timeSpentDoingWork); // Then: only max active threads will be served immediately while others // will block until them finish - List timeUsedInMillis = - getTimeUsedInMillisRoundedUpToTens(futures); - log.info("result: {}", timeUsedInMillis); + List timeBlockedInMillis = + submitConcurrentTasksAndGetResult(callable, maxConcurrent); + log.debug("result: {}", timeBlockedInMillis); Iterable blocked = - Iterables.filter(timeUsedInMillis, new BlockedPredicate()); + Iterables.filter(timeBlockedInMillis, new BlockedPredicate()); assertThat(blocked, Matchers. iterableWithSize(maxConcurrent - maxActive)); } + void runnableWillTakeTime(final int timeSpentDoingWork) { + Mockito.doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocation) throws Throwable { + Uninterruptibles.sleepUninterruptibly(timeSpentDoingWork, + TimeUnit.MILLISECONDS); + return null; + } + }).when(runntable).run(); + } + @Test public void activeRequestWillBeRateLimited() throws InterruptedException { // Given: I am within max active threads count and can do my job RIGHT // NOW and permits per second is 1 - final int timeSpendDoingWork = 0; double permitsPerSecond = 1; limiter.changeRateLimitPermitsPerSecond(permitsPerSecond); // When: I start working on heavy duty stuff - Callable callable = taskAcquireThenRelease(timeSpendDoingWork); + Callable callable = taskToAcquireAndMeasureBlockedTime(0); List> tasks = Collections.nCopies(maxActive, callable); ExecutorService executorService = Executors.newFixedThreadPool(maxActive); @@ -145,14 +213,37 @@ public void activeRequestWillBeRateLimited() throws InterruptedException { } @Test - public void changeMaxConcurrentLimitWillTakeEffectImmediately() { + private void changeMaxConcurrentLimitWillTakeEffectImmediately() + throws ExecutionException, InterruptedException { + runnableWillTakeTime(10); + + // we start off with only 1 concurrent permit limiter = new RestCallLimiter(new RestCallLimiter.RateLimitConfig(1, 10, 1)); - assertThat(limiter.tryAcquire(), Matchers.is(true)); - assertThat(limiter.tryAcquire(), Matchers.is(false)); - limiter.changeConcurrentLimit(2); - assertThat(limiter.tryAcquire(), Matchers.is(true)); + Callable task = new Callable() { + + @Override + public Boolean call() throws Exception { + return limiter.tryAcquireAndRun(runntable); + } + }; + + int numOfThreads = 2; + List result = + submitConcurrentTasksAndGetResult(task, numOfThreads); + assertThat(result, Matchers.containsInAnyOrder(true, false)); + assertThat(limiter.availableConcurrentPermit(), Matchers.is(1)); + + // change permit to match number of threads + limiter.changeConcurrentLimit(1, numOfThreads); + + List resultAfterChange = + submitConcurrentTasksAndGetResult(task, numOfThreads); + assertThat(resultAfterChange, Matchers.contains(true, true)); + + assertThat(limiter.availableConcurrentPermit(), + Matchers.is(numOfThreads)); } @Test @@ -160,16 +251,17 @@ public void changeMaxActiveLimitWhenNoBlockedThreads() { limiter = new RestCallLimiter(new RestCallLimiter.RateLimitConfig(3, 3, 1000)); - assertThat(acquireAndMeasureBlockedTime(), Matchers.equalTo(0L)); - limiter.release(); + limiter.tryAcquireAndRun(runntable); limiter.changeActiveLimit(3, 2); - assertThat(acquireAndMeasureBlockedTime(), Matchers.equalTo(0L)); - limiter.release(); + // change won't happen until next request comes in + limiter.tryAcquireAndRun(runntable); + assertThat(limiter.availableActivePermit(), Matchers.is(2)); limiter.changeActiveLimit(2, 1); - assertThat(acquireAndMeasureBlockedTime(), Matchers.equalTo(0L)); - assertThat(limiter.availableActivePermit(), Matchers.is(0)); + + limiter.tryAcquireAndRun(runntable); + assertThat(limiter.availableActivePermit(), Matchers.is(1)); } @Test @@ -182,7 +274,10 @@ public void changeMaxActiveLimitWhenHasBlockedThreads() // When: below requests are fired simultaneously // 3 requests (each takes 20ms) and 1 request should block - Callable callable = taskAcquireThenRelease(20); + final int timeSpentDoingWork = 20; + runnableWillTakeTime(timeSpentDoingWork); + Callable callable = + taskToAcquireAndMeasureBlockedTime(timeSpentDoingWork); List> requests = Collections.nCopies(3, callable); // 1 task to update the active permit with 5ms delay // (so that it will happen while there is a blocked request) @@ -205,9 +300,7 @@ public Long call() throws Exception { // ensure this happen after change limit took place Uninterruptibles .sleepUninterruptibly(10, TimeUnit.MILLISECONDS); - long blockedTime = acquireAndMeasureBlockedTime(); - limiter.release(); - return blockedTime; + return tryAcquireAndMeasureBlockedTime(timeSpentDoingWork); } }; List> delayedRequests = @@ -238,31 +331,47 @@ public Long call() throws Exception { assertThat(blocked, Matchers. iterableWithSize(3)); } - // it will measure acquire blocking time and return it - private Callable taskAcquireThenRelease( - final int timeSpendDoingWorkInMillis) { + @Test + public void willReleaseSemaphoreWhenThereIsException() throws IOException, + ServletException { + doThrow(new RuntimeException("bad")).when(runntable).run(); + + try { + limiter.tryAcquireAndRun(runntable); + } catch (Exception e) { + // I know + } + + assertThat(limiter.availableConcurrentPermit(), + Matchers.equalTo(maxConcurrent)); + assertThat(limiter.availableActivePermit(), Matchers.equalTo(maxActive)); + } + + /** + * it will measure acquire blocking time and return it. + */ + private Callable taskToAcquireAndMeasureBlockedTime( + final long timeSpentDoingWork) { return new Callable() { @Override public Long call() throws Exception { - long blockedTime = acquireAndMeasureBlockedTime(); - // spend some time doing some real work - Uninterruptibles.sleepUninterruptibly( - timeSpendDoingWorkInMillis, TimeUnit.MILLISECONDS); - limiter.release(); - return blockedTime; + return tryAcquireAndMeasureBlockedTime(timeSpentDoingWork); } }; } - private long acquireAndMeasureBlockedTime() { + private long tryAcquireAndMeasureBlockedTime(long timeSpentDoingWork) { Stopwatch stopwatch = new Stopwatch(); stopwatch.start(); - limiter.tryAcquire(); + limiter.tryAcquireAndRun(runntable); stopwatch.stop(); - long blockedTime = stopwatch.elapsedMillis(); - log.info("blocked: {}", blockedTime); - return roundToTens(blockedTime); + long timeSpent = stopwatch.elapsedMillis(); + log.debug("real time try acquire and run task takes: {}", timeSpent); + long blockedTime = + roundToTens(timeSpent) - roundToTens(timeSpentDoingWork); + log.debug("blocked: {}", blockedTime); + return blockedTime; } private static List getTimeUsedInMillisRoundedUpToTens( @@ -280,13 +389,28 @@ public Long apply(Future input) { } private static long roundToTens(long arg) { - return Math.round(arg / 10.0) * 10; + return arg / 10 * 10; } private static class BlockedPredicate implements Predicate { + @Override public boolean apply(Long input) { return input > 0; } } + + private static class ToListenableFuture implements + Function, ListenableFuture> { + private final ListeningExecutorService executorService; + + public ToListenableFuture(ListeningExecutorService executorService) { + this.executorService = executorService; + } + + @Override + public ListenableFuture apply(Callable input) { + return executorService.submit(input); + } + } }