diff --git a/CHANGES b/CHANGES index 46bcf6852e..bb0bb25e21 100644 --- a/CHANGES +++ b/CHANGES @@ -50,6 +50,12 @@ API Changes: Bug Fixes: + * util-core: AsyncMeter had a bug where if the burst size was smaller than + the number of disbursed tokens, it would discard all of the tokens over + the disbursal limit. Changed to instead process tokens in the wait queue + with leftover tokens. This improves behavior where the actual period is + smaller than can actually be simulated with the given timer. ``RB_ID=836742`` + * util-zk: Fixed race when an existing permit is released between the time the list was gotten and the data was checked. ``RB_ID=835856`` diff --git a/util-core/src/main/scala/com/twitter/concurrent/AsyncMeter.scala b/util-core/src/main/scala/com/twitter/concurrent/AsyncMeter.scala index b4cb94843b..3ce8292b47 100644 --- a/util-core/src/main/scala/com/twitter/concurrent/AsyncMeter.scala +++ b/util-core/src/main/scala/com/twitter/concurrent/AsyncMeter.scala @@ -238,20 +238,20 @@ class AsyncMeter private( } private[this] def updateAndGet(tokens: Int): Boolean = { - refreshTokens() + bucket.put(getNumRefreshTokens()) bucket.tryGet(tokens) } // we refresh the bucket with as many tokens as we have accrued since we last // refreshed. - private[this] def refreshTokens(): Unit = bucket.put(synchronized { + private[this] def getNumRefreshTokens(): Int = synchronized { val newTokens = period.numPeriods(elapsed()) elapsed = Stopwatch.start() val num = newTokens + remainder val floor = math.floor(num) remainder = num - floor floor.toInt - }) + } private[this] def restartTimerIfDead(): Unit = synchronized { if (!running) { @@ -264,10 +264,25 @@ class AsyncMeter private( // it's safe to race on allow, because polling loop is locked private[this] final def allow(): Unit = { - refreshTokens() + // tokens represents overflow from lack of granularity. we don't want to + // store more than `burstSize` tokens, but we want to be able to process + // load at the rate we advertise to, even if we can't refresh to `burstSize` + // as fast as `burstDuration` would like. we get around this by ensuring + // that we disburse the full amount to waiters, which ensures correct + // behavior for small `burstSize` and `burstDuration` below the minimum + // granularity. + var tokens = getNumRefreshTokens() + + if (tokens > burstSize) { + tokens -= burstSize + bucket.put(burstSize) + } else { + bucket.put(tokens) + tokens = 0 + } // we loop here so that we can satisfy more than one promise at a time. - // imagine that start with no tokens, we distribute ten tokens, and our + // imagine that we start with no tokens, we distribute ten tokens, and our // waiters are waiting for 4, 1, 6, 3 tokens. we should distribute 4, and // 1, and ask 6 and 3 to keep waiting until we have more tokens. while (true) { @@ -284,7 +299,14 @@ class AsyncMeter private( // tokens that we're missing with the Stopwatch. task.close() None - case (p, num) if bucket.tryGet(num) => + case (p, num) if num < tokens => + tokens -= num + q.poll() // we wait to remove until after we're able to get tokens + Some(p) + case (p, num) if bucket.tryGet(num - tokens) => + // we must zero tokens because we're implicitly pulling from the + // tokens first, and then the token bucket + tokens = 0 q.poll() // we wait to remove until after we're able to get tokens Some(p) case _ => diff --git a/util-core/src/test/scala/com/twitter/concurrent/AsyncMeterTest.scala b/util-core/src/test/scala/com/twitter/concurrent/AsyncMeterTest.scala index b66a5cc609..ed246e3ad5 100644 --- a/util-core/src/test/scala/com/twitter/concurrent/AsyncMeterTest.scala +++ b/util-core/src/test/scala/com/twitter/concurrent/AsyncMeterTest.scala @@ -112,6 +112,64 @@ class AsyncMeterTest extends FunSuite { } } + test("AsyncMeter should handle small burst sizes and periods smaller than timer granularity") { + val timer = new MockTimer + Time.withCurrentTimeFrozen { ctl => + val meter = newMeter(1, 500.microseconds, 100)(timer) + val ready = meter.await(1) + assert(ready.isDone) + + val first = meter.await(1) + val second = meter.await(1) + assert(!first.isDefined) + assert(!second.isDefined) + + ctl.advance(1.millisecond) + timer.tick() + assert(first.isDone) + assert(second.isDone) + } + } + + test("AsyncMeter should handle small, short bursts with big token amounts") { + val timer = new MockTimer + Time.withCurrentTimeFrozen { ctl => + val meter = newMeter(2, 500.microseconds, 100)(timer) + val ready = meter.await(2) + assert(ready.isDone) + + val first = meter.await(1) + val second = meter.await(2) + val third = meter.await(1) + assert(!first.isDefined) + assert(!second.isDefined) + assert(!third.isDefined) + + ctl.advance(1.millisecond) + timer.tick() + assert(first.isDone) + assert(second.isDone) + assert(third.isDone) + } + } + + test("AsyncMeter should hit the full rate even with insufficient granularity") { + val timer = new MockTimer + Time.withCurrentTimeFrozen { ctl => + val meter = newUnboundedMeter(1, 500.microseconds)(timer) + val ready = Future.join(Seq.fill(1000)(meter.await(1))).join { + FuturePool.unboundedPool { + for (_ <- 0 until 500) { + ctl.advance(1.millisecond) + timer.tick() + } + } + } + Await.ready(ready, 5.seconds) + assert(ready.isDefined) + } + } + test("AsyncMeter should allow an expensive call to be satisfied slowly") { val timer = new MockTimer Time.withCurrentTimeFrozen { ctl =>