Skip to content

Commit

Permalink
Add Reverting and Awaiting Bucket Rate Limits (#1127)
Browse files Browse the repository at this point in the history
There are two new features introduced in this commit:
The first one allows to await a rate limit instead of cancelling the command invocation.
The second one allows to return a ticket to the rate limit bucket.

**Awaiting Rate Limits**
Buckets can now be configured to return an action. The action indicates how a framework should react.
This means, buckets can indicate to cancel or delay command invocation.
In order to set this feature, call the chain-method `await_ratelimits` on the bucket builder.

**Returning Tickets**
If a command never reaches the critical zone that requires a rate limit, we can now undo and return the ticket.
This can be down by returning `RevertBucket` from a command.
The concept is simple, a command may perform an expensive computation or issue a request via web API.
However the command may never reach that part, imagine incorrect user input.
The command would still go on cooldown, which can be a pain for the user.
By returning `RevertBucket`, the cooldown can be undone.
  • Loading branch information
Lakelezz committed Dec 18, 2020
1 parent 1b4d408 commit 1589475
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 24 deletions.
9 changes: 6 additions & 3 deletions examples/e05_command_framework/src/main.rs
Expand Up @@ -15,7 +15,7 @@ use serenity::{
framework::standard::{
Args, CommandOptions, CommandResult, CommandGroup,
DispatchError, HelpOptions, help_commands, Reason, StandardFramework,
buckets::LimitedFor,
buckets::{RevertBucket, LimitedFor},
macros::{command, group, help, check, hook},
},
http::Http,
Expand Down Expand Up @@ -259,7 +259,9 @@ async fn main() {
// Can't be used more than once per 5 seconds:
.bucket("emoji", |b| b.delay(5)).await
// Can't be used more than 2 times per 30 seconds, with a 5 second delay applying per channel.
.bucket("complicated", |b| b.delay(5).time_span(30).limit(2).limit_for(LimitedFor::Channel)).await
// Optionally `await_ratelimits` will delay until the command can be executed instead of
// cancelling the command invocation.
.bucket("complicated", |b| b.delay(5).time_span(30).limit(2).limit_for(LimitedFor::Channel).await_ratelimits()).await
// The `#[group]` macro generates `static` instances of the options set for the group.
// They're made in the pattern: `#name_GROUP` for the group instance and `#name_GROUP_OPTIONS`.
// #name is turned all uppercase
Expand Down Expand Up @@ -467,7 +469,8 @@ async fn ping(ctx: &Context, msg: &Message) -> CommandResult {
async fn cat(ctx: &Context, msg: &Message) -> CommandResult {
msg.channel_id.say(&ctx.http, ":cat:").await?;

Ok(())
// We can return one ticket to the bucket undoing the ratelimit.
Err(RevertBucket.into())
}

#[command]
Expand Down
40 changes: 34 additions & 6 deletions src/framework/standard/mod.rs
Expand Up @@ -12,12 +12,14 @@ pub use args::{Args, Delimiter, Error as ArgError, Iter, RawArguments};
pub use configuration::{Configuration, WithWhiteSpace};
pub use structures::*;

use structures::buckets::Bucket;
use structures::buckets::{Bucket, BucketAction};
pub use structures::buckets::BucketBuilder;

use parse::{ParseError, Invoke};
use parse::map::{CommandMap, GroupMap, Map};

use self::buckets::RevertBucket;

use super::Framework;
use crate::client::Context;
use crate::model::{
Expand Down Expand Up @@ -273,15 +275,32 @@ impl StandardFramework {
return Some(DispatchError::BlockedChannel);
}

{
let mut buckets = self.buckets.lock().await;
// Try passing the command's bucket.
// exiting the loop if no command ratelimit has been hit or
// early-return when ratelimits cancel the framework invocation.
// Otherwise, delay and loop again to check if we passed the bucket.
loop {
let mut duration = None;

{
let mut buckets = self.buckets.lock().await;

if let Some(ref mut bucket) = command.bucket.as_ref().and_then(|b| buckets.get_mut(*b)) {
if let Some(ref mut bucket) = command.bucket.as_ref().and_then(|b| buckets.get_mut(*b)) {

if let Some(rate_limit) = bucket.take(ctx, msg).await {
return Some(DispatchError::Ratelimited(rate_limit))
if let Some(bucket_action) = bucket.take(ctx, msg).await {

duration = match bucket_action {
BucketAction::CancelWith(duration) => return Some(DispatchError::Ratelimited(duration)),
BucketAction::DelayFor(duration) => Some(duration),
};
}
}
}

match duration {
Some(duration) => tokio::time::delay_for(duration).await,
None => break,
}
}

for check in group.checks.iter().chain(command.checks.iter()) {
Expand Down Expand Up @@ -707,6 +726,15 @@ impl Framework for StandardFramework {

let res = (command.fun)(&mut ctx, &msg, args).await;

// Check if the command wants to revert the bucket by giving back a ticket.
if matches!(res, Err(ref e) if e.is::<RevertBucket>()) {
let mut buckets = self.buckets.lock().await;

if let Some(ref mut bucket) = command.options.bucket.as_ref().and_then(|b| buckets.get_mut(*b)) {
bucket.give(&ctx, &msg).await;
}
}

if let Some(after) = &self.after {
after(&mut ctx, &msg, name, res).await;
}
Expand Down
181 changes: 166 additions & 15 deletions src/framework/standard/structures/buckets.rs
Expand Up @@ -11,14 +11,29 @@ pub(crate) struct Ratelimit {
pub delay: Duration,
pub limit: Option<(Duration, u32)>,
}
pub(crate) struct UnitRatelimit {
pub last_time: Option<Instant>,
pub set_time: Instant,
pub tickets: u32,
}

impl UnitRatelimit {
fn new(creation_time: Instant) -> Self {
Self {
last_time: None,
set_time: creation_time,
tickets: 0,
}
}
}

#[derive(Default)]
pub(crate) struct UnitRatelimit {
pub(crate) struct UnitRatelimitTimes {
pub last_time: Option<Instant>,
pub set_time: Option<Instant>,
pub tickets: u32,
}

/// A bucket offers fine-grained control over the execution of commands.
pub(crate) enum Bucket {
/// The bucket will collect tickets for every invocation of a command.
Global(TicketCounter),
Expand All @@ -38,7 +53,7 @@ pub(crate) enum Bucket {

impl Bucket {
#[inline]
pub async fn take(&mut self, ctx: &Context, msg: &Message) -> Option<Duration> {
pub async fn take(&mut self, ctx: &Context, msg: &Message) -> Option<BucketAction> {
match self {
Self::Global(counter) => counter.take(ctx, msg, 0).await,
Self::User(counter) => counter.take(ctx, msg, msg.author.id.0).await,
Expand All @@ -61,16 +76,65 @@ impl Bucket {
},
}
}

#[inline]
pub async fn give(&mut self, ctx: &Context, msg: &Message) {
match self {
Self::Global(counter) => counter.give(ctx, msg, 0).await,
Self::User(counter) => counter.give(ctx, msg, msg.author.id.0).await,
Self::Guild(counter) => {
if let Some(guild_id) = msg.guild_id {
counter.give(ctx, msg, guild_id.0).await
}
}
Self::Channel(counter) => counter.give(ctx, msg, msg.channel_id.0).await,
// This requires the cache, as messages do not contain their channel's
// category.
#[cfg(feature = "cache")]
Self::Category(counter) =>
if let Some(category_id) = msg.category_id(ctx).await {
counter.give(ctx, msg, category_id.0).await
}
}
}
}

/// Keeps track of who owns how many tickets and when they accessed the last
/// time.
pub(crate) struct TicketCounter {
pub ratelimit: Ratelimit,
pub tickets_for: HashMap<u64, UnitRatelimit>,
pub check: Option<Check>,
pub await_ratelimits: bool,
}

/// A bucket may return results based on how it set up.
///
/// By default, it will return `CancelWith` when a limit is hit.
/// This is intended to cancel the command invocation and propagate the
/// duration to the user.
///
/// If the bucket is set to await durations, it will suggest to wait
/// for the bucket by returning `DelayFor` and then delay for the duration,
/// and then try taking a ticket again.
pub enum BucketAction {
CancelWith(Duration),
DelayFor(Duration),
}

impl TicketCounter {
pub async fn take(&mut self, ctx: &Context, msg: &Message, id: u64) -> Option<Duration> {
/// Tries to check whether the invocation is permitted by the ticket counter
/// and if a ticket can be taken; it does not return a
/// a ticket but a duration until a ticket can be taken.
///
/// The duration will be wrapped in an action for the caller to perform
/// if wanted. This may inform them to directly cancel trying to take a ticket
/// or delay the take until later.
///
/// However there is no contract: It does not matter what
/// the caller ends up doing, receiving some action eventually means
/// no ticket can be taken and the duration must elapse.
pub async fn take(&mut self, ctx: &Context, msg: &Message, id: u64) -> Option<BucketAction> {
if let Some(ref check) = self.check {

if !(check)(ctx, msg).await {
Expand All @@ -82,36 +146,99 @@ impl TicketCounter {
let Self {
tickets_for, ratelimit, ..
} = self;
let ticket_owner = tickets_for.entry(id).or_default();
let ticket_owner = tickets_for.entry(id)
.or_insert_with(|| UnitRatelimit::new(now));

// Check if too many tickets have been taken already.
// If all tickets are exhausted, return the needed delay
// for this invocation.
if let Some((timespan, limit)) = ratelimit.limit {

if (ticket_owner.tickets + 1) > limit {
if let Some(res) = ticket_owner
.set_time
.and_then(|x| (x + timespan).checked_duration_since(now))
if let Some(res) = (ticket_owner
.set_time + timespan).checked_duration_since(now)
{
return Some(res);
return Some(if self.await_ratelimits {
BucketAction::DelayFor(res)
} else {
BucketAction::CancelWith(res)
})
} else {
ticket_owner.tickets = 0;
ticket_owner.set_time = Some(now);
ticket_owner.set_time = now;
}
}
}

if let Some(res) = ticket_owner
// Check if `ratelimit.delay`-time passed between the last and
// the current invocation
// If the time did not pass, return the needed delay for this
// invocation.
if let Some(ratelimit) = ticket_owner
.last_time
.and_then(|x| (x + ratelimit.delay).checked_duration_since(now))
{
return Some(res);
return Some(if self.await_ratelimits {
BucketAction::DelayFor(ratelimit)
} else {
BucketAction::CancelWith(ratelimit)
})
} else {
ticket_owner.tickets += 1;
ticket_owner.last_time = Some(now);
}

None
}

/// Reverts the last ticket step performed by returning a ticket for the
/// matching ticket holder.
/// Only call this if the mutable owner already took a ticket in this
/// atomic execution of calling `take` and `give`.
pub async fn give(&mut self, ctx: &Context, msg: &Message, id: u64) {
if let Some(ref check) = self.check {

if !(check)(ctx, msg).await {
return
}
}

if let Some(ticket_owner) = self.tickets_for.get_mut(&id) {

// Remove a ticket if one is available.
if ticket_owner.tickets > 0 {
ticket_owner.tickets -= 1;
}

let delay = self.ratelimit.delay;
// Substract one step of time that would have to pass.
// This tries to bypass a problem of keeping track of when tickets
// were taken.
// When a ticket is taken, the bucket sets `last_time`, by
// substracting the delay, once a ticket is allowed to be
// taken.
// If the value is set to `None` this could possibly reset the
// bucket.
ticket_owner.last_time = ticket_owner
.last_time
.and_then(|i| i.checked_sub(delay));
}
}
}

/// An error struct that can be returned from a command to set the
/// bucket one step back.
#[derive(Debug)]
pub struct RevertBucket;

impl std::fmt::Display for RevertBucket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "RevertBucket")
}
}

impl std::error::Error for RevertBucket {}

/// Decides what a bucket will use to collect tickets for.
pub enum LimitedFor {
/// The bucket will collect tickets for every invocation of a command.
Expand All @@ -137,13 +264,26 @@ impl Default for LimitedFor {
}
}

#[derive(Default)]
pub struct BucketBuilder {
pub(crate) delay: Duration,
pub(crate) time_span: Duration,
pub(crate) limit: u32,
pub(crate) check: Option<Check>,
pub(crate) limited_for: LimitedFor,
pub(crate) await_ratelimits: bool,
}

impl Default for BucketBuilder {
fn default() -> Self {
Self {
delay: Duration::default(),
time_span: Duration::default(),
limit: 1,
check: None,
limited_for: LimitedFor::default(),
await_ratelimits: false,
}
}
}

impl BucketBuilder {
Expand Down Expand Up @@ -213,8 +353,6 @@ impl BucketBuilder {

/// Number of invocations allowed per [`time_span`].
///
/// Expressed in seconds.
///
/// [`time_span`]: Self::time_span
#[inline]
pub fn limit(&mut self, n: u32) -> &mut Self {
Expand All @@ -240,6 +378,18 @@ impl BucketBuilder {
self
}

/// If this is set to `true`, the invocation of the command will be delayed
/// and won't return a duration to wait to dispatch errors, but actually
/// await until the duration has been elapsed.
///
/// By default, ratelimits will become dispatch errors.
#[inline]
pub fn await_ratelimits(&mut self) -> &mut Self {
self.await_ratelimits = true;

self
}

/// Constructs the bucket.
#[inline]
pub(crate) fn construct(self) -> Bucket {
Expand All @@ -250,6 +400,7 @@ impl BucketBuilder {
},
tickets_for: HashMap::new(),
check: self.check,
await_ratelimits: self.await_ratelimits,
};

match self.limited_for {
Expand Down

0 comments on commit 1589475

Please sign in to comment.