Skip to content

Commit

Permalink
Add a generic collector for events (#1429)
Browse files Browse the repository at this point in the history
This creates a new collector that can be used to collect arbitrary event types, rather than the specific event types supported by the simpler collectors.
  • Loading branch information
sbrocket committed Jul 14, 2021
1 parent 6192107 commit 3117f1d
Show file tree
Hide file tree
Showing 10 changed files with 642 additions and 8 deletions.
3 changes: 3 additions & 0 deletions Cargo.toml
Expand Up @@ -144,6 +144,9 @@ version = "2.1"
version = "0.2"
package = "http"

[dev-dependencies.tokio-test]
version = "0.4"

[features]
# Defaults with different backends
default = ["default_no_backend", "rustls_backend"]
Expand Down
36 changes: 34 additions & 2 deletions examples/e10_collectors/src/main.rs
Expand Up @@ -5,7 +5,7 @@ use std::{collections::HashSet, env, time::Duration};

use serenity::{
async_trait,
collector::MessageCollectorBuilder,
collector::{EventCollectorBuilder, MessageCollectorBuilder},
framework::standard::{
help_commands,
macros::{command, group, help},
Expand Down Expand Up @@ -166,7 +166,39 @@ async fn challenge(ctx: &Context, msg: &Message, _: Args) -> CommandResult {
score += 1;
}

let _ = msg.reply(ctx, &format!("You completed {} out of 3 tasks correctly!", score)).await;
// We can also collect arbitrary events using the generic EventCollector. For example, here we
// collect updates to the messages that the user sent above and check for them updating all 5 of
// them.
let builder = EventCollectorBuilder::new(&ctx)
.add_event_type(EventType::MessageUpdate)
.timeout(Duration::from_secs(20));
// Only collect MessageUpdate events for the 5 MessageIds we're interested in.
let mut collector = collected.iter().fold(builder, |b, msg| b.add_message_id(msg.id)).await?;

let _ = msg.reply(ctx, "Edit each of those 5 messages in 20 seconds").await;
let mut edited = HashSet::new();
while let Some(event) = collector.next().await {
match event.as_ref() {
Event::MessageUpdate(e) => {
edited.insert(e.id);
},
e => panic!("Unexpected event type received: {:?}", e.event_type()),
}
if edited.len() >= 5 {
break;
}
}

if edited.len() >= 5 {
score += 1;
let _ = msg.reply(ctx, "Great! You edited 5 out of 5").await;
} else {
let _ = msg.reply(ctx, &format!("You only edited {} out of 5", edited.len())).await;
}

let _ = msg
.reply(ctx, &format!("TIME'S UP! You completed {} out of 4 tasks correctly!", score))
.await;

Ok(())
}
13 changes: 11 additions & 2 deletions src/client/bridge/gateway/shard_messenger.rs
Expand Up @@ -5,7 +5,7 @@ use super::{ChunkGuildFilter, ShardClientMessage, ShardRunnerMessage};
#[cfg(all(feature = "unstable_discord_api", feature = "collector"))]
use crate::collector::ComponentInteractionFilter;
#[cfg(feature = "collector")]
use crate::collector::{MessageFilter, ReactionFilter};
use crate::collector::{EventFilter, MessageFilter, ReactionFilter};
use crate::gateway::InterMessage;
use crate::model::prelude::*;

Expand Down Expand Up @@ -245,6 +245,15 @@ impl ShardMessenger {
self.tx.unbounded_send(InterMessage::Client(Box::new(ShardClientMessage::Runner(msg))))
}

/// Sets a new filter for an event collector.
#[inline]
#[cfg(feature = "collector")]
#[cfg_attr(docsrs, doc(cfg(feature = "collector")))]
pub fn set_event_filter(&self, collector: EventFilter) {
#[allow(clippy::let_underscore_must_use)]
let _ = self.send_to_shard(ShardRunnerMessage::SetEventFilter(collector));
}

/// Sets a new filter for a message collector.
#[inline]
#[cfg(feature = "collector")]
Expand All @@ -254,7 +263,7 @@ impl ShardMessenger {
let _ = self.send_to_shard(ShardRunnerMessage::SetMessageFilter(collector));
}

/// Sets a new filter for a message collector.
/// Sets a new filter for a reaction collector.
#[cfg(feature = "collector")]
#[cfg_attr(docsrs, doc(cfg(feature = "collector")))]
pub fn set_reaction_filter(&self, collector: ReactionFilter) {
Expand Down
15 changes: 14 additions & 1 deletion src/client/bridge/gateway/shard_runner.rs
Expand Up @@ -21,7 +21,7 @@ use crate::client::{EventHandler, RawEventHandler};
#[cfg(all(feature = "unstable_discord_api", feature = "collector"))]
use crate::collector::ComponentInteractionFilter;
#[cfg(feature = "collector")]
use crate::collector::{LazyArc, LazyReactionAction, MessageFilter, ReactionFilter};
use crate::collector::{EventFilter, LazyArc, LazyReactionAction, MessageFilter, ReactionFilter};
#[cfg(feature = "framework")]
use crate::framework::Framework;
use crate::gateway::{GatewayError, InterMessage, ReconnectType, Shard, ShardAction};
Expand Down Expand Up @@ -49,6 +49,8 @@ pub struct ShardRunner {
voice_manager: Option<Arc<dyn VoiceGatewayManager + Send + Sync + 'static>>,
cache_and_http: Arc<CacheAndHttp>,
#[cfg(feature = "collector")]
event_filters: Vec<EventFilter>,
#[cfg(feature = "collector")]
message_filters: Vec<MessageFilter>,
#[cfg(feature = "collector")]
reaction_filters: Vec<ReactionFilter>,
Expand All @@ -75,6 +77,8 @@ impl ShardRunner {
voice_manager: opt.voice_manager,
cache_and_http: opt.cache_and_http,
#[cfg(feature = "collector")]
event_filters: Vec::new(),
#[cfg(feature = "collector")]
message_filters: Vec::new(),
#[cfg(feature = "collector")]
reaction_filters: Vec::new(),
Expand Down Expand Up @@ -242,6 +246,9 @@ impl ShardRunner {
},
_ => {},
}

let mut event = LazyArc::new(event);
retain(&mut self.event_filters, |f| f.send_event(&mut event));
}

/// Clones the internal copy of the Sender to the shard runner.
Expand Down Expand Up @@ -436,6 +443,12 @@ impl ShardRunner {
self.shard.update_presence().await.is_ok()
},
#[cfg(feature = "collector")]
ShardClientMessage::Runner(ShardRunnerMessage::SetEventFilter(collector)) => {
self.event_filters.push(collector);

true
},
#[cfg(feature = "collector")]
ShardClientMessage::Runner(ShardRunnerMessage::SetMessageFilter(collector)) => {
self.message_filters.push(collector);

Expand Down
6 changes: 5 additions & 1 deletion src/client/bridge/gateway/shard_runner_message.rs
Expand Up @@ -3,7 +3,7 @@ use async_tungstenite::tungstenite::Message;
#[cfg(all(feature = "unstable_discord_api", feature = "collector"))]
use crate::collector::ComponentInteractionFilter;
#[cfg(feature = "collector")]
use crate::collector::{MessageFilter, ReactionFilter};
use crate::collector::{EventFilter, MessageFilter, ReactionFilter};
use crate::model::{
gateway::Activity,
id::{GuildId, UserId},
Expand Down Expand Up @@ -61,6 +61,10 @@ pub enum ShardRunnerMessage {
SetPresence(OnlineStatus, Option<Activity>),
/// Indicates that the client is to update the shard's presence's status.
SetStatus(OnlineStatus),
/// Sends a new filter for events to the shard.
#[cfg(feature = "collector")]
#[cfg_attr(docsrs, doc(cfg(feature = "collector")))]
SetEventFilter(EventFilter),
/// Sends a new filter for messages to the shard.
#[cfg(feature = "collector")]
#[cfg_attr(docsrs, doc(cfg(feature = "collector")))]
Expand Down
47 changes: 47 additions & 0 deletions src/collector/error.rs
@@ -0,0 +1,47 @@
use std::{
error::Error as StdError,
fmt::{Display, Formatter, Result as FmtResult},
};

/// An error that occured while working with a collector.
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum Error {
/// No event types were passed to [add_event_type].
///
/// [add_event_type]: crate::collector::EventCollectorBuilder::add_event_type
NoEventTypes,
/// The combination of event types and ID filters used with [EventCollectorBuilder] is invalid
/// and will never match any events.
///
/// For example, the following always errors because GuildCreate never has a related user ID:
/// ```rust
/// # use serenity::{prelude::*, collector::{CollectorError, EventCollectorBuilder}, model::prelude::*};
/// # let (sender, _) = futures::channel::mpsc::unbounded();
/// # let ctx = serenity::client::bridge::gateway::ShardMessenger::new(sender);
/// # tokio_test::block_on(async move {
/// assert!(matches!(
/// EventCollectorBuilder::new(&ctx)
/// .add_event_type(EventType::GuildCreate)
/// .add_user_id(UserId::default())
/// .await,
/// Err(SerenityError::Collector(CollectorError::InvalidEventIdFilters)),
/// ));
/// # });
/// ```
/// [EventCollectorBuilder]: crate::collector::EventCollectorBuilder
InvalidEventIdFilters,
}

impl Display for Error {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Error::NoEventTypes => f.write_str("No event types provided"),
Error::InvalidEventIdFilters => {
f.write_str("Invalid event type + id filters, would never match any events")
},
}
}
}

impl StdError for Error {}

0 comments on commit 3117f1d

Please sign in to comment.